diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..1e4e193d8 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,10 @@ +root = true + +[*.{c,h}] +indent_style = tab +indent_size = 8 +tab_width = 8 + +[*.proto] +indent_style = space +indent_size = 2 \ No newline at end of file diff --git a/api/agent.h b/api/agent.h index 7398e32ee..a9e8c92f2 100644 --- a/api/agent.h +++ b/api/agent.h @@ -212,3 +212,9 @@ int agent_storage_put( struct agent *agent, const char *name, void *data, size_t size ); + +void * +agent_alloc(struct agent *agent, size_t size); + +void +agent_free(struct agent *agent, void *ptr, size_t size); \ No newline at end of file diff --git a/balancer.yaml b/balancer.yaml index 5b41a410f..fbab1b2a4 100644 --- a/balancer.yaml +++ b/balancer.yaml @@ -10,7 +10,7 @@ packet_handler: - addr: "192.0.2.1" port: 80 proto: TCP # Accepts: TCP, tcp - scheduler: SOURCE_HASH # Accepts: SOURCE_HASH, source_hash, SH, sh + scheduler: sh flags: gre: false fix_mss: true @@ -28,11 +28,6 @@ packet_handler: weight: 100 src_addr: "192.0.2.1" src_mask: "255.255.255.255" - - ip: "10.1.1.2" - port: 0 - weight: 50 - src_addr: "192.0.2.1" - src_mask: "255.255.255.255" peers: - "192.0.2.10" - "192.0.2.11" @@ -42,7 +37,7 @@ packet_handler: - addr: "192.0.2.2" port: 443 proto: tcp # Lowercase also accepted - scheduler: round_robin # Accepts: ROUND_ROBIN, round_robin, RR, rr + scheduler: wrr flags: gre: false fix_mss: true @@ -76,7 +71,7 @@ packet_handler: - addr: "192.0.2.4" port: 53 proto: UDP - scheduler: SH # Short form + scheduler: sh flags: gre: false fix_mss: false @@ -104,7 +99,7 @@ packet_handler: - addr: "192.0.2.3" port: 0 # MUST be 0 for pure_l3 mode proto: tcp - scheduler: sh # Short form lowercase + scheduler: wrr flags: gre: false fix_mss: false @@ -126,7 +121,7 @@ packet_handler: - addr: "2001:db8::1" port: 443 proto: TCP - scheduler: RR + scheduler: wrr flags: gre: false fix_mss: true @@ -160,7 +155,6 @@ packet_handler: tcp_fin: 10 tcp: 60 udp: 30 - default: 60 # State management configuration state: diff --git a/cli/modules/counters/src/main.rs b/cli/modules/counters/src/main.rs index fcb004fd1..70888dc96 100644 --- a/cli/modules/counters/src/main.rs +++ b/cli/modules/counters/src/main.rs @@ -448,16 +448,8 @@ fn format_batch_counter(counter: &PerfCounter, next_min_batch: Option, widt let total_bytes = counter.bytes; // Average latency per packet and per batch - let avg_latency_per_packet = if total_packets > 0 { - counter.summary_latency / total_packets - } else { - 0 - }; - let avg_latency_per_batch = if total_batches > 0 { - counter.summary_latency / total_batches - } else { - 0 - }; + let avg_latency_per_packet = counter.summary_latency.checked_div(total_packets).unwrap_or(0); + let avg_latency_per_batch = counter.summary_latency.checked_div(total_batches).unwrap_or(0); // Line 1: Total batches, packets, and bytes let total_content = format!( diff --git a/common/btree/u32.h b/common/btree/u32.h index 85e4ea277..3d17d1e8a 100644 --- a/common/btree/u32.h +++ b/common/btree/u32.h @@ -294,7 +294,7 @@ btree_u32_init( // Initialize all blocks with maximum value (last element) // This ensures sentinel values for incomplete blocks if (n > 0) { - uint32_t max_val = data[n - 1]; + uint32_t max_val = data[n - 1] ^ 0x80000000; size_t total_values = nblocks * BTREE_U32_BLOCK_SIZE; for (size_t i = 0; i < total_values; ++i) { uint32_t *ptr = (uint32_t *)big_array_get( diff --git a/common/btree/u64.h b/common/btree/u64.h index 072f13916..5df697402 100644 --- a/common/btree/u64.h +++ b/common/btree/u64.h @@ -5,7 +5,6 @@ #include #include #include -#include #include /** @@ -304,7 +303,7 @@ btree_u64_init( // Initialize all blocks with maximum value (last element) // This ensures sentinel values for incomplete blocks if (n > 0) { - uint64_t max_val = data[n - 1]; + uint64_t max_val = data[n - 1] ^ 0x8000000000000000ULL; size_t total_values = nblocks * BTREE_U64_BLOCK_SIZE; for (size_t i = 0; i < total_values; ++i) { uint64_t *ptr = (uint64_t *)big_array_get( diff --git a/common/go/relptr/relptr.go b/common/go/relptr/relptr.go new file mode 100644 index 000000000..a3a413d85 --- /dev/null +++ b/common/go/relptr/relptr.go @@ -0,0 +1,78 @@ +package relptr + +import "unsafe" + +// Deref resolves a relative pointer field to *T. +// +// A relative pointer P stores the offset from its own address to the target: +// +// target = &P + P (when P != 0) +// target = nil (when P == 0) +// +// This mirrors the C macro ADDR_OF from common/memory_address.h. +// The field parameter must be a pointer to a pointer field (e.g. &vs.Reals). +func Deref[T any](field **T) *T { + raw := (*uintptr)(unsafe.Pointer(field)) + offset := *raw + if offset == 0 { + return nil + } + return (*T)(unsafe.Pointer(offset + uintptr(unsafe.Pointer(field)))) +} + +// Slice resolves a relative pointer field to a []T slice over contiguous +// memory. The returned slice shares the underlying memory (zero-copy). +func Slice[T any](field **T, count uint32) []T { + if count == 0 { + return nil + } + ptr := Deref(field) + if ptr == nil { + return nil + } + return unsafe.Slice(ptr, count) +} + +// Set makes a relative pointer field point to ptr. +// +// This mirrors the C macro SET_OFFSET_OF from common/memory_address.h. +func Set[T any](field **T, ptr *T) { + raw := (*uintptr)(unsafe.Pointer(field)) + if ptr == nil { + *raw = 0 + return + } + *raw = uintptr(unsafe.Pointer(ptr)) - uintptr(unsafe.Pointer(field)) +} + +func Equate[T any](dst **T, src **T) { + Set(dst, Deref(src)) +} + +func SetSlice[T any](field **T, ptr []T) { + if len(ptr) == 0 { + Set(field, nil) + return + } + Set(field, &ptr[0]) +} + +// DerefOpaque resolves a relative pointer stored in a uintptr field. +// Use this for opaque pointer fields (e.g. filter, selector) where +// the concrete type is not known to Go. +func DerefOpaque(field *uintptr) unsafe.Pointer { + offset := *field + if offset == 0 { + return nil + } + return unsafe.Pointer(offset + uintptr(unsafe.Pointer(field))) +} + +// SetOpaque makes a uintptr relative pointer field point to addr. +func SetOpaque(field *uintptr, addr unsafe.Pointer) { + if addr == nil { + *field = 0 + return + } + *field = uintptr(addr) - uintptr(unsafe.Pointer(field)) +} diff --git a/common/go/relptr/relptr_test.go b/common/go/relptr/relptr_test.go new file mode 100644 index 000000000..bba27dca0 --- /dev/null +++ b/common/go/relptr/relptr_test.go @@ -0,0 +1,198 @@ +package relptr + +import ( + "testing" + "unsafe" +) + +type testStruct struct { + Value uint64 +} + +func TestDerefAndSet(t *testing.T) { + var field *testStruct + target := testStruct{Value: 42} + + Set(&field, &target) + + got := Deref(&field) + if got == nil { + t.Fatal("Deref returned nil") + } + if got != &target { + t.Fatalf("Deref returned %p, want %p", got, &target) + } + if got.Value != 42 { + t.Fatalf("Value = %d, want 42", got.Value) + } +} + +func TestDerefNil(t *testing.T) { + var field *testStruct // zero bytes = null relative pointer + + got := Deref(&field) + if got != nil { + t.Fatalf("Deref of zero field returned %p, want nil", got) + } +} + +func TestSetNil(t *testing.T) { + var field *testStruct + target := testStruct{Value: 1} + + Set(&field, &target) + raw := *(*uintptr)(unsafe.Pointer(&field)) + if raw == 0 { + t.Fatal("field should be non-zero after Set") + } + + Set(&field, nil) + raw = *(*uintptr)(unsafe.Pointer(&field)) + if raw != 0 { + t.Fatalf("field = %d, want 0 after Set(nil)", raw) + } +} + +func TestSlice(t *testing.T) { + items := [4]testStruct{ + {Value: 10}, + {Value: 20}, + {Value: 30}, + {Value: 40}, + } + + var field *testStruct + Set(&field, &items[0]) + + got := Slice(&field, 4) + if len(got) != 4 { + t.Fatalf("len = %d, want 4", len(got)) + } + for i, v := range got { + want := uint64((i + 1) * 10) + if v.Value != want { + t.Errorf("got[%d].Value = %d, want %d", i, v.Value, want) + } + } + + // Verify zero-copy: mutate through slice, read from original. + got[2].Value = 99 + if items[2].Value != 99 { + t.Fatal("slice is not zero-copy") + } +} + +func TestSliceZeroCount(t *testing.T) { + var field *testStruct + target := testStruct{Value: 1} + Set(&field, &target) + + got := Slice(&field, 0) + if got != nil { + t.Fatalf("Slice with count=0 returned non-nil") + } +} + +func TestSliceNilField(t *testing.T) { + var field *testStruct + + got := Slice(&field, 5) + if got != nil { + t.Fatalf("Slice of nil field returned non-nil") + } +} + +func TestDerefOpaque(t *testing.T) { + var field uintptr + target := testStruct{Value: 77} + + SetOpaque(&field, unsafe.Pointer(&target)) + + got := DerefOpaque(&field) + if got == nil { + t.Fatal("DerefOpaque returned nil") + } + if (*testStruct)(got).Value != 77 { + t.Fatal("wrong value") + } +} + +func TestDerefOpaqueNil(t *testing.T) { + var field uintptr + got := DerefOpaque(&field) + if got != nil { + t.Fatal("expected nil") + } +} + +// TestRelativePointerInStruct simulates a C-like struct with a relative +// pointer field, similar to how balancer_vs.reals works in shared memory. +func TestRelativePointerInStruct(t *testing.T) { + type container struct { + Items *uint64 + Count uint32 + _ uint32 + Metadata uint64 + } + + items := [3]uint64{100, 200, 300} + + var c container + c.Count = 3 + c.Metadata = 0xDEAD + + Set(&c.Items, &items[0]) + + // Resolve and iterate. + slice := Slice(&c.Items, c.Count) + if len(slice) != 3 { + t.Fatalf("len = %d, want 3", len(slice)) + } + for i, v := range slice { + want := uint64((i + 1) * 100) + if v != want { + t.Errorf("slice[%d] = %d, want %d", i, v, want) + } + } + + // Ensure container fields are not corrupted. + if c.Metadata != 0xDEAD { + t.Fatalf("Metadata corrupted: %x", c.Metadata) + } +} + +// TestFieldEmbeddedInContiguousMemory allocates a flat byte buffer and +// uses relative pointers between regions, mimicking shared memory layout. +func TestFieldEmbeddedInContiguousMemory(t *testing.T) { + type header struct { + DataPtr *uint64 + Len uint32 + _ uint32 + } + + // Simulate a contiguous shared memory region. + buf := make([]byte, 256) + + hdr := (*header)(unsafe.Pointer(&buf[0])) + hdr.Len = 4 + + // Place data at offset 64 within the same buffer. + dataStart := (*uint64)(unsafe.Pointer(&buf[64])) + data := unsafe.Slice(dataStart, 4) + data[0] = 1 + data[1] = 2 + data[2] = 3 + data[3] = 4 + + Set(&hdr.DataPtr, dataStart) + + got := Slice(&hdr.DataPtr, hdr.Len) + if len(got) != 4 { + t.Fatalf("len = %d, want 4", len(got)) + } + for i, v := range got { + if v != uint64(i+1) { + t.Errorf("got[%d] = %d, want %d", i, v, i+1) + } + } +} diff --git a/common/rcu.h b/common/rcu.h index 3830c7c9e..0be72b177 100644 --- a/common/rcu.h +++ b/common/rcu.h @@ -180,10 +180,7 @@ typedef struct { * or 1) This packs both fields into a single atomic to reduce cache * traffic */ atomic_uint state; - - /** Padding to cache line size (64 bytes) to prevent false sharing */ - uint8_t pad[64 - sizeof(atomic_uint)]; -} rcu_worker_t; +} __attribute__((aligned(64))) rcu_worker_t; // Bit positions in the packed state field #define RCU_STATE_ACTIVE_BIT 0 @@ -201,11 +198,11 @@ typedef struct { * cache-line alignment of worker states. */ typedef struct { - /** Global epoch counter (0 or 1), flipped during updates */ - atomic_uint global_epoch; - /** Per-worker state array, one entry per worker thread */ rcu_worker_t workers[RCU_WORKERS]; + + /** Global epoch counter (0 or 1), flipped during updates */ + atomic_uint global_epoch; } rcu_t; /** diff --git a/common/ttlmap/detail/iter.h b/common/ttlmap/detail/iter.h index ec8b49ffa..914163983 100644 --- a/common/ttlmap/detail/iter.h +++ b/common/ttlmap/detail/iter.h @@ -31,3 +31,18 @@ } \ __ret; \ }) + +#define __TTLMAP_ITER_NEXT_BUCKET( \ + map_ptr, bucket_idx, key_type, value_type, now, cb, data \ +) \ + __extension__({ \ + __TTLMAP_BUCKET_ITER( \ + map_ptr, \ + (bucket_idx), \ + key_type, \ + value_type, \ + now, \ + cb, \ + data \ + ); \ + }) diff --git a/common/ttlmap/ttlmap.h b/common/ttlmap/ttlmap.h index 1f9554990..41eb3ce5c 100644 --- a/common/ttlmap/ttlmap.h +++ b/common/ttlmap/ttlmap.h @@ -2,7 +2,6 @@ #include -#include "common/memory.h" #include "detail/bucket.h" #include "detail/iter.h" #include "detail/lock.h" @@ -43,6 +42,39 @@ typedef struct ttlmap ttlmap_t; #define TTLMAP_PREFETCH(map_ptr, key_ptr, value_type, ...) \ __TTLMAP_PREFETCH(map_ptr, key_ptr, value_type, ##__VA_ARGS__) +struct ttlmap_bucket_iter { + size_t buckets; + size_t next_bucket; + struct ttlmap *map; +}; + +static inline void +ttlmap_bucket_iter_init(struct ttlmap_bucket_iter *iter, struct ttlmap *map) { + iter->map = map; + iter->next_bucket = 0; + iter->buckets = + map->buckets_exp == (size_t)-1 ? 0 : 1ull << map->buckets_exp; +} + +#define TTLMAP_ITER_NEXT(iter_ptr, key_type, value_type, now, cb, data) \ + __extension__({ \ + int __ret = 1; \ + if ((iter_ptr)->next_bucket == (iter_ptr)->buckets) { \ + __ret = 0; \ + } else { \ + __TTLMAP_ITER_NEXT_BUCKET( \ + (iter_ptr)->map, \ + (iter_ptr)->next_bucket++, \ + key_type, \ + value_type, \ + now, \ + cb, \ + data \ + ); \ + } \ + __ret; \ + }) + //////////////////////////////////////////////////////////////////////////////// static inline void diff --git a/controlplane/ffi/agent.go b/controlplane/ffi/agent.go index b0ffb8b0b..534da5ca9 100644 --- a/controlplane/ffi/agent.go +++ b/controlplane/ffi/agent.go @@ -20,9 +20,12 @@ package ffi //#include "api/agent.h" //#include "controlplane/agent/agent.h" import "C" + import ( "fmt" "unsafe" + + "github.com/c2h5oh/datasize" ) // ModuleConfig is a Go wrapper around a C cp_module pointer, representing a @@ -376,3 +379,60 @@ func (m *Agent) DeleteModuleConfig(configName string) error { } return nil } + +// Alloc allocates memory for a single value of type T and returns a pointer to it. +func Alloc[T any](m *Agent) *T { + var zero T + size := unsafe.Sizeof(zero) + ptr := m.AllocRaw(datasize.ByteSize(size)) + if ptr == nil { + return nil + } + return (*T)(ptr) +} + +// Free frees memory for a single value of type T. +func Free[T any](m *Agent, ptr *T) { + var zero T + size := unsafe.Sizeof(zero) + m.FreeRaw(unsafe.Pointer(ptr), datasize.ByteSize(size)) +} + +// AllocSlice allocates a contiguous block of memory for `count` elements of type T +// and returns it as a Go slice (backed by the allocated memory). +func AllocSlice[T any](m *Agent, count int) []T { + if count == 0 { + return []T{} + } + + var zero T + elemSize := unsafe.Sizeof(zero) + totalSize := elemSize * uintptr(count) + ptr := m.AllocRaw(datasize.ByteSize(totalSize)) + if ptr == nil { + return nil + } + return unsafe.Slice((*T)(ptr), count) +} + +// FreeSlice frees memory previously allocated with AllocSlice. +func FreeSlice[T any](m *Agent, s []T) { + if cap(s) == 0 { + return + } + var zero T + elemSize := unsafe.Sizeof(zero) + totalSize := elemSize * uintptr(cap(s)) + ptr := unsafe.Pointer(unsafe.SliceData(s)) + m.FreeRaw(ptr, datasize.ByteSize(totalSize)) +} + +// AllocRaw still exposes the raw byte-size allocation if needed. +func (m *Agent) AllocRaw(size datasize.ByteSize) unsafe.Pointer { + return C.agent_alloc(m.ptr, C.size_t(size)) +} + +// FreeRaw still exposes the raw byte-size free if needed. +func (m *Agent) FreeRaw(ptr unsafe.Pointer, size datasize.ByteSize) { + C.agent_free(m.ptr, ptr, C.size_t(size)) +} diff --git a/controlplane/ffi/shm.go b/controlplane/ffi/shm.go index 15a3dc799..f4b18b136 100644 --- a/controlplane/ffi/shm.go +++ b/controlplane/ffi/shm.go @@ -13,6 +13,7 @@ import "C" import ( "fmt" + "iter" "unsafe" "github.com/c2h5oh/datasize" @@ -764,7 +765,7 @@ type ModuleReference struct { ModuleName string } -func (m *DPConfig) AllModulePositions(moduleType string) []ModuleReference { +func (m *DPConfig) AllModulePositions(moduleType string) iter.Seq[ModuleReference] { deviceList := m.Devices() pipelineList := m.Pipelines() @@ -779,47 +780,29 @@ func (m *DPConfig) AllModulePositions(moduleType string) []ModuleReference { functions[function.Name] = function.Chains } - count := 0 - for _, device := range deviceList { - pipelineVariants := [][]DevicePipelineInfo{ - device.InputPipelines, - device.OutputPipelines, - } - for _, pipelines := range pipelineVariants { - for _, pipeline := range pipelines { - for _, function := range pipelineFunctions[pipeline.Name] { - for _, chain := range functions[function] { - for _, module := range chain.Modules { - if module.Type == moduleType { - count += 1 - } - } - } - } + return func(yield func(ModuleReference) bool) { + for _, device := range deviceList { + pipelineVariants := [][]DevicePipelineInfo{ + device.InputPipelines, + device.OutputPipelines, } - } - } - - result := make([]ModuleReference, 0, count) - for _, device := range deviceList { - pipelineVariants := [][]DevicePipelineInfo{ - device.InputPipelines, - device.OutputPipelines, - } - for _, pipelines := range pipelineVariants { - for _, pipeline := range pipelines { - for _, function := range pipelineFunctions[pipeline.Name] { - for _, chain := range functions[function] { - for _, module := range chain.Modules { - if module.Type == moduleType { - result = append(result, ModuleReference{ - Device: device.Name, - Pipeline: pipeline.Name, - Function: function, - Chain: chain.Name, - ModuleType: module.Type, - ModuleName: module.Name, - }) + for _, pipelines := range pipelineVariants { + for _, pipeline := range pipelines { + for _, function := range pipelineFunctions[pipeline.Name] { + for _, chain := range functions[function] { + for _, module := range chain.Modules { + if module.Type == moduleType { + if !yield(ModuleReference{ + Device: device.Name, + Pipeline: pipeline.Name, + Function: function, + Chain: chain.Name, + ModuleType: module.Type, + ModuleName: module.Name, + }) { + return + } + } } } } @@ -827,6 +810,4 @@ func (m *DPConfig) AllModulePositions(moduleType string) []ModuleReference { } } } - - return result } diff --git a/controlplane/meson.build b/controlplane/meson.build index 2c5f539ea..55ab3eb3f 100644 --- a/controlplane/meson.build +++ b/controlplane/meson.build @@ -33,7 +33,6 @@ custom_target( # cgo linker deps lib_acl_cp, lib_acl_dp, - lib_balancer_cp, lib_balancer_dp, lib_decap_cp, lib_decap_dp, diff --git a/controlplane/yncp/cfg.go b/controlplane/yncp/cfg.go index a6c878010..5862f8fa9 100644 --- a/controlplane/yncp/cfg.go +++ b/controlplane/yncp/cfg.go @@ -12,7 +12,7 @@ import ( "github.com/yanet-platform/yanet2/controlplane/internal/gateway" acl "github.com/yanet-platform/yanet2/modules/acl/controlplane" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" + balancer "github.com/yanet-platform/yanet2/modules/balancer/controlplane" decap "github.com/yanet-platform/yanet2/modules/decap/controlplane" dscp "github.com/yanet-platform/yanet2/modules/dscp/controlplane" forward "github.com/yanet-platform/yanet2/modules/forward/controlplane" diff --git a/controlplane/yncp/director.go b/controlplane/yncp/director.go index d0f4cbb4b..f7a6f1b9b 100644 --- a/controlplane/yncp/director.go +++ b/controlplane/yncp/director.go @@ -9,7 +9,7 @@ import ( "github.com/yanet-platform/yanet2/controlplane/ffi" "github.com/yanet-platform/yanet2/controlplane/internal/gateway" acl "github.com/yanet-platform/yanet2/modules/acl/controlplane" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" + balancer "github.com/yanet-platform/yanet2/modules/balancer/controlplane" decap "github.com/yanet-platform/yanet2/modules/decap/controlplane" dscp "github.com/yanet-platform/yanet2/modules/dscp/controlplane" forward "github.com/yanet-platform/yanet2/modules/forward/controlplane" @@ -124,7 +124,7 @@ func NewDirector(cfg *Config, options ...DirectorOption) (*Director, error) { return nil, fmt.Errorf("failed to initialize acl built-in module: %w", err) } - balancerModule, err := balancer.NewBalancerModule(cfg.Modules.Balancer, log) + balancerModule, err := balancer.NewModule(cfg.Modules.Balancer, log) if err != nil { return nil, fmt.Errorf("failed to initialize balancer built-in module: %w", err) } diff --git a/filter/compiler/net6_fast.c b/filter/compiler/net6_fast.c index 9beefc990..42013b83d 100644 --- a/filter/compiler/net6_fast.c +++ b/filter/compiler/net6_fast.c @@ -8,8 +8,6 @@ #include "compiler/segments.h" #include "declare.h" #include "rule.h" -#include -#include static int validate_net6_half(uint8_t *bytes) { @@ -183,6 +181,7 @@ init_classifier( value_registry_free(&low_registry); free_classifier_part(&classifier->high, mctx); free_classifier_part(&classifier->low, mctx); + return -1; } value_registry_free(&high_registry); diff --git a/lib/controlplane/agent/agent.c b/lib/controlplane/agent/agent.c index cbe065c04..fa3fa922f 100644 --- a/lib/controlplane/agent/agent.c +++ b/lib/controlplane/agent/agent.c @@ -1942,3 +1942,13 @@ yanet_module_performance_counters_free( free(counters->counters); } } + +void * +agent_alloc(struct agent *agent, size_t size) { + return memory_balloc(&agent->memory_context, size); +} + +void +agent_free(struct agent *agent, void *ptr, size_t size) { + memory_bfree(&agent->memory_context, ptr, size); +} \ No newline at end of file diff --git a/lib/dataplane/time/clock.c b/lib/dataplane/time/clock.c index b3274a5af..4f3a61849 100644 --- a/lib/dataplane/time/clock.c +++ b/lib/dataplane/time/clock.c @@ -19,7 +19,7 @@ tsc_clock_init(struct tsc_clock *clock) { clock->tsc_to_ns = 1024e9 / rte_get_tsc_hz(); clock->real_time_ns -= - clock->timestamp_counter * clock->tsc_to_ns >> 10; + (clock->timestamp_counter >> 10) * clock->tsc_to_ns; return 0; } diff --git a/modules/acl/controlplane/metrics.go b/modules/acl/controlplane/metrics.go index df9aa5dd8..5f5ad9759 100644 --- a/modules/acl/controlplane/metrics.go +++ b/modules/acl/controlplane/metrics.go @@ -127,7 +127,7 @@ func (m *ACLService) collectMetrics() ([]*commonpb.Metric, error) { result := make([]*commonpb.Metric, 0) gaugesEmitted := make(map[string]struct{}) - for _, pos := range positions { + for pos := range positions { configName := pos.ModuleName baseLabels := []*commonpb.Label{ diff --git a/modules/balancer/README.md b/modules/balancer/README.md deleted file mode 100644 index 1854fd3bd..000000000 --- a/modules/balancer/README.md +++ /dev/null @@ -1,423 +0,0 @@ -# Balancer Module - -The balancer module provides Layer 4 (TCP/UDP) load balancing functionality for the YANET platform. It distributes incoming traffic across multiple backend servers (real servers) while maintaining session affinity and providing advanced features like dynamic weight adjustment and ICMP error handling. - -## Overview - -The balancer operates as a packet handler in the YANET dataplane pipeline, intercepting traffic destined for configured virtual services and forwarding it to backend real servers using IP-in-IP (IPIP) or GRE encapsulation. - -### Key Features - -- **Layer 4 Load Balancing**: Distributes TCP and UDP traffic across multiple backend servers -- **Multiple Scheduling Algorithms**: - - Source Hash (session affinity based on client IP/port) - - Weighted Round Robin (even distribution with weight support) -- **Session Tracking**: Maintains connection state to ensure packets of the same flow reach the same backend -- **Dynamic Weight Adjustment**: Weighted Least Connection (WLC) algorithm automatically adjusts real server weights based on active session counts -- **Encapsulation Support**: IPIP and GRE tunneling for forwarding traffic to real servers -- **ICMP Handling**: Processes ICMP echo requests and error messages, with support for multi-balancer broadcasting -- **Source Address Filtering**: Optional restriction of traffic to specific client networks -- **Pure L3 Mode**: Load balance all traffic to an IP address regardless of port -- **One Packet Scheduler (OPS)**: Stateless per-packet scheduling without session tracking -- **Automatic Session Table Resizing**: Dynamically grows session table capacity based on load -- **Real Server Management**: Runtime updates to real server weights and enabled/disabled state - -## Architecture - -### Components - -1. **Packet Handler** ([`dataplane/`](dataplane/)): Fast-path packet processing in the dataplane - - Virtual service matching and selection - - Real server scheduling (SOURCE_HASH, ROUND_ROBIN) - - Session table lookup and creation - - Packet encapsulation (IPIP/GRE) - - ICMP processing and broadcasting - - TCP MSS adjustment - -2. **Control Plane** ([`controlplane/`](controlplane/)): Configuration management and state synchronization - - Configuration validation and updates - - Real server property updates (buffered and immediate) - - Session table management and resizing - - Periodic refresh for statistics and WLC - - State synchronization with dataplane - -3. **Agent Service** ([`agent/`](agent/)): gRPC API for management operations - - Configuration CRUD operations - - Real server updates with buffering support - - Statistics and runtime information queries - - Active session inspection - - Topology graph visualization - -## Configuration - -### Virtual Service - -A virtual service defines a load-balanced endpoint that distributes traffic across multiple real servers: - -```protobuf -message VirtualService { - VsIdentifier id = 1; // IP, port, protocol - VsScheduler scheduler = 2; // SOURCE_HASH or ROUND_ROBIN - repeated AllowedSrc allowed_srcs = 3; // Optional source filtering - repeated Real reals = 4; // Backend servers - VsFlags flags = 5; // Feature flags - repeated Addr peers = 6; // Peer balancer addresses -} -``` - -**Virtual Service Identifier** uniquely identifies a service by: -- IP address (IPv4 or IPv6) -- Port number (0 for pure L3 mode) -- Transport protocol (TCP or UDP) - -### Real Server - -A real server represents a backend that handles forwarded traffic: - -```protobuf -message Real { - RelativeRealIdentifier id = 1; // IP address - uint32 weight = 2; // Scheduling weight (1-65535) - Addr src_addr = 3; // Encapsulation source address - Addr src_mask = 4; // Encapsulation source mask -} -``` - -**Weight** determines traffic distribution: -- **SOURCE_HASH**: Higher weight = more hash buckets = more traffic -- **ROUND_ROBIN**: Weight determines consecutive connection count - -### Feature Flags - -Control virtual service behavior: - -- **`gre`**: Use GRE encapsulation instead of IPIP -- **`fix_mss`**: Adjust TCP MSS to account for encapsulation overhead -- **`ops`**: One Packet Scheduler mode (stateless, no session tracking) -- **`pure_l3`**: Match all traffic to IP regardless of port (port must be 0) -- **`wlc`**: Enable Weighted Least Connection dynamic weight adjustment - -### Session Timeouts - -Configure session lifetime based on protocol and TCP state: - -```protobuf -message SessionsTimeouts { - uint32 tcp_syn_ack = 1; // SYN-ACK state timeout - uint32 tcp_syn = 2; // SYN state timeout - uint32 tcp_fin = 3; // FIN state timeout - uint32 tcp = 4; // Established state timeout - uint32 udp = 5; // UDP session timeout - uint32 default = 6; // Default timeout -} -``` - -### State Configuration - -Controls session table and periodic operations: - -```protobuf -message StateConfig { - uint64 session_table_capacity = 1; // Max concurrent sessions - float session_table_max_load_factor = 2; // Auto-resize threshold (0.7-0.9) - WlcConfig wlc = 3; // WLC algorithm parameters - google.protobuf.Duration refresh_period = 4; // Periodic refresh interval -} -``` - -**Refresh Period** enables periodic operations: -- Session table scanning and statistics updates -- Automatic session table resizing when load exceeds threshold -- WLC weight adjustment (if enabled) -- Set to 0 to disable periodic operations - -### Weighted Least Connection (WLC) - -WLC dynamically adjusts real server weights based on active session distribution: - -```protobuf -message WlcConfig { - uint64 power = 1; // Adjustment aggressiveness (1-16) - uint32 max_weight = 2; // Maximum effective weight cap -} -``` - -**Algorithm**: -``` -ratio = (real_sessions * total_weight) / (total_sessions * real_weight) -wlc_factor = max(1.0, power * (1.0 - ratio)) -effective_weight = min(real_weight * wlc_factor, max_weight) -``` - -**Requirements** for WLC: -1. Set `VsFlags.wlc = true` for the virtual service -2. Configure `StateConfig.refresh_period` to non-zero value -3. Configure `StateConfig.session_table_max_load_factor` -4. Configure `StateConfig.wlc` with power and max_weight - -## Scheduling Algorithms - -### SOURCE_HASH - -Selects real servers based on hash of client source IP and port: -- Provides session affinity (same client → same real) -- Weight affects hash space distribution -- Best for stateful applications requiring client stickiness - -### ROUND_ROBIN - -Maintains monotonic counter and selects reals consecutively: -- Each real receives `weight` consecutive connections -- More even distribution across all reals -- Best for stateless applications - -## Source Address Filtering - -The balancer supports optional source address filtering to restrict access to virtual services based on client IP addresses and source ports. - -### Configuration - -```protobuf -message AllowedSrc { - Net net = 1; // Network prefix (address + mask) - repeated PortsRange ports = 2; // Optional source port ranges -} - -message PortsRange { - uint32 from = 1; // Starting port (inclusive) - uint32 to = 2; // Ending port (inclusive) -} -``` - -### Behavior - -- **Empty `allowed_srcs` list**: All source addresses are denied (no traffic allowed) -- **Non-empty `allowed_srcs` list**: Only traffic from matching sources is accepted -- **Multiple entries**: Evaluated with OR logic (any match allows the packet) - -### Port Filtering - -Each `AllowedSrc` entry can optionally specify source port ranges: - -- **Empty `ports` list**: All source ports are permitted for this network -- **Non-empty `ports` list**: Only source ports within specified ranges are permitted -- **Multiple ranges**: Evaluated with OR logic (any match allows the packet) - -### Matching Logic - -For each incoming packet: -1. If `allowed_srcs` is empty → **DROP** -2. For each `AllowedSrc` entry: - - Check if packet source IP matches the network prefix: `(src_ip & mask) == (net.addr & mask)` - - If `ports` list is empty → **ACCEPT** (IP match is sufficient) - - If `ports` list is non-empty, check if source port falls within any range - - If both IP and port match → **ACCEPT** -3. If no entry matches → **DROP** (increment `packet_src_not_allowed` counter) - -### Use Cases - -**Restrict to trusted networks:** -```protobuf -allowed_srcs: [ - { net: { addr: "10.0.0.0", mask: "255.0.0.0" } } -] -``` - -**Allow only high ports from specific network:** -```protobuf -allowed_srcs: [ - { - net: { addr: "192.168.0.0", mask: "255.255.0.0" } - ports: [ { from: 1024, to: 65535 } ] - } -] -``` - -**Allow specific service ports from multiple networks:** -```protobuf -allowed_srcs: [ - { - net: { addr: "172.16.0.0", mask: "255.240.0.0" } - ports: [ - { from: 80, to: 80 }, // HTTP - { from: 443, to: 443 }, // HTTPS - { from: 8000, to: 9000 } // Custom range - ] - }, - { - net: { addr: "10.0.0.0", mask: "255.0.0.0" } - ports: [ { from: 80, to: 80 }, { from: 443, to: 443 } ] - } -] -``` - -**Allow all IPv4 addresses (no filtering):** -```protobuf -allowed_srcs: [ - { net: { addr: "0.0.0.0", mask: "0.0.0.0" } } -] -``` - -### Statistics - -Packets blocked by source filtering are counted in the `packet_src_not_allowed` counter for the virtual service. This counter helps monitor unauthorized access attempts and validate filtering rules. - -## Packet Processing Flow - -1. **Ingress**: Packet arrives at balancer -2. **Decapsulation**: If destination matches decap address, remove outer IP header -3. **Virtual Service Selection**: Match packet to configured virtual service -4. **Source Filtering**: Check if source address and port are allowed (if `allowed_srcs` configured) -5. **Session Lookup**: Check if session exists in session table -6. **Real Selection**: If new session, select real using scheduler -7. **Session Creation**: Create new session entry (unless OPS mode) -8. **Encapsulation**: Wrap packet in IPIP or GRE tunnel -9. **MSS Adjustment**: Fix TCP MSS if enabled -10. **Forwarding**: Send packet to selected real server - -## ICMP Handling - -The balancer processes two types of ICMP messages: - -### ICMP Echo (Ping) -- Responds to pings for virtual service addresses -- Useful for health checking and monitoring - -### ICMP Error Messages -- Processes errors related to forwarded sessions -- Validates error against known sessions -- Forwards error to appropriate real server if session exists -- Broadcasts to peer balancers if no session found AND packet didn't come from a peer -- Packets that were decapsulated (came from peer balancers) are not re-broadcasted to prevent loops - -## Statistics - -The balancer tracks comprehensive statistics: - -### L4 Statistics -- Incoming/outgoing packet counts -- Virtual service selection failures -- Real server selection failures -- Invalid packet counts - -### ICMP Statistics -- Echo requests and responses -- Error message processing -- Peer broadcasting metrics -- Packet clone operations - -### Per-Virtual-Service Statistics -- Packet and byte counts -- Session creation counts -- Source filtering rejections -- Session table overflow events -- Real server availability issues - -### Per-Real-Server Statistics -- Packet and byte counts forwarded -- Session creation counts -- OPS packet counts -- ICMP error forwarding - -## API Operations - -The balancer provides a gRPC service with the following operations: - -### Configuration Management -- **`UpdateConfig`**: Create or update balancer configuration -- **`ShowConfig`**: Retrieve current configuration -- **`ListConfigs`**: List all balancer instances - -### Real Server Management -- **`UpdateReals`**: Update real server weights and enabled state - - Supports immediate or buffered updates - - Buffered updates applied atomically on flush -- **`FlushRealUpdates`**: Apply all buffered real server updates - -### Monitoring -- **`ShowStats`**: Retrieve packet processing statistics -- **`ShowInfo`**: Get runtime information (session counts, timestamps) -- **`ShowSessions`**: List all active sessions -- **`ShowGraph`**: Visualize balancer topology - -## Session Management - -Sessions are tracked in a hash table with the following properties: - -- **Key**: Client IP, client port, virtual service identifier -- **Value**: Assigned real server, timestamps, timeout -- **Capacity**: Configurable, automatically resized when load exceeds threshold -- **Timeout**: Based on protocol and TCP state -- **Cleanup**: Expired sessions removed during periodic refresh - -## Multi-Balancer Support - -Multiple balancer instances can coordinate for high availability: - -- **Peer Configuration**: List peer balancer addresses in virtual service -- **ICMP Broadcasting**: Error messages broadcasted to all peers -- **Coordinated Processing**: Ensures all balancers handle errors consistently - -## Performance Considerations - -### Session Table Sizing -- Larger capacity = more memory, supports more concurrent connections -- Auto-resize prevents overflow but causes temporary performance impact -- Set initial capacity based on expected peak load - -### Refresh Period -- Shorter period = more responsive, higher CPU overhead -- Longer period = less overhead, slower response to changes -- Typical values: 5-30 seconds for dynamic workloads - -### WLC Configuration -- Higher power = more aggressive adjustment, faster response -- Lower power = gentler adjustment, more stable weights -- Balance between responsiveness and stability - -### Scheduler Selection -- SOURCE_HASH: Better for stateful apps, may have uneven distribution -- ROUND_ROBIN: Better for stateless apps, more even distribution - -## Example Configuration - -```protobuf -BalancerConfig { - packet_handler: { - vs: [ - { - id: { addr: "10.0.0.1", port: 80, proto: TCP } - scheduler: SOURCE_HASH - reals: [ - { id: { ip: "192.168.1.10" }, weight: 100 } - { id: { ip: "192.168.1.11" }, weight: 100 } - { id: { ip: "192.168.1.12" }, weight: 50 } - ] - flags: { fix_mss: true, wlc: true } - } - ] - source_address_v4: "10.0.0.100" - source_address_v6: "2001:db8::100" - sessions_timeouts: { - tcp_syn: 30 - tcp_syn_ack: 30 - tcp_fin: 30 - tcp: 300 - udp: 60 - default: 60 - } - } - state: { - session_table_capacity: 1000000 - session_table_max_load_factor: 0.8 - wlc: { power: 4, max_weight: 500 } - refresh_period: { seconds: 10 } - } -} -``` - -## See Also - -- [Protobuf API Documentation](agent/balancerpb/) - Complete API reference -- [Dataplane Implementation](dataplane/) - Fast-path packet processing -- [Control Plane](controlplane/) - Configuration and state management \ No newline at end of file diff --git a/modules/balancer/agent/agent.c b/modules/balancer/agent/agent.c deleted file mode 100644 index 0058013cd..000000000 --- a/modules/balancer/agent/agent.c +++ /dev/null @@ -1,106 +0,0 @@ -#include "agent.h" -#include "api/agent.h" -#include "controlplane/agent/agent.h" -#include "controlplane/diag/diag.h" -#include "manager.h" -#include "modules/balancer/controlplane/api/balancer.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -const char *agent_name = "balancer"; -const char *storage_name = "balancer_storage"; - -//////////////////////////////////////////////////////////////////////////////// - -struct balancer_agent * -balancer_agent(struct yanet_shm *shm, size_t memory) { - struct agent *agent = agent_reattach(shm, 0, agent_name, memory); - if (agent == NULL) { - PUSH_ERROR("failed to reattach balancer agent"); - return NULL; - } - - if (agent_storage_read(agent, storage_name) == NULL) { - struct balancer_managers managers; - memset(&managers, 0, sizeof(managers)); - if (agent_storage_put( - agent, storage_name, &managers, sizeof(managers) - ) != 0) { - PUSH_ERROR("failed to allocate balancer storage"); - agent_cleanup(agent); - return NULL; - } - } - - return (struct balancer_agent *)agent; -} - -const char * -balancer_agent_take_error(struct balancer_agent *agent) { - return agent_take_error((struct agent *)agent); -} - -//////////////////////////////////////////////////////////////////////////////// - -void -balancer_agent_inspect( - struct balancer_agent *agent, struct agent_inspect *inspect -) { - struct agent *base_agent = (struct agent *)agent; - - // Get memory context statistics - inspect->memory_limit = base_agent->memory_limit; - inspect->memory_usage = base_agent->memory_context.balloc_size - - base_agent->memory_context.bfree_size; - - // Get all managers - struct balancer_managers managers; - balancer_agent_managers(agent, &managers); - - inspect->balancer_count = managers.count; - - if (managers.count == 0) { - inspect->balancers = NULL; - return; - } - - // Allocate array for balancer inspections - inspect->balancers = - calloc(managers.count, sizeof(struct named_balancer_inspect)); - - // Fill in each balancer inspection - for (size_t i = 0; i < managers.count; ++i) { - struct balancer_manager *manager = managers.managers[i]; - inspect->balancers[i].name = balancer_manager_name(manager); - balancer_manager_inspect( - manager, &inspect->balancers[i].inspect - ); - inspect->memory_usage += - inspect->balancers[i].inspect.total_usage; - } -} - -void -balancer_agent_inspect_free(struct agent_inspect *inspect) { - if (inspect == NULL || inspect->balancers == NULL) { - return; - } - - // Free each balancer inspection - for (size_t i = 0; i < inspect->balancer_count; ++i) { - balancer_manager_inspect_free(&inspect->balancers[i].inspect); - } - - // Free the balancers array - free(inspect->balancers); - inspect->balancers = NULL; - inspect->balancer_count = 0; -} - -struct dp_config * -balancer_agent_dp_config(struct balancer_agent *agent) { - return agent_dp_config((struct agent *)agent); -} \ No newline at end of file diff --git a/modules/balancer/agent/agent.h b/modules/balancer/agent/agent.h deleted file mode 100644 index 7422be8b8..000000000 --- a/modules/balancer/agent/agent.h +++ /dev/null @@ -1,163 +0,0 @@ -#pragma once - -#include "modules/balancer/controlplane/api/inspect.h" - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Opaque handle to a balancer agent instance. - * - * The agent is the top-level container that manages multiple balancer managers. - * It coordinates shared memory allocation and provides lifecycle management - * for balancer instances. - * - * Thread-Safety: Not thread-safe. External synchronization required for - * concurrent access. - */ -struct balancer_agent; - -struct yanet_shm; - -/** - * Create a new balancer agent instance. - * - * The agent is responsible for managing multiple balancer managers and - * coordinating their access to shared memory. It allocates the specified - * amount of memory from the provided shared memory region. - * - * @param shm Pointer to the shared memory region to use for allocations. - * @param memory Amount of memory (in bytes) to allocate for the agent. - * @return Newly created agent handle on success, or NULL on error. - */ -struct balancer_agent * -balancer_agent(struct yanet_shm *shm, size_t memory); - -struct balancer_manager; - -/** - * Container for a list of balancer managers. - * - * Used to retrieve all managers currently registered with an agent. - * The managers array is owned by the agent and should not be freed by caller. - */ -struct balancer_managers { - size_t count; // Number of managers in the array - struct balancer_manager **managers; // Array of manager pointers -}; - -/** - * Retrieve all balancer managers registered with the agent. - * - * Fills the provided balancer_managers structure with pointers to all - * currently active managers. The returned array is owned by the agent - * and remains valid until the agent is destroyed or managers are modified. - * - * @param agent Agent handle. - * @param managers Output structure to be filled with manager list. - */ -void -balancer_agent_managers( - struct balancer_agent *agent, struct balancer_managers *managers -); - -struct balancer_manager_config; - -/** - * Create and register a new balancer manager with the agent. - * - * Creates a new manager instance with the specified name and configuration, - * then registers it with the agent. The manager will be included in - * subsequent calls to balancer_agent_managers(). - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_agent_take_error(). - * - * @param agent Agent that will own the manager. - * @param name Human-readable manager name (used for identification). - * @param config Initial configuration for the manager. - * @return Newly created manager handle on success, or NULL on error. - */ -struct balancer_manager * -balancer_agent_new_manager( - struct balancer_agent *agent, - const char *name, - struct balancer_manager_config *config -); - -/** - * Named balancer inspection with name and memory usage details. - */ -struct named_balancer_inspect { - const char *name; - struct balancer_inspect inspect; -}; - -/** - * Agent-level memory inspection. - * - * Provides agent memory usage information and detailed inspection - * for all balancer instances managed by this agent. - */ -struct agent_inspect { - /** Agent memory limit (configured maximum) */ - uint64_t memory_limit; - - /** Current memory usage */ - uint64_t memory_usage; - - /** Number of balancers in the array */ - size_t balancer_count; - - /** Array of balancer inspections */ - struct named_balancer_inspect *balancers; -}; - -/** - * Retrieve agent-level memory inspection. - * - * Fills the provided agent_inspect structure with agent memory usage - * and detailed inspection for all balancer instances. The balancers - * array is allocated and must be freed with balancer_agent_inspect_free(). - * - * @param agent Agent handle. - * @param inspect Output structure to be filled with inspection data. - */ -void -balancer_agent_inspect( - struct balancer_agent *agent, struct agent_inspect *inspect -); - -/** - * Free all allocations inside an agent_inspect structure. - * - * Releases memory allocated by balancer_agent_inspect() for the - * balancers array and nested structures. Safe to call with - * partially-initialized structures; ignores NULL pointers. - * - * @param inspect Structure to release. The struct itself is not freed. - */ -void -balancer_agent_inspect_free(struct agent_inspect *inspect); - -/** - * Retrieve the last diagnostic error message for this agent. - * - * Returns the most recent error message recorded by agent operations. - * After calling this function, the error state is cleared. - * - * Ownership: The returned string is heap-allocated for the caller; you must - * free() it when no longer needed. Returns NULL if no error is available. - * - * @param agent Agent handle. - * @return Null-terminated error message string to be freed by caller, or NULL. - */ -const char * -balancer_agent_take_error(struct balancer_agent *agent); - -struct dp_config; - -struct dp_config * -balancer_agent_dp_config(struct balancer_agent *agent); \ No newline at end of file diff --git a/modules/balancer/agent/balancerpb/graph.proto b/modules/balancer/agent/balancerpb/graph.proto deleted file mode 100644 index 513bb2702..000000000 --- a/modules/balancer/agent/balancerpb/graph.proto +++ /dev/null @@ -1,85 +0,0 @@ -syntax = "proto3"; - -package balancerpb; - -import "modules/balancer/agent/balancerpb/module.proto"; - -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; - -// Real server in graph topology. -// -// Represents a real server's current state within the balancer's -// virtual service topology, including its weight and enabled status. -message GraphReal { - // Real server identifier (relative to its virtual service) - RelativeRealIdentifier identifier = 1; - - // Original configured weight of the real server. - // - // This is the static weight value from the configuration that was - // set via UpdateConfig or UpdateReals. It remains constant unless - // explicitly changed through configuration updates. - // - // This weight is used as the baseline for WLC calculations when - // the WLC flag is enabled for the virtual service. - uint32 weight = 2; - - // Current effective weight after WLC adjustments. - // - // When WLC (Weighted Least Connection) is enabled for the virtual - // service, this field shows the dynamically adjusted weight based - // on current session distribution. The WLC algorithm adjusts weights - // to balance load across real servers according to their active - // session counts. - // - // Formula: - // effective_weight = weight * max(1.0, power * (1.0 - - // connectionsRatio)) - // where: - // - weight: Original configured weight (field above) - // - power: WLC power factor from StateConfig.wlc.power - // - connectionsRatio: - // (real_sessions * total_weight) / (total_sessions * - // real_weight) - // - // When WLC is disabled, effective_weight equals weight. - // - // The effective_weight is capped at StateConfig.wlc.max_weight. - uint32 effective_weight = 3; - - // Whether the real server is currently enabled. - // - // When false, the real receives no new sessions but existing - // sessions may continue to be forwarded to it. - bool enabled = 4; -} - -// Virtual service in graph topology. -// -// Represents a virtual service and all its associated real servers -// with their current states. -message GraphVs { - // Virtual service identifier - VsIdentifier identifier = 1; - - // List of real servers backing this virtual service. - // - // Includes current weight and enabled status for each real. - repeated GraphReal reals = 2; -} - -// Complete balancer topology graph. -// -// Provides a snapshot of the entire balancer configuration showing -// the relationships between virtual services and their real servers, -// along with current operational state (weights, enabled status). -// -// This is useful for: -// - Visualizing the load balancer topology -// - Monitoring real server states -// - Debugging configuration issues -// - Understanding traffic distribution -message Graph { - // List of all virtual services with their real servers - repeated GraphVs virtual_services = 1; -} \ No newline at end of file diff --git a/modules/balancer/agent/balancerpb/info.proto b/modules/balancer/agent/balancerpb/info.proto deleted file mode 100644 index 5314abb02..000000000 --- a/modules/balancer/agent/balancerpb/info.proto +++ /dev/null @@ -1,145 +0,0 @@ -syntax = "proto3"; - -package balancerpb; - -import "google/protobuf/duration.proto"; -import "google/protobuf/timestamp.proto"; -import "modules/balancer/agent/balancerpb/module.proto"; - -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; - -// Real server runtime information. -// -// Provides statistics about a specific real server including -// active session count and last activity timestamp. -message RealInfo { - // Real server identifier - RealIdentifier id = 1; - - // Number of active sessions currently assigned to this real server. - // - // This count is updated asynchronously during periodic refresh - // (controlled by StateConfig.refresh_period), so it may not reflect - // the exact current state. The update frequency depends on the - // configured refresh period. - uint64 active_sessions = 2; - - // Timestamp of the last packet processed for this real server. - // - // Updated when any packet is forwarded to or received from this real. - // Useful for detecting inactive reals or monitoring traffic patterns. - google.protobuf.Timestamp last_packet_timestamp = 3; -} - -// Virtual service runtime information. -// -// Provides statistics about a specific virtual service including -// active session count, last activity, and per-real statistics. -message VsInfo { - // Virtual service identifier - VsIdentifier id = 1; - - // Number of active sessions for this virtual service. - // - // This is the sum of active sessions across all real servers - // backing this virtual service. Updated asynchronously during - // periodic refresh, so may lag behind actual state. - uint64 active_sessions = 2; - - // Timestamp of the last packet processed for this virtual service. - // - // Updated when any packet matches this virtual service. - // Useful for detecting inactive services or monitoring traffic - // patterns. - google.protobuf.Timestamp last_packet_timestamp = 3; - - // Runtime information for each real server backing this virtual - // service. - // - // Provides per-real session counts and activity timestamps. - // The list corresponds to the reals configured in VirtualService. - repeated RealInfo reals = 4; -} - -// Balancer-wide runtime information. -// -// Aggregates statistics across all virtual services and real servers -// managed by this balancer instance. -message BalancerInfo { - // Total number of active sessions across all virtual services. - // - // This is the sum of active sessions for all virtual services - // managed by this balancer. Updated asynchronously during periodic - // refresh (controlled by StateConfig.refresh_period). - // - // Note: This represents sessions tracked by the balancer, not - // necessarily all active connections to real servers (which may - // have additional direct connections). - uint64 active_sessions = 1; - - // Timestamp of the last packet processed by this balancer. - // - // Updated when any packet is processed by any virtual service. - // Useful for monitoring balancer activity and detecting issues. - google.protobuf.Timestamp last_packet_timestamp = 2; - - // Runtime information for each virtual service. - // - // Provides per-VS session counts, activity timestamps, and - // per-real statistics. The list includes all configured virtual - // services, even if they have no active sessions. - repeated VsInfo vs = 3; -} - -//////////////////////////////////////////////////////////////////////////////// - -// Individual session information. -// -// Represents a single active session tracked by the balancer, -// including client information, virtual service mapping, real server -// assignment, and timing information. -message SessionInfo { - // Client source IP address (IPv4 or IPv6) - Addr client_addr = 1; - - // Client source port number - uint32 client_port = 2; - - // Virtual service this session is associated with. - // - // Identifies which virtual service the client connected to. - VsIdentifier vs_id = 3; - - // Real server this session is assigned to. - // - // Identifies which real server is handling this session. - // Once assigned, all packets for this session are forwarded - // to this real (unless the real becomes unavailable). - RealIdentifier real_id = 4; - - // Session creation timestamp. - // - // When the first packet of this session was processed and - // the session was created in the session table. - google.protobuf.Timestamp create_timestamp = 5; - - // Last packet timestamp. - // - // When the most recent packet for this session was processed. - // Used to calculate session age and determine if timeout has expired. - google.protobuf.Timestamp last_packet_timestamp = 6; - - // Session timeout duration. - // - // How long the session remains active without receiving packets. - // The timeout value depends on the protocol and TCP state: - // - TCP SYN: sessions_timeouts.tcp_syn - // - TCP SYN-ACK: sessions_timeouts.tcp_syn_ack - // - TCP FIN: sessions_timeouts.tcp_fin - // - TCP established: sessions_timeouts.tcp - // - UDP: sessions_timeouts.udp - // - // If (current_time - last_packet_timestamp) > timeout, the session - // is removed from the session table during the next cleanup cycle. - google.protobuf.Duration timeout = 7; -} diff --git a/modules/balancer/agent/balancerpb/inspect.proto b/modules/balancer/agent/balancerpb/inspect.proto deleted file mode 100644 index 33339f3e9..000000000 --- a/modules/balancer/agent/balancerpb/inspect.proto +++ /dev/null @@ -1,82 +0,0 @@ -syntax = "proto3"; - -package balancerpb; - -import "modules/balancer/agent/balancerpb/module.proto"; - -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; - -// Memory usage for real servers within a VS -message RealsUsage { - uint64 counters_usage = 1; - uint64 data_usage = 2; - uint64 total_usage = 3; -} - -// Memory usage for a single virtual service -message VsInspect { - uint64 acl_usage = 1; - uint64 ring_usage = 2; - uint64 counters_usage = 3; - RealsUsage reals_usage = 4; - uint64 other_usage = 5; - uint64 total_usage = 6; -} - -// Named VS inspect with identifier -message NamedVsInspect { - VsIdentifier identifier = 1; - VsInspect inspect = 2; -} - -// Memory usage for IPv4 or IPv6 packet handler VS section -message PacketHandlerVsInspect { - uint64 matcher_usage = 1; - uint64 summary_vs_usage = 2; - repeated NamedVsInspect vs_inspects = 3; - uint64 announce_usage = 4; - uint64 index_usage = 5; - uint64 total_usage = 6; -} - -// Complete packet handler memory usage -message PacketHandlerInspect { - PacketHandlerVsInspect vs_ipv4_inspect = 1; - PacketHandlerVsInspect vs_ipv6_inspect = 2; - uint64 summary_vs_usage = 3; - uint64 vs_index_usage = 4; - uint64 reals_index_usage = 5; - uint64 counters_usage = 6; - uint64 decap_usage = 7; - uint64 total_usage = 8; -} - -// State memory usage -message StateInspect { - uint64 session_table_usage = 1; - uint64 total_usage = 2; -} - -// Per-balancer memory inspection -message BalancerInspect { - // Balancer name - string name = 1; - // Packet handler memory usage - PacketHandlerInspect packet_handler_inspect = 2; - // State memory usage - StateInspect state_inspect = 3; - // Other memory usage - uint64 other_usage = 4; - // Total memory usage for this balancer - uint64 total_usage = 5; -} - -// Agent-level inspection (top-level response) -message AgentInspect { - // Agent memory limit (configured maximum) - uint64 memory_limit = 1; - // Current memory usage - uint64 memory_usage = 2; - // List of all balancer inspects - repeated BalancerInspect balancers = 3; -} \ No newline at end of file diff --git a/modules/balancer/agent/balancerpb/meson.build b/modules/balancer/agent/balancerpb/meson.build deleted file mode 100644 index 486c516b5..000000000 --- a/modules/balancer/agent/balancerpb/meson.build +++ /dev/null @@ -1,33 +0,0 @@ -root_dir = meson.project_source_root() -proto_files = [ - join_paths(root_dir, 'modules/balancer/agent/balancerpb/balancer.proto'), - join_paths(root_dir, 'modules/balancer/agent/balancerpb/stats.proto'), - join_paths(root_dir, 'modules/balancer/agent/balancerpb/info.proto'), - join_paths(root_dir, 'modules/balancer/agent/balancerpb/graph.proto'), - join_paths(root_dir, 'modules/balancer/agent/balancerpb/module.proto'), - join_paths(root_dir, 'modules/balancer/agent/balancerpb/inspect.proto'), -] - -protoc_gen = custom_target( - 'balancer-protoc', - output: [ - 'balancer.pb.go', - 'balancer_grpc.pb.go', - 'stats.pb.go', - 'info.pb.go', - 'graph.pb.go', - 'module.pb.go', - 'inspect.pb.go', - ], - input: proto_files, - command: [ - protoc, - '-I', root_dir, - '--experimental_allow_proto3_optional', - '--go_out=paths=source_relative:' + root_dir, - '--go-grpc_out=paths=source_relative:' + root_dir, - '@INPUT@', - ], - build_by_default: true, -) -balancer_protoc_gen = protoc_gen diff --git a/modules/balancer/agent/config.c b/modules/balancer/agent/config.c deleted file mode 100644 index 755708fc9..000000000 --- a/modules/balancer/agent/config.c +++ /dev/null @@ -1,774 +0,0 @@ -#include "common/memory.h" -#include "common/memory_address.h" -#include "common/network.h" -#include "manager.h" -#include "modules/balancer/controlplane/api/balancer.h" -#include "modules/balancer/controlplane/api/vs.h" -#include -#include -#include - -/** - * Clone a named_real_config array from normal pointers to relative pointers. - * - * @param dst Destination pointer (will be set to allocated memory with relative - * pointers) - * @param src Source array with normal pointers - * @param count Number of elements in the array - * @param mctx Memory context for allocation - * @return 0 on success, -1 on error - */ -static int -clone_reals_to_relative( - struct named_real_config **dst, - struct named_real_config *src, - size_t count, - struct memory_context *mctx -) { - if (count == 0) { - SET_OFFSET_OF(dst, NULL); - return 0; - } - - struct named_real_config *reals = - memory_balloc(mctx, sizeof(struct named_real_config) * count); - if (reals == NULL) { - return -1; - } - - // Copy all real configs (they contain no pointers, just embedded - // structs) - memcpy(reals, src, sizeof(struct named_real_config) * count); - - SET_OFFSET_OF(dst, reals); - return 0; -} - -/** - * Clone an allowed_src array from normal pointers to relative pointers. - * Each allowed_src contains an address and an array of port ranges. - */ -static int -clone_allowed_src_to_relative( - struct allowed_sources **dst, - struct allowed_sources *src, - size_t count, - struct memory_context *mctx -) { - if (count == 0) { - SET_OFFSET_OF(dst, NULL); - return 0; - } - - // Allocate array of allowed_src entries - struct allowed_sources *entries = - memory_balloc(mctx, sizeof(struct allowed_sources) * count); - if (entries == NULL) { - return -1; - } - - // For each allowed_src entry, copy addr and clone port ranges - for (size_t i = 0; i < count; i++) { - // Copy networks - entries[i].nets_count = src[i].nets_count; - if (entries[i].nets_count > 0) { - entries[i].nets = memory_balloc( - mctx, sizeof(struct net) * entries[i].nets_count - ); - if (entries[i].nets == NULL) { - // todo: proper cleanup - memory_bfree( - mctx, - entries, - sizeof(struct allowed_sources) * count - ); - return -1; - } - for (size_t net_idx = 0; - net_idx < entries[i].nets_count; - ++net_idx) { - entries[i].nets[net_idx] = src[i].nets[net_idx]; - } - SET_OFFSET_OF(&entries[i].nets, entries[i].nets); - } else { - entries[i].nets = NULL; - } - - // Clone tag string if present - if (src[i].tag != NULL) { - size_t tag_len = strlen(src[i].tag) + 1; - char *tag_copy = memory_balloc(mctx, tag_len); - if (tag_copy == NULL) { - // TODO: proper cleanup - memory_bfree( - mctx, - entries, - sizeof(struct allowed_sources) * count - ); - return -1; - } - memcpy(tag_copy, src[i].tag, tag_len); - entries[i].tag = tag_copy; - SET_OFFSET_OF(&entries[i].tag, entries[i].tag); - } else { - entries[i].tag = NULL; - } - - entries[i].port_ranges_count = src[i].port_ranges_count; - - // Clone the ports_range array - if (src[i].port_ranges_count > 0) { - struct ports_range *ranges = memory_balloc( - mctx, - sizeof(struct ports_range) * - src[i].port_ranges_count - ); - if (ranges == NULL) { - // Cleanup previously allocated ranges - for (size_t j = 0; j < i; j++) { - if (entries[j].port_ranges_count > 0) { - memory_bfree( - mctx, - ADDR_OF(&entries[j] - .port_ranges - ), - sizeof(struct - ports_range - ) * entries[j].port_ranges_count - ); - } - } - // TODO: proper clean up - memory_bfree( - mctx, - entries, - sizeof(struct allowed_sources) * count - ); - return -1; - } - memcpy(ranges, - src[i].port_ranges, - sizeof(struct ports_range) * - src[i].port_ranges_count); - SET_OFFSET_OF(&entries[i].port_ranges, ranges); - } else { - SET_OFFSET_OF(&entries[i].port_ranges, NULL); - } - } - - SET_OFFSET_OF(dst, entries); - return 0; -} - -/** - * Clone a net4_addr array from normal pointers to relative pointers. - */ -static int -clone_net4_addrs_to_relative( - struct net4_addr **dst, - struct net4_addr *src, - size_t count, - struct memory_context *mctx -) { - if (count == 0) { - SET_OFFSET_OF(dst, NULL); - return 0; - } - - struct net4_addr *addrs = - memory_balloc(mctx, sizeof(struct net4_addr) * count); - if (addrs == NULL) { - return -1; - } - - memcpy(addrs, src, sizeof(struct net4_addr) * count); - SET_OFFSET_OF(dst, addrs); - return 0; -} - -/** - * Clone a net6_addr array from normal pointers to relative pointers. - */ -static int -clone_net6_addrs_to_relative( - struct net6_addr **dst, - struct net6_addr *src, - size_t count, - struct memory_context *mctx -) { - if (count == 0) { - SET_OFFSET_OF(dst, NULL); - return 0; - } - - struct net6_addr *addrs = - memory_balloc(mctx, sizeof(struct net6_addr) * count); - if (addrs == NULL) { - return -1; - } - - memcpy(addrs, src, sizeof(struct net6_addr) * count); - SET_OFFSET_OF(dst, addrs); - return 0; -} - -/** - * Clone a vs_config from normal pointers to relative pointers. - */ -static int -clone_vs_config_to_relative( - struct vs_config *dst, - struct vs_config *src, - struct memory_context *mctx -) { - // Copy scalar fields - dst->flags = src->flags; - dst->scheduler = src->scheduler; - dst->real_count = src->real_count; - dst->allowed_src_count = src->allowed_src_count; - dst->peers_v4_count = src->peers_v4_count; - dst->peers_v6_count = src->peers_v6_count; - - // Clone reals array - if (clone_reals_to_relative( - &dst->reals, src->reals, src->real_count, mctx - ) != 0) { - return -1; - } - - // Clone allowed_src array - if (clone_allowed_src_to_relative( - &dst->allowed_src, - src->allowed_src, - src->allowed_src_count, - mctx - ) != 0) { - // TODO: free reals - return -1; - } - - // Clone peers_v4 array - if (clone_net4_addrs_to_relative( - &dst->peers_v4, src->peers_v4, src->peers_v4_count, mctx - ) != 0) { - // TODO: free reals and addr ranges - return -1; - } - - // Clone peers_v6 array - if (clone_net6_addrs_to_relative( - &dst->peers_v6, src->peers_v6, src->peers_v6_count, mctx - ) != 0) { - // TODO: free freals, addr ranges and net4 addrs - return -1; - } - - return 0; -} - -static void -free_vs_config_with_relative_pointers( - struct vs_config *vs_config, struct memory_context *mctx -); - -/** - * Clone a named_vs_config array from normal pointers to relative pointers. - */ -static int -clone_vs_array_to_relative( - struct named_vs_config **dst, - struct named_vs_config *src, - size_t count, - struct memory_context *mctx -) { - if (count == 0) { - SET_OFFSET_OF(dst, NULL); - return 0; - } - - struct named_vs_config *vs_array = - memory_balloc(mctx, sizeof(struct named_vs_config) * count); - if (vs_array == NULL) { - return -1; - } - - for (size_t i = 0; i < count; i++) { - // Copy identifier (no pointers) - vs_array[i].identifier = src[i].identifier; - - // Clone config with nested pointers - if (clone_vs_config_to_relative( - &vs_array[i].config, &src[i].config, mctx - ) != 0) { - for (size_t j = 0; j < i; ++j) { - free_vs_config_with_relative_pointers( - &vs_array[j].config, mctx - ); - } - memory_bfree( - mctx, - vs_array, - sizeof(struct named_vs_config) * count - ); - return -1; - } - } - - SET_OFFSET_OF(dst, vs_array); - return 0; -} - -/** - * Clone packet_handler_config from normal pointers to relative pointers. - */ -static int -clone_handler_config_to_relative( - struct packet_handler_config *dst, - struct packet_handler_config *src, - struct memory_context *mctx -) { - // Copy scalar fields and embedded structs - dst->sessions_timeouts = src->sessions_timeouts; - dst->vs_count = src->vs_count; - dst->source_v4 = src->source_v4; - dst->source_v6 = src->source_v6; - dst->decap_v4_count = src->decap_v4_count; - dst->decap_v6_count = src->decap_v6_count; - - // Clone vs array - if (clone_vs_array_to_relative( - &dst->vs, src->vs, src->vs_count, mctx - ) != 0) { - return -1; - } - - // Clone decap_v4 array - if (clone_net4_addrs_to_relative( - &dst->decap_v4, src->decap_v4, src->decap_v4_count, mctx - ) != 0) { - return -1; - } - - // Clone decap_v6 array - if (clone_net6_addrs_to_relative( - &dst->decap_v6, src->decap_v6, src->decap_v6_count, mctx - ) != 0) { - return -1; - } - - return 0; -} - -/** - * Clone balancer_config from normal pointers to relative pointers. - */ -int -clone_balancer_config_to_relative( - struct balancer_config *dst, - struct balancer_config *src, - struct memory_context *mctx -) { - // Clone handler config - if (clone_handler_config_to_relative( - &dst->handler, &src->handler, mctx - ) != 0) { - return -1; - } - - // Copy state config (no pointers) - dst->state = src->state; - - return 0; -} - -/* ======================================================================== - * Functions for cloning FROM relative pointers TO normal pointers - * ======================================================================== */ - -/** - * Clone a named_real_config array from relative pointers to normal pointers. - */ -static int -clone_reals_from_relative( - struct named_real_config **dst, - struct named_real_config **src_offset, - size_t count -) { - if (count == 0) { - *dst = NULL; - return 0; - } - - struct named_real_config *src = ADDR_OF(src_offset); - struct named_real_config *reals = - calloc(count, sizeof(struct named_real_config)); - if (reals == NULL) { - return -1; - } - - memcpy(reals, src, sizeof(struct named_real_config) * count); - *dst = reals; - return 0; -} - -/** - * Clone an allowed_src array from relative pointers to normal pointers. - * Each allowed_src contains an address and an array of port ranges. - */ -static int -clone_allowed_src_from_relative( - struct allowed_sources **dst, - struct allowed_sources **src_offset, - size_t count -) { - if (count == 0) { - *dst = NULL; - return 0; - } - - struct allowed_sources *src = ADDR_OF(src_offset); - struct allowed_sources *entries = - calloc(count, sizeof(struct allowed_sources)); - if (entries == NULL) { - return -1; - } - - // For each allowed_src entry, copy addr and clone port ranges - for (size_t i = 0; i < count; i++) { - // Copy networks - entries[i].nets_count = src[i].nets_count; - if (entries[i].nets_count > 0) { - entries[i].nets = - calloc(entries[i].nets_count, - sizeof(struct net6)); - struct net *nets = ADDR_OF(&src[i].nets); - for (size_t net_idx = 0; - net_idx < entries[i].nets_count; - ++net_idx) { - entries[i].nets[net_idx] = nets[net_idx]; - } - } - - entries[i].port_ranges_count = src[i].port_ranges_count; - - // Clone the ports_range array - if (src[i].port_ranges_count > 0) { - struct ports_range *src_ranges = - ADDR_OF(&src[i].port_ranges); - struct ports_range *ranges = - calloc(src[i].port_ranges_count, - sizeof(struct ports_range)); - if (ranges == NULL) { - // Cleanup previously allocated ranges - for (size_t j = 0; j < i; j++) { - free(entries[j].port_ranges); - } - free(entries); - return -1; - } - memcpy(ranges, - src_ranges, - sizeof(struct ports_range) * - src[i].port_ranges_count); - entries[i].port_ranges = ranges; - } else { - entries[i].port_ranges = NULL; - } - - // Clone tag string if present - if (src[i].tag != NULL) { - const char *src_tag = ADDR_OF(&src[i].tag); - entries[i].tag = strdup(src_tag); - } else { - entries[i].tag = NULL; - } - } - - *dst = entries; - return 0; -} - -/** - * Clone a net4_addr array from relative pointers to normal pointers. - */ -static int -clone_net4_addrs_from_relative( - struct net4_addr **dst, struct net4_addr **src_offset, size_t count -) { - if (count == 0) { - *dst = NULL; - return 0; - } - - struct net4_addr *src = ADDR_OF(src_offset); - struct net4_addr *addrs = calloc(count, sizeof(struct net4_addr)); - if (addrs == NULL) { - return -1; - } - - memcpy(addrs, src, sizeof(struct net4_addr) * count); - *dst = addrs; - return 0; -} - -/** - * Clone a net6_addr array from relative pointers to normal pointers. - */ -static int -clone_net6_addrs_from_relative( - struct net6_addr **dst, struct net6_addr **src_offset, size_t count -) { - if (count == 0) { - *dst = NULL; - return 0; - } - - struct net6_addr *src = ADDR_OF(src_offset); - struct net6_addr *addrs = calloc(count, sizeof(struct net6_addr)); - if (addrs == NULL) { - return -1; - } - - memcpy(addrs, src, sizeof(struct net6_addr) * count); - *dst = addrs; - return 0; -} - -/** - * Clone a vs_config from relative pointers to normal pointers. - */ -static int -clone_vs_config_from_relative(struct vs_config *dst, struct vs_config *src) { - // Copy scalar fields - dst->flags = src->flags; - dst->scheduler = src->scheduler; - dst->real_count = src->real_count; - dst->allowed_src_count = src->allowed_src_count; - dst->peers_v4_count = src->peers_v4_count; - dst->peers_v6_count = src->peers_v6_count; - - // Clone reals array - if (clone_reals_from_relative( - &dst->reals, &src->reals, src->real_count - ) != 0) { - return -1; - } - - // Clone allowed_src array - if (clone_allowed_src_from_relative( - &dst->allowed_src, &src->allowed_src, src->allowed_src_count - ) != 0) { - free(dst->reals); - return -1; - } - - // Clone peers_v4 array - if (clone_net4_addrs_from_relative( - &dst->peers_v4, &src->peers_v4, src->peers_v4_count - ) != 0) { - free(dst->reals); - free(dst->allowed_src); - return -1; - } - - // Clone peers_v6 array - if (clone_net6_addrs_from_relative( - &dst->peers_v6, &src->peers_v6, src->peers_v6_count - ) != 0) { - free(dst->reals); - free(dst->allowed_src); - free(dst->peers_v4); - return -1; - } - - return 0; -} - -/** - * Clone a named_vs_config array from relative pointers to normal pointers. - */ -static int -clone_vs_array_from_relative( - struct named_vs_config **dst, - struct named_vs_config **src_offset, - size_t count -) { - if (count == 0) { - *dst = NULL; - return 0; - } - - struct named_vs_config *src = ADDR_OF(src_offset); - struct named_vs_config *vs_array = - calloc(count, sizeof(struct named_vs_config)); - - for (size_t i = 0; i < count; i++) { - // Copy identifier (no pointers) - vs_array[i].identifier = src[i].identifier; - - // Clone config with nested pointers - clone_vs_config_from_relative( - &vs_array[i].config, &src[i].config - ); - } - - *dst = vs_array; - return 0; -} - -/** - * Clone packet_handler_config from relative pointers to normal pointers. - */ -int -packet_handler_config_from_relative( - struct packet_handler_config *dst, struct packet_handler_config *src -) { - // Copy scalar fields and embedded structs - dst->sessions_timeouts = src->sessions_timeouts; - dst->vs_count = src->vs_count; - dst->source_v4 = src->source_v4; - dst->source_v6 = src->source_v6; - dst->decap_v4_count = src->decap_v4_count; - dst->decap_v6_count = src->decap_v6_count; - - // Clone vs array - if (clone_vs_array_from_relative(&dst->vs, &src->vs, src->vs_count) != - 0) { - return -1; - } - - // Clone decap_v4 array - if (clone_net4_addrs_from_relative( - &dst->decap_v4, &src->decap_v4, src->decap_v4_count - ) != 0) { - // Cleanup vs array - if (dst->vs) { - for (size_t i = 0; i < dst->vs_count; i++) { - free(dst->vs[i].config.reals); - free(dst->vs[i].config.allowed_src); - free(dst->vs[i].config.peers_v4); - free(dst->vs[i].config.peers_v6); - } - free(dst->vs); - } - return -1; - } - - // Clone decap_v6 array - if (clone_net6_addrs_from_relative( - &dst->decap_v6, &src->decap_v6, src->decap_v6_count - ) != 0) { - // Cleanup - if (dst->vs) { - for (size_t i = 0; i < dst->vs_count; i++) { - free(dst->vs[i].config.reals); - free(dst->vs[i].config.allowed_src); - free(dst->vs[i].config.peers_v4); - free(dst->vs[i].config.peers_v6); - } - free(dst->vs); - } - free(dst->decap_v4); - return -1; - } - - return 0; -} - -/** - * Clone balancer_config from relative pointers to normal pointers. - */ -int -clone_balancer_config_from_relative( - struct balancer_config *dst, struct balancer_config *src -) { - // Clone handler config - if (packet_handler_config_from_relative(&dst->handler, &src->handler) != - 0) { - return -1; - } - - // Copy state config (no pointers) - dst->state = src->state; - - return 0; -} - -/** - * Free a vs_config with relative pointers (allocated in agent memory). - */ -static void -free_vs_config_with_relative_pointers( - struct vs_config *cfg, struct memory_context *mctx -) { - // Free reals array - if (cfg->real_count > 0 && cfg->reals != NULL) { - struct named_real_config *reals = ADDR_OF(&cfg->reals); - memory_bfree( - mctx, - reals, - sizeof(struct named_real_config) * cfg->real_count - ); - cfg->reals = NULL; - cfg->real_count = 0; - } - - // Free allowed_src array (with nested port ranges) - if (cfg->allowed_src_count > 0 && cfg->allowed_src != NULL) { - struct allowed_sources *entries = ADDR_OF(&cfg->allowed_src); - - // First, free each nested ports_range array and tag strings - for (size_t i = 0; i < cfg->allowed_src_count; i++) { - if (entries[i].port_ranges_count > 0 && - entries[i].port_ranges != NULL) { - struct ports_range *ranges = - ADDR_OF(&entries[i].port_ranges); - memory_bfree( - mctx, - ranges, - sizeof(struct ports_range) * - entries[i].port_ranges_count - ); - } - // Free tag string if present - if (entries[i].tag != NULL) { - const char *tag = ADDR_OF(&entries[i].tag); - size_t tag_len = strlen(tag) + 1; - memory_bfree(mctx, (void *)tag, tag_len); - } - } - - // Then free the allowed_src array itself - memory_bfree( - mctx, - entries, - sizeof(struct allowed_sources) * cfg->allowed_src_count - ); - cfg->allowed_src = NULL; - cfg->allowed_src_count = 0; - } - - // Free peers_v4 array - if (cfg->peers_v4_count > 0 && cfg->peers_v4 != NULL) { - struct net4_addr *addrs = ADDR_OF(&cfg->peers_v4); - memory_bfree( - mctx, - addrs, - sizeof(struct net4_addr) * cfg->peers_v4_count - ); - cfg->peers_v4 = NULL; - cfg->peers_v4_count = 0; - } - - // Free peers_v6 array - if (cfg->peers_v6_count > 0 && cfg->peers_v6 != NULL) { - struct net6_addr *addrs = ADDR_OF(&cfg->peers_v6); - memory_bfree( - mctx, - addrs, - sizeof(struct net6_addr) * cfg->peers_v6_count - ); - cfg->peers_v6 = NULL; - cfg->peers_v6_count = 0; - } -} diff --git a/modules/balancer/agent/go/agent.go b/modules/balancer/agent/go/agent.go deleted file mode 100644 index cb06d0673..000000000 --- a/modules/balancer/agent/go/agent.go +++ /dev/null @@ -1,296 +0,0 @@ -// Package balancer provides the load balancer agent implementation for YANET. -// This package manages balancer instances, virtual services, and real servers, -// coordinating between the control plane and data plane for packet distribution. -// -// The BalancerAgent manages multiple BalancerManager instances, each representing -// a separate load balancer configuration with its own virtual services and real servers. -package balancer - -import ( - "fmt" - "sync" - "time" - - "github.com/c2h5oh/datasize" - "github.com/yanet-platform/yanet2/common/commonpb" - "github.com/yanet-platform/yanet2/common/go/metrics" - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "go.uber.org/zap" -) - -type BalancerAgent struct { - handle *ffi.BalancerAgent - managers map[string]*BalancerManager - - mu sync.Mutex - - handlersMetrics handlersMetrics - - log *zap.SugaredLogger -} - -func NewBalancerAgent( - shm *yanet.SharedMemory, - memory datasize.ByteSize, - log *zap.SugaredLogger, -) (*BalancerAgent, error) { - if log == nil { - return nil, fmt.Errorf("logger cannot be nil") - } - - handle, err := ffi.NewBalancerAgent(shm, uint(memory.Bytes())) - if err != nil { - return nil, err - } - managerHandles := handle.Managers() - managers := make(map[string]*BalancerManager) - for _, managerHandle := range managerHandles { - manager := NewBalancerManager(&managerHandle, log) - managers[manager.Name()] = manager - } - return &BalancerAgent{ - handle: handle, - managers: managers, - mu: sync.Mutex{}, - log: log, - handlersMetrics: newHandlersMetrics(), - }, nil -} - -func (a *BalancerAgent) NewBalancerManager( - name string, - config *balancerpb.BalancerConfig, -) error { - a.mu.Lock() - defer a.mu.Unlock() - - tracker := newHandlerMetricTracker( - "create", - &a.handlersMetrics, - defaultLatencyBoundsMS, - metrics.Labels{"config": name}, - ) - defer tracker.Fix() - - a.log.Infow("creating new balancer manager", "name", name) - - if _, ok := a.managers[name]; ok { - a.log.Warnw("balancer manager already exists", "name", name) - return fmt.Errorf( - "balancer manager with name '%s' already exists", - name, - ) - } - - // Convert and validate config - managerConfig, err := ProtoToManagerConfig(config) - if err != nil { - a.log.Errorw("failed to convert config", "name", name, "error", err) - return fmt.Errorf("config is invalid: %w", err) - } - - managerHandle, err := a.handle.NewManager(name, managerConfig) - if err != nil { - a.log.Errorw( - "failed to create balancer manager", - "name", - name, - "error", - err, - ) - return fmt.Errorf("failed to create new balancer manager: %v", err) - } - - a.managers[name] = NewBalancerManager( - managerHandle, - a.log.With("balancer", name), - ) - a.log.Infow("balancer manager created successfully", "name", name) - return nil -} - -func (a *BalancerAgent) BalancerManager(name string) (*BalancerManager, error) { - a.mu.Lock() - defer a.mu.Unlock() - manager, ok := a.managers[name] - if !ok { - return nil, fmt.Errorf( - "balancer manager with name '%s' not found", - name, - ) - } - return manager, nil -} - -func (a *BalancerAgent) Managers() []string { - a.mu.Lock() - defer a.mu.Unlock() - res := []string{} - for name := range a.managers { - res = append(res, name) - } - return res -} - -func (a *BalancerAgent) Inspect() *balancerpb.AgentInspect { - a.mu.Lock() - defer a.mu.Unlock() - - ffiInspect := a.handle.Inspect() - return ConvertAgentInspectToProto(ffiInspect) -} - -func (a *BalancerAgent) Metrics() ([]*commonpb.Metric, error) { - dpConfig := a.handle.DPConfig() - positions := dpConfig.AllModulePositions("balancer") - - managers := make([]*BalancerManager, 0, len(positions)) - { - a.mu.Lock() - - for idx := range positions { - position := &positions[idx] - manager := a.managers[positions[idx].ModuleName] - if manager == nil { - a.log.Warnw( - "metrics: balancer manager not found", - "config", - position.ModuleName, - ) - } - managers = append(managers, manager) - } - - a.mu.Unlock() - } - - result := make([]*commonpb.Metric, 0, len(managers)*200) - - for idx := range positions { - manager := managers[idx] - if manager == nil { - continue - } - position := positions[idx] - ref := balancerpb.PacketHandlerRef{ - Device: &position.Device, - Pipeline: &position.Pipeline, - Function: &position.Function, - Chain: &position.Chain, - } - - metrics, err := manager.Metrics(time.Now(), &ref) - if err != nil { - a.log.Errorf("failed to get metrics", "balancer", manager.Name()) - } else { - result = append(result, metrics...) - } - } - - // append agent metrics - result = append(result, a.handlersMetrics.collect()...) - - return result, nil -} - -// StatsEntries enumerates dataplane balancer positions, -// optionally filters by balancer name and packet-handler ref fields, selects the -// corresponding manager for each position, and returns a list of (name, ref, -// stats) entries. -// -// Filtering rules: -// - if name is specified: only positions with ModuleName == name are included -// - for PacketHandlerRef: each specified field (device/pipeline/function/chain) is matched by strict equality. -func (a *BalancerAgent) StatsEntries( - name *string, - refFilter *balancerpb.PacketHandlerRef, -) ([]*balancerpb.StatsEntry, error) { - dpConfig := a.handle.DPConfig() - positions := dpConfig.AllModulePositions("balancer") - - // Snapshot managers under lock to avoid holding agent mutex during per-position stats reads. - managersByName := make(map[string]*BalancerManager, len(a.managers)) - { - a.mu.Lock() - for k, v := range a.managers { - managersByName[k] = v - } - a.mu.Unlock() - } - - matchesRef := func(posDevice, posPipeline, posFunction, posChain string) bool { - if refFilter == nil { - return true - } - if refFilter.Device != nil && *refFilter.Device != posDevice { - return false - } - if refFilter.Pipeline != nil && *refFilter.Pipeline != posPipeline { - return false - } - if refFilter.Function != nil && *refFilter.Function != posFunction { - return false - } - if refFilter.Chain != nil && *refFilter.Chain != posChain { - return false - } - return true - } - - entries := make([]*balancerpb.StatsEntry, 0) - - for idx := range positions { - position := &positions[idx] - - // Optional manager-name filter - if name != nil && position.ModuleName != *name { - continue - } - - // Optional packet-handler ref filter - if !matchesRef(position.Device, position.Pipeline, position.Function, position.Chain) { - continue - } - - manager := managersByName[position.ModuleName] - if manager == nil { - a.log.Warnw( - "stats: balancer manager not found", - "config", - position.ModuleName, - ) - continue - } - - ref := &balancerpb.PacketHandlerRef{ - Device: &position.Device, - Pipeline: &position.Pipeline, - Function: &position.Function, - Chain: &position.Chain, - } - - stats, err := manager.Stats(ref) - if err != nil { - a.log.Warnw( - "failed to get stats for position", - "config", position.ModuleName, - "device", position.Device, - "pipeline", position.Pipeline, - "function", position.Function, - "chain", position.Chain, - "error", err, - ) - continue - } - - entries = append(entries, &balancerpb.StatsEntry{ - Name: position.ModuleName, - Ref: ref, - Stats: stats, - }) - } - - return entries, nil -} diff --git a/modules/balancer/agent/go/agent_test.go b/modules/balancer/agent/go/agent_test.go deleted file mode 100644 index 69d2f3d59..000000000 --- a/modules/balancer/agent/go/agent_test.go +++ /dev/null @@ -1,603 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/durationpb" -) - -func TestBalancerAgent(t *testing.T) { - // Create mock Yanet instance - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 27, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: "eth0", - }, - }, - }) - require.NoError(t, err, "failed to initialize mock") - require.NotNil(t, m, "mock is nil") - defer m.Free() - - // Create logger for tests - log := zap.NewNop().Sugar() - - // Create balancer agent - agent, err := NewBalancerAgent(m.SharedMemory(), 32*datasize.MB, log) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - // Verify initial state - no managers - managers := agent.Managers() - assert.Empty(t, managers, "expected no managers initially") - - // Define first manager configuration with zero refresh period - firstManagerConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{ - FixMss: true, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.1.1.1"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.1.1.3").AsSlice()}, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.13.11.215").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: nil, - RefreshPeriod: durationpb.New(0), // Zero refresh period - Wlc: nil, - }, - } - - t.Run("NewBalancerManager_First", func(t *testing.T) { - err := agent.NewBalancerManager("balancer0", firstManagerConfig) - require.NoError(t, err, "failed to create first manager") - }) - - t.Run("BalancerManager_First", func(t *testing.T) { - manager, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get first manager") - require.NotNil(t, manager, "manager is nil") - assert.Equal(t, "balancer0", manager.Name()) - }) - - t.Run("Managers_One", func(t *testing.T) { - managers := agent.Managers() - assert.Len(t, managers, 1, "expected one manager") - assert.Contains(t, managers, "balancer0") - }) - - // Define second manager configuration with zero refresh period - secondManagerConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 15, - TcpSyn: 25, - TcpFin: 20, - Tcp: 69, - Udp: 15, - Default: 25, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.20.30.40").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{ - Ops: true, - Gre: true, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.20.30.40"). - AsSlice(), - }, - Port: 8443, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.17.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 150, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.2.2.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.2.2.3").AsSlice()}, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.20.30.40").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::a").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.15.12.216").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(2000); return &v }(), - SessionTableMaxLoadFactor: nil, - RefreshPeriod: durationpb.New(0), // Zero refresh period - Wlc: nil, - }, - } - - t.Run("NewBalancerManager_Second", func(t *testing.T) { - err := agent.NewBalancerManager("balancer1", secondManagerConfig) - require.NoError(t, err, "failed to create second manager") - }) - - t.Run("Managers_Two", func(t *testing.T) { - managers := agent.Managers() - assert.Len(t, managers, 2, "expected two managers") - assert.Contains(t, managers, "balancer0") - assert.Contains(t, managers, "balancer1") - }) - - t.Run("BalancerManager_Both", func(t *testing.T) { - // Test retrieving each manager - manager0, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get balancer0") - assert.Equal(t, "balancer0", manager0.Name()) - - manager1, err := agent.BalancerManager("balancer1") - require.NoError(t, err, "failed to get balancer1") - assert.Equal(t, "balancer1", manager1.Name()) - }) - - t.Run("NewBalancerManager_DuplicateName", func(t *testing.T) { - // Attempt to create manager with existing name - err := agent.NewBalancerManager("balancer0", firstManagerConfig) - require.Error( - t, - err, - "expected error when creating manager with duplicate name", - ) - assert.Contains( - t, - err.Error(), - "already exists", - "error should mention manager already exists", - ) - }) - - t.Run("BalancerManager_NonExistent", func(t *testing.T) { - // Attempt to retrieve non-existent manager - _, err := agent.BalancerManager("nonexistent") - require.Error( - t, - err, - "expected error when retrieving non-existent manager", - ) - assert.Contains( - t, - err.Error(), - "not found", - "error should mention manager not found", - ) - }) - - t.Run("UpdateManager_First", func(t *testing.T) { - // Update first manager configuration - newSessionTimeouts := balancerpb.SessionsTimeouts{ - TcpSynAck: 30, // Changed from 10 - TcpSyn: 40, // Changed from 20 - TcpFin: 35, // Changed from 15 - Tcp: 69, // Changed from 100 - Udp: 21, // Changed from 11 - Default: 39, // Changed from 19 - } - - update := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &newSessionTimeouts, - }, - } - - manager, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager") - - // Get config before update for comparison - configBeforeUpdate := manager.Config() - - _, err = manager.Update(update, m.CurrentTime()) - require.NoError(t, err, "failed to update manager") - - // Verify update by retrieving manager again - manager, err = agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager after update") - - newConfig := manager.Config() - - // Verify only the session timeouts were updated - assert.Equal( - t, - newSessionTimeouts.TcpSynAck, - newConfig.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - assert.Equal( - t, - newSessionTimeouts.TcpSyn, - newConfig.PacketHandler.SessionsTimeouts.TcpSyn, - ) - assert.Equal( - t, - newSessionTimeouts.TcpFin, - newConfig.PacketHandler.SessionsTimeouts.TcpFin, - ) - assert.Equal( - t, - newSessionTimeouts.Tcp, - newConfig.PacketHandler.SessionsTimeouts.Tcp, - ) - assert.Equal( - t, - newSessionTimeouts.Udp, - newConfig.PacketHandler.SessionsTimeouts.Udp, - ) - assert.Equal( - t, - newSessionTimeouts.Default, - newConfig.PacketHandler.SessionsTimeouts.Default, - ) - - // Verify other fields remain unchanged (compare with config before update) - assert.Equal( - t, - configBeforeUpdate.PacketHandler.Vs[0].AllowedSrcs[0].Nets[0].Addr.Bytes, - newConfig.PacketHandler.Vs[0].AllowedSrcs[0].Nets[0].Addr.Bytes, - ) - assert.Equal( - t, - configBeforeUpdate.State.SessionTableMaxLoadFactor, - newConfig.State.SessionTableMaxLoadFactor, - ) - }) - - t.Run("UpdateManager_ConsecutiveCalls", func(t *testing.T) { - // Verify consecutive calls return updated config - manager1, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager (first call)") - config1 := manager1.Config() - - manager2, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager (second call)") - config2 := manager2.Config() - - // Both calls should return the same updated values - assert.Equal( - t, - config1.PacketHandler.SessionsTimeouts.TcpSynAck, - config2.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - assert.Equal( - t, - config1.PacketHandler.SessionsTimeouts.Tcp, - config2.PacketHandler.SessionsTimeouts.Tcp, - ) - assert.Equal( - t, - uint32(30), - config2.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - assert.Equal(t, uint32(69), config2.PacketHandler.SessionsTimeouts.Tcp) - }) - - t.Run("UpdateManager_Second", func(t *testing.T) { - // Update second manager configuration - update source addresses - newSourceV4 := &balancerpb.Addr{ - Bytes: netip.MustParseAddr("30.30.40.50"). - AsSlice(), - // Changed from 20, 20, 30, 40 - } - newSourceV6 := &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::14"). - AsSlice(), - // Changed last byte from 10 to 20 - } - - update := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: newSourceV4, - SourceAddressV6: newSourceV6, - }, - } - - manager, err := agent.BalancerManager("balancer1") - require.NoError(t, err, "failed to get manager") - - // Get config before update for comparison - configBeforeUpdate := manager.Config() - - _, err = manager.Update(update, m.CurrentTime()) - require.NoError(t, err, "failed to update manager") - - // Verify update by retrieving manager again - manager, err = agent.BalancerManager("balancer1") - require.NoError(t, err, "failed to get manager after update") - - newConfig := manager.Config() - - // Verify only the source addresses were updated - assert.Equal( - t, - newSourceV4.Bytes, - newConfig.PacketHandler.SourceAddressV4.Bytes, - ) - assert.Equal( - t, - newSourceV6.Bytes, - newConfig.PacketHandler.SourceAddressV6.Bytes, - ) - - // Verify other fields remain unchanged (compare with config before update) - assert.Equal( - t, - configBeforeUpdate.PacketHandler.SessionsTimeouts, - newConfig.PacketHandler.SessionsTimeouts, - ) - // Verify decap addresses remain unchanged - assert.Equal( - t, - configBeforeUpdate.PacketHandler.DecapAddresses, - newConfig.PacketHandler.DecapAddresses, - ) - // Verify allowed sources remain unchanged (using new ACL structure) - assert.Equal( - t, - configBeforeUpdate.PacketHandler.Vs[0].AllowedSrcs[0].Nets[0].Addr.Bytes, - newConfig.PacketHandler.Vs[0].AllowedSrcs[0].Nets[0].Addr.Bytes, - ) - assert.Equal( - t, - configBeforeUpdate.PacketHandler.Vs[0].Id, - newConfig.PacketHandler.Vs[0].Id, - ) - assert.Equal( - t, - configBeforeUpdate.State.SessionTableMaxLoadFactor, - newConfig.State.SessionTableMaxLoadFactor, - ) - assert.Equal( - t, - configBeforeUpdate.State.SessionTableCapacity, - newConfig.State.SessionTableCapacity, - ) - }) - - t.Run("VerifyTagInConfig", func(t *testing.T) { - // Test that tag field is properly shown in config - manager, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager") - - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.NotNil( - t, - config.PacketHandler, - "packet handler should not be nil", - ) - require.Len( - t, - config.PacketHandler.Vs, - 1, - "should have 1 virtual service", - ) - require.Len( - t, - config.PacketHandler.Vs[0].AllowedSrcs, - 1, - "should have 1 allowed source", - ) - - // Verify tag is nil since it wasn't specified in the config - assert.Nil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "tag should be nil when not specified", - ) - }) - - t.Run("UpdateWithTag", func(t *testing.T) { - // Update config with a specific tag value - update := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{ - FixMss: true, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.1.1.1"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "54321"; return &s }(), // Set a specific tag - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.1.1.3").AsSlice()}, - }, - }, - }, - }, - } - - manager, err := agent.BalancerManager("balancer0") - require.NoError(t, err, "failed to get manager") - - _, err = manager.Update(update, m.CurrentTime()) - require.NoError(t, err, "failed to update manager") - - // Verify the tag is updated - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.NotNil( - t, - config.PacketHandler, - "packet handler should not be nil", - ) - require.Len( - t, - config.PacketHandler.Vs, - 1, - "should have 1 virtual service", - ) - require.Len( - t, - config.PacketHandler.Vs[0].AllowedSrcs, - 1, - "should have 1 allowed source", - ) - - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "tag should not be nil", - ) - assert.Equal( - t, - "54321", - *config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "tag should be updated to 54321", - ) - }) -} diff --git a/modules/balancer/agent/go/conversion.go b/modules/balancer/agent/go/conversion.go deleted file mode 100644 index 7e507fffe..000000000 --- a/modules/balancer/agent/go/conversion.go +++ /dev/null @@ -1,1631 +0,0 @@ -package balancer - -import ( - "fmt" - "net/netip" - - "github.com/yanet-platform/yanet2/common/go/xnetip" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" -) - -// Protobuf to FFI conversions - -func NewRealUpdateFromProto( - update *balancerpb.RealUpdate, -) (*ffi.RealUpdate, error) { - if update.RealId == nil || update.RealId.Vs == nil || - update.RealId.Real == nil { - return nil, fmt.Errorf("incomplete real identifier in update") - } - - vip, ok := netip.AddrFromSlice(update.RealId.Vs.Addr.Bytes) - if !ok { - return nil, fmt.Errorf("incorrect virtual service IP") - } - realIP, ok := netip.AddrFromSlice(update.RealId.Real.Ip.Bytes) - if !ok { - return nil, fmt.Errorf("incorrect real ip") - } - - proto := ffi.VsTransportProtoUDP - if update.RealId.Vs.Proto == balancerpb.TransportProto_TCP { - proto = ffi.VsTransportProtoTCP - } - - // Use the real port as specified (don't default to VS port) - realPort := uint16(update.RealId.Real.Port) - - result := &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: vip, - Port: uint16(update.RealId.Vs.Port), - TransportProto: proto, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: realIP, - Port: realPort, - }, - }, - Weight: ffi.DontUpdateRealWeight, - Enabled: ffi.DontUpdateRealEnabled, - } - - if update.Weight != nil { - result.Weight = uint16(*update.Weight) - } - - if update.Enable != nil { - if *update.Enable { - result.Enabled = 1 - } else { - result.Enabled = 0 - } - } - - return result, nil -} - -func ProtoToFFIConfig( - config *balancerpb.BalancerConfig, -) (ffi.BalancerConfig, error) { - if config.PacketHandler == nil { - return ffi.BalancerConfig{}, fmt.Errorf( - "packet_handler is required in CREATE mode", - ) - } - if config.State == nil { - return ffi.BalancerConfig{}, fmt.Errorf( - "state config is required in CREATE mode", - ) - } - if config.State.SessionTableCapacity == nil { - return ffi.BalancerConfig{}, fmt.Errorf( - "session_table_capacity is required in CREATE mode", - ) - } - if config.State.SessionTableMaxLoadFactor == nil { - return ffi.BalancerConfig{}, fmt.Errorf( - "session_table_max_load_factor is required in CREATE mode", - ) - } - if config.State.RefreshPeriod == nil { - return ffi.BalancerConfig{}, fmt.Errorf( - "refresh_period is required in CREATE mode", - ) - } - - handlerConfig, err := ProtoToHandlerConfig(config.PacketHandler) - if err != nil { - return ffi.BalancerConfig{}, err - } - - return ffi.BalancerConfig{ - State: ffi.StateConfig{ - TableCapacity: uint(*config.State.SessionTableCapacity), - }, - Handler: handlerConfig, - }, nil -} - -// ProtoToManagerConfig converts protobuf config to FFI manager config for CREATE mode -// Validates that all required fields are present -func ProtoToManagerConfig( - config *balancerpb.BalancerConfig, -) (*ffi.BalancerManagerConfig, error) { - // Validate required fields - if config == nil { - return nil, fmt.Errorf("config is nil") - } - if config.PacketHandler == nil { - return nil, fmt.Errorf("packet_handler is required") - } - if config.State == nil { - return nil, fmt.Errorf("state config is required") - } - if config.State.SessionTableCapacity == nil { - return nil, fmt.Errorf("session_table_capacity is required") - } - - // Check if any of refresh_period, max_load_factor, or wlc is present - hasRefreshPeriod := config.State.RefreshPeriod != nil - isRefreshPeriodValued := hasRefreshPeriod && - config.State.RefreshPeriod.AsDuration() != 0 - hasMaxLoadFactor := config.State.SessionTableMaxLoadFactor != nil - hasWlc := config.State.Wlc != nil - - // If any one is present, all three must be present - if isRefreshPeriodValued || hasMaxLoadFactor || hasWlc { - if !hasRefreshPeriod { - return nil, fmt.Errorf( - "refresh_period is required when max_load_factor or wlc is specified", - ) - } - if !hasMaxLoadFactor { - return nil, fmt.Errorf( - "session_table_max_load_factor is required when refresh_period or wlc is specified", - ) - } - if !hasWlc { - return nil, fmt.Errorf( - "wlc config is required when refresh_period or session_table_max_load_factor is specified", - ) - } - } - - // Convert handler config - handlerConfig, err := ProtoToHandlerConfig(config.PacketHandler) - if err != nil { - return nil, fmt.Errorf("failed to convert handler config: %w", err) - } - - // Create FFI balancer config - balancerConfig := ffi.BalancerConfig{ - State: ffi.StateConfig{ - TableCapacity: uint(*config.State.SessionTableCapacity), - }, - Handler: handlerConfig, - } - - // Create WLC configuration - wlcConfig, err := createWlcConfig(config) - if err != nil { - return nil, fmt.Errorf("failed to create WLC config: %w", err) - } - - // Create manager config - managerConfig := &ffi.BalancerManagerConfig{ - Balancer: balancerConfig, - Wlc: wlcConfig, - } - - // Set refresh period and max load factor if present - if hasRefreshPeriod { - managerConfig.RefreshPeriod = config.State.RefreshPeriod.AsDuration() - } else { - managerConfig.RefreshPeriod = 0 - } - if hasMaxLoadFactor { - managerConfig.MaxLoadFactor = *config.State.SessionTableMaxLoadFactor - } else { - managerConfig.MaxLoadFactor = 0.0 - } - - return managerConfig, nil -} - -func ProtoToHandlerConfig( - config *balancerpb.PacketHandlerConfig, -) (ffi.PacketHandlerConfig, error) { - // Validate required fields (non-optional in UPDATE mode) - if config.SessionsTimeouts == nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "sessions_timeouts is required", - ) - } - if config.SourceAddressV4 == nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "source_address_v4 is required", - ) - } - if config.SourceAddressV6 == nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "source_address_v6 is required", - ) - } - if config.DecapAddresses == nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "decap_addresses is required (can be empty list)", - ) - } - if config.Vs == nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "vs (virtual services) is required", - ) - } - - // Convert session timeouts - timeouts := ffi.SessionsTimeouts{ - TCPSynAck: config.SessionsTimeouts.TcpSynAck, - TCPSyn: config.SessionsTimeouts.TcpSyn, - TCPFin: config.SessionsTimeouts.TcpFin, - TCP: config.SessionsTimeouts.Tcp, - UDP: config.SessionsTimeouts.Udp, - Default: config.SessionsTimeouts.Default, - } - - // Convert source addresses - var sourceV4, sourceV6 netip.Addr - if len(config.SourceAddressV4.Bytes) == 4 { - sourceV4 = netip.AddrFrom4([4]byte(config.SourceAddressV4.Bytes)) - } else { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "source_address_v4 must be a valid IPv4 address", - ) - } - if len(config.SourceAddressV6.Bytes) == 16 { - sourceV6 = netip.AddrFrom16([16]byte(config.SourceAddressV6.Bytes)) - } else { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "source_address_v6 must be a valid IPv6 address", - ) - } - - // Convert decap addresses - decapV4 := make([]netip.Addr, 0) - decapV6 := make([]netip.Addr, 0) - for _, addrMsg := range config.DecapAddresses { - if addrMsg != nil { - if addr, ok := netip.AddrFromSlice(addrMsg.Bytes); ok { - if addr.Is4() { - decapV4 = append(decapV4, addr) - } else { - decapV6 = append(decapV6, addr) - } - } - } - } - - // Convert virtual services - virtualServices := make([]ffi.VsConfig, 0, len(config.Vs)) - for _, protoVs := range config.Vs { - vsConfig, err := protoToVsConfig(protoVs) - if err != nil { - return ffi.PacketHandlerConfig{}, fmt.Errorf( - "failed to convert VS: %w", - err, - ) - } - virtualServices = append(virtualServices, vsConfig) - } - - return ffi.PacketHandlerConfig{ - SessionsTimeouts: timeouts, - VirtualServices: virtualServices, - SourceV4: sourceV4, - SourceV6: sourceV6, - DecapV4: decapV4, - DecapV6: decapV6, - }, nil -} - -func protoToVsConfig( - protoVs *balancerpb.VirtualService, -) (ffi.VsConfig, error) { - if protoVs.Id == nil || protoVs.Id.Addr == nil { - return ffi.VsConfig{}, fmt.Errorf("invalid VS identifier") - } - - // Convert VS address - vsAddr, ok := netip.AddrFromSlice(protoVs.Id.Addr.Bytes) - if !ok { - return ffi.VsConfig{}, fmt.Errorf("invalid VS address") - } - - // Convert proto - var proto ffi.VsTransportProto - if protoVs.Id.Proto == balancerpb.TransportProto_TCP { - proto = ffi.VsTransportProtoTCP - } else { - proto = ffi.VsTransportProtoUDP - } - - // Convert flags - flags := ffi.VsFlags{} - if protoVs.Flags != nil { - flags.GRE = protoVs.Flags.Gre - flags.OPS = protoVs.Flags.Ops - flags.PureL3 = protoVs.Flags.PureL3 - flags.FixMSS = protoVs.Flags.FixMss - } - - // Convert scheduler - var scheduler ffi.VsScheduler - if protoVs.Scheduler == balancerpb.VsScheduler_ROUND_ROBIN { - scheduler = ffi.VsSchedulerRoundRobin - } else { - scheduler = ffi.VsSchedulerSourceHash - } - - // Convert reals - reals := make([]ffi.RealConfig, 0, len(protoVs.Reals)) - for _, protoReal := range protoVs.Reals { - realConfig, err := protoToRealConfig(protoReal) - if err != nil { - return ffi.VsConfig{}, fmt.Errorf( - "failed to convert real: %w", - err, - ) - } - reals = append(reals, realConfig) - } - - // Convert allowed sources - allowedSrc := make([]ffi.AllowedSources, 0, len(protoVs.AllowedSrcs)) - for i, protoAllowedSrc := range protoVs.AllowedSrcs { - if protoAllowedSrc == nil { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d] is nil", - i, - ) - } - if protoAllowedSrc.Nets == nil { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d].net is nil", - i, - ) - } - - nets := make([]xnetip.NetWithMask, len(protoAllowedSrc.Nets)) - - for j, protoNet := range protoAllowedSrc.Nets { - if protoNet == nil { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d].net[%d] is nil", - i, j, - ) - } - - // Convert network address - addr, ok := netip.AddrFromSlice(protoNet.Addr.Bytes) - if !ok { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d]: invalid network address", - i, - ) - } - - // Validate IP version matches VS address - if vsAddr.Is4() != addr.Is4() { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d].net[%d]: IP version mismatch - VS is %s but allowed_src network is %s", - i, - j, - func() string { - if vsAddr.Is4() { - return "IPv4" - } - return "IPv6" - }(), - func() string { - if addr.Is4() { - return "IPv4" - } - return "IPv6" - }(), - ) - } - - // Convert mask bytes - var maskBytes []byte - if protoNet.Mask != nil { - maskBytes = protoNet.Mask.Bytes - } - - // Validate mask length - expectedLen := 4 - if addr.Is6() { - expectedLen = 16 - } - if len(maskBytes) != expectedLen { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d]: invalid mask length: got %d, expected %d", - i, len(maskBytes), expectedLen, - ) - } - - // Create NetWithMask - net, err := xnetip.NewNetWithMask(addr, maskBytes) - if err != nil { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d]: failed to create network: %w", - i, err, - ) - } - - nets[j] = net - } - - // Convert port ranges - portRanges := make([]ffi.PortRange, 0, len(protoAllowedSrc.Ports)) - for j, protoPortRange := range protoAllowedSrc.Ports { - if protoPortRange == nil { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d].ports[%d] is nil", - i, j, - ) - } - if protoPortRange.From > protoPortRange.To { - return ffi.VsConfig{}, fmt.Errorf( - "allowed_src[%d].ports[%d]: invalid range: from=%d > to=%d", - i, j, protoPortRange.From, protoPortRange.To, - ) - } - portRanges = append(portRanges, ffi.PortRange{ - From: uint16(protoPortRange.From), - To: uint16(protoPortRange.To), - }) - } - - // Convert tag from protobuf *string to Go string - var tag string - if protoAllowedSrc.Tag != nil { - tag = *protoAllowedSrc.Tag - } - - allowedSrc = append(allowedSrc, ffi.AllowedSources{ - Nets: nets, - PortRanges: portRanges, - Tag: tag, - }) - } - - // Convert peers - var peersV4, peersV6 []netip.Addr - for _, peerMsg := range protoVs.Peers { - if peerMsg != nil { - if peer, ok := netip.AddrFromSlice(peerMsg.Bytes); ok { - if peer.Is4() { - peersV4 = append(peersV4, peer) - } else { - peersV6 = append(peersV6, peer) - } - } - } - } - - return ffi.VsConfig{ - Identifier: ffi.VsIdentifier{ - Addr: vsAddr, - Port: uint16(protoVs.Id.Port), - TransportProto: proto, - }, - Flags: flags, - Scheduler: scheduler, - Reals: reals, - AllowedSources: allowedSrc, - PeersV4: peersV4, - PeersV6: peersV6, - }, nil -} - -func protoToRealConfig( - protoReal *balancerpb.Real, -) (ffi.RealConfig, error) { - if protoReal.Id == nil || protoReal.Id.Ip == nil { - return ffi.RealConfig{}, fmt.Errorf("invalid real identifier") - } - - realAddr, ok := netip.AddrFromSlice(protoReal.Id.Ip.Bytes) - if !ok { - return ffi.RealConfig{}, fmt.Errorf("invalid real address") - } - - // Validate weight - if protoReal.Weight == 0 { - return ffi.RealConfig{}, fmt.Errorf( - "invalid real weight: weight must be at least 1", - ) - } - - var srcNet xnetip.NetWithMask - if protoReal.SrcAddr != nil && protoReal.SrcMask != nil { - srcAddr, ok := netip.AddrFromSlice(protoReal.SrcAddr.Bytes) - if !ok { - return ffi.RealConfig{}, fmt.Errorf("invalid source address") - } - - // Accept arbitrary masks (no validation for contiguous bits) - maskBytes := protoReal.SrcMask.Bytes - - // Validate mask length matches address type - expectedLen := 4 - if srcAddr.Is6() { - expectedLen = 16 - } - if len(maskBytes) != expectedLen { - return ffi.RealConfig{}, fmt.Errorf( - "invalid source mask length: got %d, expected %d", - len(maskBytes), expectedLen, - ) - } - - var err error - srcNet, err = xnetip.NewNetWithMask(srcAddr, maskBytes) - if err != nil { - return ffi.RealConfig{}, fmt.Errorf( - "invalid source network: %w", - err, - ) - } - } - - return ffi.RealConfig{ - Identifier: ffi.RelativeRealIdentifier{ - Addr: realAddr, - Port: uint16(protoReal.Id.Port), - }, - Src: srcNet, - Weight: uint16(protoReal.Weight), - }, nil -} - -// createWlcConfig creates WLC configuration from protobuf config -// Returns error if WLC is enabled but required fields are missing -func createWlcConfig( - config *balancerpb.BalancerConfig, -) (ffi.BalancerManagerWlcConfig, error) { - wlc := ffi.BalancerManagerWlcConfig{ - Power: 0, - MaxRealWeight: 0, - Vs: []uint32{}, - } - - // Check if any VS has WLC enabled - hasWlcEnabled := false - if config.PacketHandler != nil { - for _, vs := range config.PacketHandler.Vs { - if vs.Flags != nil && vs.Flags.Wlc { - hasWlcEnabled = true - break - } - } - } - - // If WLC is enabled, validate configuration - if hasWlcEnabled { - if config.State == nil || config.State.Wlc == nil { - return wlc, fmt.Errorf( - "wlc config is required when WLC flag is enabled on virtual services", - ) - } - - if config.State.Wlc.Power == nil { - return wlc, fmt.Errorf("wlc.power is required when WLC is enabled") - } - if config.State.Wlc.MaxWeight == nil { - return wlc, fmt.Errorf( - "wlc.max_weight is required when WLC is enabled", - ) - } - - wlc.Power = uint(*config.State.Wlc.Power) - wlc.MaxRealWeight = uint(*config.State.Wlc.MaxWeight) - - // Collect VS indices that have WLC enabled - for i, vs := range config.PacketHandler.Vs { - if vs.Flags != nil && vs.Flags.Wlc { - wlc.Vs = append(wlc.Vs, uint32(i)) - } - } - } - - // Always apply WLC config if provided (even if no VS has WLC enabled) - if config.State != nil && config.State.Wlc != nil { - if config.State.Wlc.Power != nil { - wlc.Power = uint(*config.State.Wlc.Power) - } else { - wlc.Power = 0 - } - if config.State.Wlc.MaxWeight != nil { - wlc.MaxRealWeight = uint(*config.State.Wlc.MaxWeight) - } else { - wlc.MaxRealWeight = 0 - } - } - - return wlc, nil -} - -// mergeBalancerConfig merges new config with current config for UPDATE mode -// Returns merged config with all required fields filled recursively -func mergeBalancerConfig( - newConfig *balancerpb.BalancerConfig, - currentConfig *ffi.BalancerManagerConfig, -) (*balancerpb.BalancerConfig, error) { - merged := &balancerpb.BalancerConfig{} - - // Recursively merge State first to get WLC config - merged.State = mergeStateConfig( - newConfig.State, - currentConfig, - ) - - merged.PacketHandler = mergePacketHandlerConfig( - newConfig.PacketHandler, - ¤tConfig.Balancer.Handler, - ¤tConfig.Wlc, - ) - - return merged, nil -} - -// mergePacketHandlerConfig recursively merges packet handler fields -// If newHandler is nil, returns current handler converted to proto with WLC info -// Otherwise, merges each field individually, using current values for nil fields -func mergePacketHandlerConfig( - newHandler *balancerpb.PacketHandlerConfig, - currentHandler *ffi.PacketHandlerConfig, - wlcConfig *ffi.BalancerManagerWlcConfig, -) *balancerpb.PacketHandlerConfig { - if newHandler == nil { - return convertPacketHandlerToProtoWithWlc(currentHandler, wlcConfig) - } - - merged := &balancerpb.PacketHandlerConfig{} - - // Merge sessions_timeouts - if newHandler.SessionsTimeouts != nil { - merged.SessionsTimeouts = newHandler.SessionsTimeouts - } else { - merged.SessionsTimeouts = &balancerpb.SessionsTimeouts{ - TcpSynAck: currentHandler.SessionsTimeouts.TCPSynAck, - TcpSyn: currentHandler.SessionsTimeouts.TCPSyn, - TcpFin: currentHandler.SessionsTimeouts.TCPFin, - Tcp: currentHandler.SessionsTimeouts.TCP, - Udp: currentHandler.SessionsTimeouts.UDP, - Default: currentHandler.SessionsTimeouts.Default, - } - } - - // Merge source_address_v4 - if newHandler.SourceAddressV4 != nil { - merged.SourceAddressV4 = newHandler.SourceAddressV4 - } else { - merged.SourceAddressV4 = &balancerpb.Addr{ - Bytes: currentHandler.SourceV4.AsSlice(), - } - } - - // Merge source_address_v6 - if newHandler.SourceAddressV6 != nil { - merged.SourceAddressV6 = newHandler.SourceAddressV6 - } else { - merged.SourceAddressV6 = &balancerpb.Addr{ - Bytes: currentHandler.SourceV6.AsSlice(), - } - } - - // Merge decap_addresses (array replacement) - if newHandler.DecapAddresses != nil { - merged.DecapAddresses = newHandler.DecapAddresses - } else { - // Convert current decap addresses - decapAddrs := make( - []*balancerpb.Addr, - 0, - len(currentHandler.DecapV4)+len(currentHandler.DecapV6), - ) - for _, addr := range currentHandler.DecapV4 { - decapAddrs = append(decapAddrs, &balancerpb.Addr{Bytes: addr.AsSlice()}) - } - for _, addr := range currentHandler.DecapV6 { - decapAddrs = append(decapAddrs, &balancerpb.Addr{Bytes: addr.AsSlice()}) - } - merged.DecapAddresses = decapAddrs - } - - // Merge vs (virtual services - array replacement) - if newHandler.Vs != nil { - merged.Vs = newHandler.Vs - } else { - // Convert current virtual services with WLC info - wlcEnabledVs := make(map[uint32]bool) - for _, vsIdx := range wlcConfig.Vs { - wlcEnabledVs[vsIdx] = true - } - - vs := make([]*balancerpb.VirtualService, 0, len(currentHandler.VirtualServices)) - for i := range currentHandler.VirtualServices { - wlcEnabled := wlcEnabledVs[uint32(i)] - vs = append(vs, convertVsConfigToProtoWithWlc(¤tHandler.VirtualServices[i], wlcEnabled)) - } - merged.Vs = vs - } - - return merged -} - -// mergeStateConfig recursively merges state configuration -// If newState is nil, returns current state converted to proto -// Otherwise, merges each field individually, using current values for nil fields -func mergeStateConfig( - newState *balancerpb.StateConfig, - currentConfig *ffi.BalancerManagerConfig, -) *balancerpb.StateConfig { - if newState == nil { - // Return entire current state - capacity := uint64(currentConfig.Balancer.State.TableCapacity) - return &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: ¤tConfig.MaxLoadFactor, - RefreshPeriod: durationpb.New( - currentConfig.RefreshPeriod, - ), - Wlc: convertWlcConfigToProto( - ¤tConfig.Wlc, - ), - } - } - - merged := &balancerpb.StateConfig{} - - // Merge session_table_capacity - if newState.SessionTableCapacity != nil { - merged.SessionTableCapacity = newState.SessionTableCapacity - } else { - capacity := uint64(currentConfig.Balancer.State.TableCapacity) - merged.SessionTableCapacity = &capacity - } - - // Merge session_table_max_load_factor - if newState.SessionTableMaxLoadFactor != nil { - merged.SessionTableMaxLoadFactor = newState.SessionTableMaxLoadFactor - } else { - merged.SessionTableMaxLoadFactor = ¤tConfig.MaxLoadFactor - } - - // Merge refresh_period - if newState.RefreshPeriod != nil { - merged.RefreshPeriod = newState.RefreshPeriod - } else { - merged.RefreshPeriod = durationpb.New(currentConfig.RefreshPeriod) - } - - // Recursively merge WLC - merged.Wlc = mergeWlcConfig(newState.Wlc, ¤tConfig.Wlc) - - return merged -} - -// mergeWlcConfig recursively merges WLC configuration -// If newWlc is nil, returns current WLC converted to proto -// Otherwise, merges each field individually, using current values for nil fields -func mergeWlcConfig( - newWlc *balancerpb.WlcConfig, - currentWlc *ffi.BalancerManagerWlcConfig, -) *balancerpb.WlcConfig { - if newWlc == nil { - return convertWlcConfigToProto(currentWlc) - } - - merged := &balancerpb.WlcConfig{} - - // Merge power - if newWlc.Power != nil { - merged.Power = newWlc.Power - } else { - if currentWlc.Power != 0 { - power := uint64(currentWlc.Power) - merged.Power = &power - } - } - - // Merge max_weight - if newWlc.MaxWeight != nil { - merged.MaxWeight = newWlc.MaxWeight - } else { - if currentWlc.MaxRealWeight != 0 { - maxWeight := uint32(currentWlc.MaxRealWeight) - merged.MaxWeight = &maxWeight - } - } - - return merged -} - -// FFI to Protobuf conversions - -func ConvertFFIProtoToProto( - proto ffi.VsTransportProto, -) balancerpb.TransportProto { - if proto == ffi.VsTransportProtoTCP { - return balancerpb.TransportProto_TCP - } - return balancerpb.TransportProto_UDP -} - -func ConvertBalancerInfoToProto( - info *ffi.BalancerInfo, -) *balancerpb.BalancerInfo { - vsInfo := make([]*balancerpb.VsInfo, 0, len(info.Vs)) - for i := range info.Vs { - vsInfo = append(vsInfo, ConvertVsInfoToProto(&info.Vs[i])) - } - - return &balancerpb.BalancerInfo{ - ActiveSessions: info.ActiveSessions, - LastPacketTimestamp: timestamppb.New(info.LastPacketTimestamp), - Vs: vsInfo, - } -} - -func ConvertVsInfoToProto(info *ffi.VsInfo) *balancerpb.VsInfo { - reals := make([]*balancerpb.RealInfo, 0, len(info.Reals)) - for i := range info.Reals { - reals = append(reals, ConvertRealInfoToProto(&info.Reals[i])) - } - - return &balancerpb.VsInfo{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: info.Identifier.Addr.AsSlice(), - }, - Port: uint32(info.Identifier.Port), - Proto: ConvertFFIProtoToProto(info.Identifier.TransportProto), - }, - ActiveSessions: info.ActiveSessions, - LastPacketTimestamp: timestamppb.New(info.LastPacketTimestamp), - Reals: reals, - } -} - -func ConvertRealInfoToProto(info *ffi.RealInfo) *balancerpb.RealInfo { - return &balancerpb.RealInfo{ - Id: &balancerpb.RealIdentifier{ - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: info.Dst.AsSlice(), - }, - }, - }, - ActiveSessions: info.ActiveSessions, - LastPacketTimestamp: timestamppb.New(info.LastPacketTimestamp), - } -} - -func ConvertSessionInfoToProto( - identifier *ffi.SessionIdentifier, - info *ffi.SessionInfo, -) *balancerpb.SessionInfo { - return &balancerpb.SessionInfo{ - LastPacketTimestamp: timestamppb.New(info.LastPacketTimestamp), - CreateTimestamp: timestamppb.New(info.CreateTimestamp), - Timeout: durationpb.New(info.Timeout), - ClientAddr: &balancerpb.Addr{ - Bytes: identifier.ClientIP.AsSlice(), - }, - ClientPort: uint32(identifier.ClientPort), - VsId: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: identifier.Real.VsIdentifier.Addr.AsSlice(), - }, - Port: uint32(identifier.Real.VsIdentifier.Port), - Proto: ConvertFFIProtoToProto( - identifier.Real.VsIdentifier.TransportProto, - ), - }, - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: identifier.Real.VsIdentifier.Addr.AsSlice(), - }, - Port: uint32(identifier.Real.VsIdentifier.Port), - Proto: ConvertFFIProtoToProto( - identifier.Real.VsIdentifier.TransportProto, - ), - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: identifier.Real.Relative.Addr.AsSlice(), - }, - Port: uint32(identifier.Real.Relative.Port), - }, - }, - } -} - -func ConvertBalancerStatsToProto( - stats *ffi.BalancerStats, -) *balancerpb.BalancerStats { - vsStats := make([]*balancerpb.NamedVsStats, 0, len(stats.Vs)) - for i := range stats.Vs { - // Convert real stats for this VS - realStats := make( - []*balancerpb.NamedRealStats, - 0, - len(stats.Vs[i].Reals), - ) - for j := range stats.Vs[i].Reals { - realStats = append(realStats, &balancerpb.NamedRealStats{ - Real: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: stats.Vs[i].Identifier.Addr.AsSlice(), - }, - Port: uint32(stats.Vs[i].Identifier.Port), - Proto: ConvertFFIProtoToProto( - stats.Vs[i].Identifier.TransportProto, - ), - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: stats.Vs[i].Reals[j].Dst.AsSlice(), - }, - }, - }, - Stats: ConvertRealStatsToProto(&stats.Vs[i].Reals[j].Stats), - }) - } - - // Convert allowed sources stats for this VS - allowedSourcesStats := make( - []*balancerpb.AllowedSourcesStats, - 0, - len(stats.Vs[i].AllowedSources), - ) - for j := range stats.Vs[i].AllowedSources { - allowedSourcesStats = append( - allowedSourcesStats, - &balancerpb.AllowedSourcesStats{ - Tag: stats.Vs[i].AllowedSources[j].Tag, - Passes: stats.Vs[i].AllowedSources[j].Passes, - }, - ) - } - - vsStats = append(vsStats, &balancerpb.NamedVsStats{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: stats.Vs[i].Identifier.Addr.AsSlice(), - }, - Port: uint32(stats.Vs[i].Identifier.Port), - Proto: ConvertFFIProtoToProto( - stats.Vs[i].Identifier.TransportProto, - ), - }, - Stats: ConvertVsStatsToProto(&stats.Vs[i].Stats), - Reals: realStats, - AllowedSources: allowedSourcesStats, - }) - } - - return &balancerpb.BalancerStats{ - L4: ConvertL4StatsToProto(&stats.L4), - Icmpv4: ConvertIcmpStatsToProto(&stats.IcmpIpv4), - Icmpv6: ConvertIcmpStatsToProto(&stats.IcmpIpv6), - Common: ConvertCommonStatsToProto(&stats.Common), - Vs: vsStats, - } -} - -func ConvertL4StatsToProto(stats *ffi.L4Stats) *balancerpb.L4Stats { - return &balancerpb.L4Stats{ - IncomingPackets: stats.IncomingPackets, - SelectVsFailed: stats.SelectVsFailed, - InvalidPackets: stats.InvalidPackets, - SelectRealFailed: stats.SelectRealFailed, - OutgoingPackets: stats.OutgoingPackets, - } -} - -func ConvertIcmpStatsToProto( - stats *ffi.IcmpStats, -) *balancerpb.IcmpStats { - return &balancerpb.IcmpStats{ - IncomingPackets: stats.IncomingPackets, - SrcNotAllowed: stats.SrcNotAllowed, - EchoResponses: stats.EchoResponses, - PayloadTooShortIp: stats.PayloadTooShortIP, - UnmatchingSrcFromOriginal: stats.UnmatchingSrcFromOriginal, - PayloadTooShortPort: stats.PayloadTooShortPort, - UnexpectedTransport: stats.UnexpectedTransport, - UnrecognizedVs: stats.UnrecognizedVs, - ForwardedPackets: stats.ForwardedPackets, - BroadcastedPackets: stats.BroadcastedPackets, - PacketClonesSent: stats.PacketClonesSent, - PacketClonesReceived: stats.PacketClonesReceived, - PacketCloneFailures: stats.PacketCloneFailures, - } -} - -func ConvertCommonStatsToProto( - stats *ffi.CommonStats, -) *balancerpb.CommonStats { - return &balancerpb.CommonStats{ - IncomingPackets: stats.IncomingPackets, - IncomingBytes: stats.IncomingBytes, - UnexpectedNetworkProto: stats.UnexpectedNetworkProto, - DecapSuccessful: stats.DecapSuccessful, - DecapFailed: stats.DecapFailed, - OutgoingPackets: stats.OutgoingPackets, - OutgoingBytes: stats.OutgoingBytes, - } -} - -func ConvertVsStatsToProto(stats *ffi.VsStats) *balancerpb.VsStats { - return &balancerpb.VsStats{ - IncomingPackets: stats.IncomingPackets, - IncomingBytes: stats.IncomingBytes, - PacketSrcNotAllowed: stats.PacketSrcNotAllowed, - NoReals: stats.NoReals, - OpsPackets: stats.OpsPackets, - SessionTableOverflow: stats.SessionTableOverflow, - EchoIcmpPackets: stats.EchoIcmpPackets, - ErrorIcmpPackets: stats.ErrorIcmpPackets, - RealIsDisabled: stats.RealIsDisabled, - RealIsRemoved: stats.RealIsRemoved, - NotRescheduledPackets: stats.NotRescheduledPackets, - BroadcastedIcmpPackets: stats.BroadcastedIcmpPackets, - CreatedSessions: stats.CreatedSessions, - OutgoingPackets: stats.OutgoingPackets, - OutgoingBytes: stats.OutgoingBytes, - } -} - -func ConvertRealStatsToProto( - stats *ffi.RealStats, -) *balancerpb.RealStats { - return &balancerpb.RealStats{ - PacketsRealDisabled: stats.PacketsRealDisabled, - OpsPackets: stats.OpsPackets, - ErrorIcmpPackets: stats.ErrorIcmpPackets, - CreatedSessions: stats.CreatedSessions, - Packets: stats.Packets, - Bytes: stats.Bytes, - } -} - -// ConvertGraphToProtoWithConfig converts FFI graph to protobuf with proper weight mapping. -// Weight in the result comes from config (original configured weight). -// EffectiveWeight in the result comes from graph (current effective weight after WLC adjustments). -func ConvertGraphToProtoWithConfig( - graph *ffi.BalancerGraph, - config *ffi.BalancerManagerConfig, -) *balancerpb.Graph { - if graph == nil { - return &balancerpb.Graph{} - } - - // Build a lookup map for config weights: VS identifier -> Real identifier -> weight - configWeights := buildConfigWeightsMap(config) - - vsServices := make([]*balancerpb.GraphVs, 0, len(graph.VirtualServices)) - for i := range graph.VirtualServices { - vsServices = append( - vsServices, - convertGraphVsToProtoWithConfig( - &graph.VirtualServices[i], - configWeights, - ), - ) - } - - return &balancerpb.Graph{ - VirtualServices: vsServices, - } -} - -// vsRealKey creates a unique key for a real within a VS context -type vsRealKey struct { - vsAddr string - vsPort uint16 - vsProto ffi.VsTransportProto - realAddr string - realPort uint16 -} - -// buildConfigWeightsMap builds a map from VS+Real identifiers to config weights -func buildConfigWeightsMap( - config *ffi.BalancerManagerConfig, -) map[vsRealKey]uint16 { - weights := make(map[vsRealKey]uint16) - if config == nil { - return weights - } - - for _, vs := range config.Balancer.Handler.VirtualServices { - for _, real := range vs.Reals { - key := vsRealKey{ - vsAddr: vs.Identifier.Addr.String(), - vsPort: vs.Identifier.Port, - vsProto: vs.Identifier.TransportProto, - realAddr: real.Identifier.Addr.String(), - realPort: real.Identifier.Port, - } - weights[key] = real.Weight - } - } - - return weights -} - -func convertGraphVsToProtoWithConfig( - vs *ffi.GraphVs, - configWeights map[vsRealKey]uint16, -) *balancerpb.GraphVs { - reals := make([]*balancerpb.GraphReal, 0, len(vs.Reals)) - for i := range vs.Reals { - // Look up config weight for this real - key := vsRealKey{ - vsAddr: vs.Identifier.Addr.String(), - vsPort: vs.Identifier.Port, - vsProto: vs.Identifier.TransportProto, - realAddr: vs.Reals[i].Identifier.Addr.String(), - realPort: vs.Reals[i].Identifier.Port, - } - - configWeight := uint16(0) - if w, ok := configWeights[key]; ok { - configWeight = w - } - - reals = append(reals, &balancerpb.GraphReal{ - Identifier: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: vs.Reals[i].Identifier.Addr.AsSlice(), - }, - Port: uint32(vs.Reals[i].Identifier.Port), - }, - // Weight = config weight (original configured weight) - Weight: uint32(configWeight), - // EffectiveWeight = graph weight (current effective weight after WLC) - EffectiveWeight: uint32(vs.Reals[i].Weight), - Enabled: vs.Reals[i].Enabled, - }) - } - - return &balancerpb.GraphVs{ - Identifier: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vs.Identifier.Addr.AsSlice(), - }, - Port: uint32(vs.Identifier.Port), - Proto: ConvertFFIProtoToProto(vs.Identifier.TransportProto), - }, - Reals: reals, - } -} - -// ConvertBalancerConfigToProto converts FFI manager config to protobuf -func ConvertBalancerConfigToProto( - config *ffi.BalancerManagerConfig, -) *balancerpb.BalancerConfig { - if config == nil { - return &balancerpb.BalancerConfig{} - } - - // Convert packet handler with WLC config - handler := convertPacketHandlerToProtoWithWlc( - &config.Balancer.Handler, - &config.Wlc, - ) - - // Convert state config - capacity := uint64(config.Balancer.State.TableCapacity) - loadFactor := config.MaxLoadFactor - refreshPeriod := durationpb.New(config.RefreshPeriod) - - return &balancerpb.BalancerConfig{ - PacketHandler: handler, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: &loadFactor, - RefreshPeriod: refreshPeriod, - Wlc: convertWlcConfigToProto(&config.Wlc), - }, - } -} - -func convertPacketHandlerToProtoWithWlc( - handler *ffi.PacketHandlerConfig, - wlcConfig *ffi.BalancerManagerWlcConfig, -) *balancerpb.PacketHandlerConfig { - // Build a set of VS indices that have WLC enabled - wlcEnabledVs := make(map[uint32]bool) - if wlcConfig != nil { - for _, vsIdx := range wlcConfig.Vs { - wlcEnabledVs[vsIdx] = true - } - } - - // Convert virtual services - vs := make([]*balancerpb.VirtualService, 0, len(handler.VirtualServices)) - for i := range handler.VirtualServices { - wlcEnabled := wlcEnabledVs[uint32(i)] - vs = append( - vs, - convertVsConfigToProtoWithWlc( - &handler.VirtualServices[i], - wlcEnabled, - ), - ) - } - - // Convert decap addresses - decapAddrs := make( - []*balancerpb.Addr, - 0, - len(handler.DecapV4)+len(handler.DecapV6), - ) - for _, addr := range handler.DecapV4 { - decapAddrs = append(decapAddrs, &balancerpb.Addr{Bytes: addr.AsSlice()}) - } - for _, addr := range handler.DecapV6 { - decapAddrs = append(decapAddrs, &balancerpb.Addr{Bytes: addr.AsSlice()}) - } - - return &balancerpb.PacketHandlerConfig{ - Vs: vs, - SourceAddressV4: &balancerpb.Addr{Bytes: handler.SourceV4.AsSlice()}, - SourceAddressV6: &balancerpb.Addr{Bytes: handler.SourceV6.AsSlice()}, - DecapAddresses: decapAddrs, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: handler.SessionsTimeouts.TCPSynAck, - TcpSyn: handler.SessionsTimeouts.TCPSyn, - TcpFin: handler.SessionsTimeouts.TCPFin, - Tcp: handler.SessionsTimeouts.TCP, - Udp: handler.SessionsTimeouts.UDP, - Default: handler.SessionsTimeouts.Default, - }, - } -} - -func convertVsConfigToProtoWithWlc( - vs *ffi.VsConfig, - wlcEnabled bool, -) *balancerpb.VirtualService { - // Convert reals - reals := make([]*balancerpb.Real, 0, len(vs.Reals)) - for i := range vs.Reals { - reals = append(reals, convertRealConfigToProto(&vs.Reals[i])) - } - - // Convert allowed sources - allowedSrcs := make([]*balancerpb.AllowedSources, 0, len(vs.AllowedSources)) - for _, allowedSrc := range vs.AllowedSources { - // Convert networks - nets := make([]*balancerpb.Net, 0, len(allowedSrc.Nets)) - for _, net := range allowedSrc.Nets { - nets = append(nets, &balancerpb.Net{ - Addr: &balancerpb.Addr{Bytes: net.Addr.AsSlice()}, - Mask: &balancerpb.Addr{Bytes: net.MaskBytes()}, - }) - } - - // Convert port ranges - protoPortRanges := make( - []*balancerpb.PortsRange, - 0, - len(allowedSrc.PortRanges), - ) - for _, portRange := range allowedSrc.PortRanges { - protoPortRanges = append(protoPortRanges, &balancerpb.PortsRange{ - From: uint32(portRange.From), - To: uint32(portRange.To), - }) - } - - // Convert tag from Go string to protobuf *string - var protoTag *string - if allowedSrc.Tag != "" { - tagCopy := allowedSrc.Tag - protoTag = &tagCopy - } - - allowedSrcs = append(allowedSrcs, &balancerpb.AllowedSources{ - Nets: nets, - Ports: protoPortRanges, - Tag: protoTag, - }) - } - - // Convert peers - peers := make([]*balancerpb.Addr, 0, len(vs.PeersV4)+len(vs.PeersV6)) - for _, peer := range vs.PeersV4 { - peers = append(peers, &balancerpb.Addr{Bytes: peer.AsSlice()}) - } - for _, peer := range vs.PeersV6 { - peers = append(peers, &balancerpb.Addr{Bytes: peer.AsSlice()}) - } - - scheduler := balancerpb.VsScheduler_SOURCE_HASH - if vs.Scheduler == ffi.VsSchedulerRoundRobin { - scheduler = balancerpb.VsScheduler_ROUND_ROBIN - } - - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: vs.Identifier.Addr.AsSlice()}, - Port: uint32(vs.Identifier.Port), - Proto: ConvertFFIProtoToProto(vs.Identifier.TransportProto), - }, - Scheduler: scheduler, - AllowedSrcs: allowedSrcs, - Reals: reals, - Flags: &balancerpb.VsFlags{ - Gre: vs.Flags.GRE, - FixMss: vs.Flags.FixMSS, - Ops: vs.Flags.OPS, - PureL3: vs.Flags.PureL3, - Wlc: wlcEnabled, - }, - Peers: peers, - } -} - -func convertRealConfigToProto(real *ffi.RealConfig) *balancerpb.Real { - srcAddr := real.Src.Addr.AsSlice() - srcMask := real.Src.MaskBytes() - - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real.Identifier.Addr.AsSlice()}, - Port: uint32(real.Identifier.Port), - }, - Weight: uint32(real.Weight), - SrcAddr: &balancerpb.Addr{Bytes: srcAddr}, - SrcMask: &balancerpb.Addr{Bytes: srcMask}, - } -} - -func convertWlcConfigToProto( - wlc *ffi.BalancerManagerWlcConfig, -) *balancerpb.WlcConfig { - if wlc == nil || (wlc.Power == 0 && wlc.MaxRealWeight == 0) { - return nil - } - power := uint64(wlc.Power) - maxWeight := uint32(wlc.MaxRealWeight) - return &balancerpb.WlcConfig{ - Power: &power, - MaxWeight: &maxWeight, - } -} - -// ConvertProtoToFFIPacketHandlerRef converts protobuf ref to FFI -func ConvertProtoToFFIPacketHandlerRef( - ref *balancerpb.PacketHandlerRef, -) *ffi.PacketHandlerRef { - if ref == nil { - return &ffi.PacketHandlerRef{} - } - - result := &ffi.PacketHandlerRef{} - if ref.Device != nil { - device := *ref.Device - result.Device = &device - } - if ref.Pipeline != nil { - pipeline := *ref.Pipeline - result.Pipeline = &pipeline - } - if ref.Function != nil { - function := *ref.Function - result.Function = &function - } - if ref.Chain != nil { - chain := *ref.Chain - result.Chain = &chain - } - return result -} - -// ConvertUpdateInfoToProto converts FFI update info to protobuf -func ConvertUpdateInfoToProto( - info *ffi.UpdateInfo, created bool, -) *balancerpb.UpdateInfo { - if info == nil { - return nil - } - - // Convert VS identifiers - vsIdentifiers := make([]*balancerpb.VsIdentifier, 0, len(info.ACLReusedVs)) - for i := range info.ACLReusedVs { - vsIdentifiers = append(vsIdentifiers, &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: info.ACLReusedVs[i].Addr.AsSlice(), - }, - Port: uint32(info.ACLReusedVs[i].Port), - Proto: ConvertFFIProtoToProto(info.ACLReusedVs[i].TransportProto), - }) - } - - return &balancerpb.UpdateInfo{ - Created: created, - VsIpv4MatcherReused: info.VsIpv4MatcherReused, - VsIpv6MatcherReused: info.VsIpv6MatcherReused, - VsAclReuses: vsIdentifiers, - } -} - -// ConvertFFIRealUpdateToProto converts FFI real update to protobuf -func ConvertFFIRealUpdateToProto( - update *ffi.RealUpdate, -) *balancerpb.RealUpdate { - result := &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: update.Identifier.VsIdentifier.Addr.AsSlice(), - }, - Port: uint32(update.Identifier.VsIdentifier.Port), - Proto: ConvertFFIProtoToProto( - update.Identifier.VsIdentifier.TransportProto, - ), - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: update.Identifier.Relative.Addr.AsSlice(), - }, - Port: uint32(update.Identifier.Relative.Port), - }, - }, - } - - if update.Weight != ffi.DontUpdateRealWeight { - weight := uint32(update.Weight) - result.Weight = &weight - } - - if update.Enabled != ffi.DontUpdateRealEnabled { - enabled := update.Enabled != 0 - result.Enable = &enabled - } - - return result -} - -// ConvertAgentInspectToProto converts FFI agent inspect to protobuf -func ConvertAgentInspectToProto( - inspect *ffi.AgentInspect, -) *balancerpb.AgentInspect { - if inspect == nil { - return &balancerpb.AgentInspect{} - } - - balancers := make([]*balancerpb.BalancerInspect, 0, len(inspect.Balancers)) - for i := range inspect.Balancers { - balancers = append( - balancers, - ConvertNamedBalancerInspectToProto(&inspect.Balancers[i]), - ) - } - - return &balancerpb.AgentInspect{ - MemoryUsage: inspect.MemoryUsage, - MemoryLimit: inspect.MemoryLimit, - Balancers: balancers, - } -} - -// ConvertNamedBalancerInspectToProto converts FFI named balancer inspect to protobuf -func ConvertNamedBalancerInspectToProto( - inspect *ffi.NamedBalancerInspect, -) *balancerpb.BalancerInspect { - if inspect == nil { - return &balancerpb.BalancerInspect{} - } - - return &balancerpb.BalancerInspect{ - Name: inspect.Name, - PacketHandlerInspect: ConvertPacketHandlerInspectToProto( - &inspect.Inspect.PacketHandler, - ), - StateInspect: ConvertStateInspectToProto( - &inspect.Inspect.State, - ), - OtherUsage: inspect.Inspect.OtherUsage, - TotalUsage: inspect.Inspect.TotalUsage, - } -} - -// ConvertPacketHandlerInspectToProto converts FFI packet handler inspect to protobuf -func ConvertPacketHandlerInspectToProto( - inspect *ffi.PacketHandlerInspect, -) *balancerpb.PacketHandlerInspect { - if inspect == nil { - return &balancerpb.PacketHandlerInspect{} - } - - return &balancerpb.PacketHandlerInspect{ - VsIpv4Inspect: ConvertPacketHandlerVsInspectToProto( - &inspect.VsIpv4Inspect, - ), - VsIpv6Inspect: ConvertPacketHandlerVsInspectToProto( - &inspect.VsIpv6Inspect, - ), - SummaryVsUsage: inspect.SummaryVsUsage, - VsIndexUsage: inspect.VsIndexUsage, - RealsIndexUsage: inspect.RealsIndexUsage, - CountersUsage: inspect.CountersUsage, - DecapUsage: inspect.DecapUsage, - TotalUsage: inspect.TotalUsage, - } -} - -// ConvertPacketHandlerVsInspectToProto converts FFI packet handler VS inspect to protobuf -func ConvertPacketHandlerVsInspectToProto( - inspect *ffi.PacketHandlerVsInspect, -) *balancerpb.PacketHandlerVsInspect { - if inspect == nil { - return &balancerpb.PacketHandlerVsInspect{} - } - - vsInspects := make([]*balancerpb.NamedVsInspect, 0, len(inspect.VsInspects)) - for i := range inspect.VsInspects { - vsInspects = append( - vsInspects, - ConvertNamedVsInspectToProto(&inspect.VsInspects[i]), - ) - } - - return &balancerpb.PacketHandlerVsInspect{ - MatcherUsage: inspect.MatcherUsage, - SummaryVsUsage: inspect.SummaryVsUsage, - VsInspects: vsInspects, - AnnounceUsage: inspect.AnnounceUsage, - IndexUsage: inspect.IndexUsage, - TotalUsage: inspect.TotalUsage, - } -} - -// ConvertNamedVsInspectToProto converts FFI named VS inspect to protobuf -func ConvertNamedVsInspectToProto( - inspect *ffi.NamedVsInspect, -) *balancerpb.NamedVsInspect { - if inspect == nil { - return &balancerpb.NamedVsInspect{} - } - - return &balancerpb.NamedVsInspect{ - Identifier: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: inspect.Identifier.Addr.AsSlice(), - }, - Port: uint32(inspect.Identifier.Port), - Proto: ConvertFFIProtoToProto(inspect.Identifier.TransportProto), - }, - Inspect: ConvertVsInspectToProto(&inspect.Inspect), - } -} - -// ConvertVsInspectToProto converts FFI VS inspect to protobuf -func ConvertVsInspectToProto( - inspect *ffi.VsInspect, -) *balancerpb.VsInspect { - if inspect == nil { - return &balancerpb.VsInspect{} - } - - return &balancerpb.VsInspect{ - AclUsage: inspect.ACLUsage, - RingUsage: inspect.RingUsage, - CountersUsage: inspect.CountersUsage, - RealsUsage: ConvertRealsUsageToProto(&inspect.RealsUsage), - OtherUsage: inspect.OtherUsage, - TotalUsage: inspect.TotalUsage, - } -} - -// ConvertStateInspectToProto converts FFI state inspect to protobuf -func ConvertStateInspectToProto( - inspect *ffi.StateInspect, -) *balancerpb.StateInspect { - if inspect == nil { - return &balancerpb.StateInspect{} - } - - return &balancerpb.StateInspect{ - SessionTableUsage: inspect.SessionTableUsage, - TotalUsage: inspect.TotalUsage, - } -} - -// ConvertRealsUsageToProto converts FFI reals usage to protobuf -func ConvertRealsUsageToProto( - usage *ffi.RealsUsage, -) *balancerpb.RealsUsage { - if usage == nil { - return &balancerpb.RealsUsage{} - } - - return &balancerpb.RealsUsage{ - CountersUsage: usage.CountersUsage, - DataUsage: usage.DataUsage, - TotalUsage: usage.TotalUsage, - } -} diff --git a/modules/balancer/agent/go/conversion_test.go b/modules/balancer/agent/go/conversion_test.go deleted file mode 100644 index d08d23fdb..000000000 --- a/modules/balancer/agent/go/conversion_test.go +++ /dev/null @@ -1,2103 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xnetip" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "google.golang.org/protobuf/types/known/durationpb" -) - -// TestNewRealUpdateFromProto_Valid tests valid real update conversion -func TestNewRealUpdateFromProto_Valid(t *testing.T) { - tests := []struct { - name string - proto *balancerpb.RealUpdate - expected *ffi.RealUpdate - }{ - { - name: "Complete update with weight and enable", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 8080, - }, - }, - Weight: ptrUint32(100), - Enable: ptrBool(true), - }, - expected: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 8080, - }, - }, - Weight: 100, - Enabled: 1, - }, - }, - { - name: "Update with only weight", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_UDP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.1").AsSlice(), - }, - Port: 0, - }, - }, - Weight: ptrUint32(200), - }, - expected: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 443, - TransportProto: ffi.VsTransportProtoUDP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("172.16.0.1"), - Port: 0, - }, - }, - Weight: 200, - Enabled: ffi.DontUpdateRealEnabled, - }, - }, - { - name: "Update with only enable", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - Port: 53, - Proto: balancerpb.TransportProto_UDP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::100"). - AsSlice(), - }, - Port: 5353, - }, - }, - Enable: ptrBool(false), - }, - expected: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 5353, - }, - }, - Weight: ffi.DontUpdateRealWeight, - Enabled: 0, - }, - }, - { - name: "Update with neither weight nor enable", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.10.10.10").AsSlice(), - }, - Port: 8080, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.10.10.11").AsSlice(), - }, - Port: 9090, - }, - }, - }, - expected: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.10.10.10"), - Port: 8080, - TransportProto: ffi.VsTransportProtoTCP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.10.10.11"), - Port: 9090, - }, - }, - Weight: ffi.DontUpdateRealWeight, - Enabled: ffi.DontUpdateRealEnabled, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := NewRealUpdateFromProto(tt.proto) - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestNewRealUpdateFromProto_Errors tests error cases -func TestNewRealUpdateFromProto_Errors(t *testing.T) { - tests := []struct { - name string - proto *balancerpb.RealUpdate - }{ - { - name: "Nil real_id", - proto: &balancerpb.RealUpdate{ - RealId: nil, - }, - }, - { - name: "Nil VS in real_id", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: nil, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - }, - }, - }, - }, - { - name: "Nil real in real_id", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: nil, - }, - }, - }, - { - name: "Invalid VS IP address", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.1.2.3").AsSlice()[:3], - }, // Invalid length - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - }, - }, - }, - }, - { - name: "Invalid real IP address", - proto: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.1.2").AsSlice()[:2], - }, // Invalid length - }, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewRealUpdateFromProto(tt.proto) - assert.Error(t, err) - }) - } -} - -// TestProtoToFFIConfig_Valid tests valid config conversion -func TestProtoToFFIConfig_Valid(t *testing.T) { - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1"). - AsSlice(), - }, - Port: 8080, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - }, - Flags: &balancerpb.VsFlags{ - FixMss: true, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("192.168.1.1").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - }, - } - - result, err := ProtoToFFIConfig(config) - require.NoError(t, err) - assert.Equal(t, uint(1000), result.State.TableCapacity) - assert.Len(t, result.Handler.VirtualServices, 1) - assert.Equal(t, netip.MustParseAddr("10.0.0.1"), result.Handler.SourceV4) - assert.Equal(t, netip.MustParseAddr("2001:db8::1"), result.Handler.SourceV6) -} - -// TestProtoToFFIConfig_MissingRequiredFields tests missing field validation -func TestProtoToFFIConfig_MissingRequiredFields(t *testing.T) { - tests := []struct { - name string - config *balancerpb.BalancerConfig - }{ - { - name: "Missing packet_handler", - config: &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - }, - }, - }, - { - name: "Missing state", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - }, - { - name: "Missing session_table_capacity", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - }, - }, - }, - { - name: "Missing session_table_max_load_factor", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - RefreshPeriod: durationpb.New(5 * time.Second), - }, - }, - }, - { - name: "Missing refresh_period", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ProtoToFFIConfig(tt.config) - assert.Error(t, err) - }) - } -} - -// TestProtoToManagerConfig_WLCValidation tests WLC field interdependencies -func TestProtoToManagerConfig_WLCValidation(t *testing.T) { - baseConfig := func() *balancerpb.BalancerConfig { - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - }, - } - } - - tests := []struct { - name string - config *balancerpb.BalancerConfig - shouldErr bool - }{ - { - name: "All three WLC fields present - valid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - c.State.RefreshPeriod = durationpb.New(5 * time.Second) - c.State.SessionTableMaxLoadFactor = ptrFloat32(0.75) - c.State.Wlc = &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - } - return c - }(), - shouldErr: false, - }, - { - name: "None present - valid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - return c - }(), - shouldErr: false, - }, - { - name: "Only refresh_period - invalid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - c.State.RefreshPeriod = durationpb.New(5 * time.Second) - return c - }(), - shouldErr: true, - }, - { - name: "Only max_load_factor - invalid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - c.State.SessionTableMaxLoadFactor = ptrFloat32(0.75) - return c - }(), - shouldErr: true, - }, - { - name: "Only WLC - invalid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - c.State.Wlc = &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - } - return c - }(), - shouldErr: true, - }, - { - name: "refresh_period and max_load_factor without WLC - invalid", - config: func() *balancerpb.BalancerConfig { - c := baseConfig() - c.State.RefreshPeriod = durationpb.New(5 * time.Second) - c.State.SessionTableMaxLoadFactor = ptrFloat32(0.75) - return c - }(), - shouldErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ProtoToManagerConfig(tt.config) - if tt.shouldErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -// TestProtoToHandlerConfig_Valid tests handler config conversion -func TestProtoToHandlerConfig_Valid(t *testing.T) { - config := &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{}, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("192.168.1.1").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::100").AsSlice()}, - }, - } - - result, err := ProtoToHandlerConfig(config) - require.NoError(t, err) - assert.Equal(t, uint32(10), result.SessionsTimeouts.TCPSynAck) - assert.Len(t, result.VirtualServices, 1) - assert.Len(t, result.DecapV4, 1) - assert.Len(t, result.DecapV6, 1) -} - -// TestProtoToHandlerConfig_MissingFields tests missing field validation -func TestProtoToHandlerConfig_MissingFields(t *testing.T) { - tests := []struct { - name string - config *balancerpb.PacketHandlerConfig - }{ - { - name: "Missing sessions_timeouts", - config: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - { - name: "Missing source_address_v4", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - { - name: "Missing source_address_v6", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - { - name: "Missing decap_addresses", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - }, - }, - { - name: "Missing vs", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ProtoToHandlerConfig(tt.config) - assert.Error(t, err) - }) - } -} - -// TestProtoToHandlerConfig_InvalidSourceAddresses tests invalid source address validation -func TestProtoToHandlerConfig_InvalidSourceAddresses(t *testing.T) { - tests := []struct { - name string - config *balancerpb.PacketHandlerConfig - }{ - { - name: "Invalid IPv4 source address length", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice()[:3], - }, // Only 3 bytes - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - { - name: "Invalid IPv6 source address length", - config: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{}, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, // Only 4 bytes - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ProtoToHandlerConfig(tt.config) - assert.Error(t, err) - }) - } -} - -// TestProtoToRealConfig_ZeroWeight tests weight validation -func TestProtoToRealConfig_ZeroWeight(t *testing.T) { - real := &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 8080, - }, - Weight: 0, // Invalid - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0").AsSlice(), - }, - } - - _, err := protoToRealConfig(real) - assert.Error(t, err) - assert.Contains(t, err.Error(), "weight must be at least 1") -} - -// TestCreateWlcConfig tests WLC config creation -func TestCreateWlcConfig(t *testing.T) { - tests := []struct { - name string - config *balancerpb.BalancerConfig - expected ffi.BalancerManagerWlcConfig - shouldErr bool - }{ - { - name: "No WLC enabled", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Flags: &balancerpb.VsFlags{Wlc: false}, - }, - }, - }, - State: &balancerpb.StateConfig{}, - }, - expected: ffi.BalancerManagerWlcConfig{ - Power: 0, - MaxRealWeight: 0, - Vs: []uint32{}, - }, - shouldErr: false, - }, - { - name: "WLC enabled with valid config", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Flags: &balancerpb.VsFlags{Wlc: true}, - }, - { - Flags: &balancerpb.VsFlags{Wlc: false}, - }, - { - Flags: &balancerpb.VsFlags{Wlc: true}, - }, - }, - }, - State: &balancerpb.StateConfig{ - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - }, - }, - }, - expected: ffi.BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - Vs: []uint32{0, 2}, - }, - shouldErr: false, - }, - { - name: "WLC enabled but missing power", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Flags: &balancerpb.VsFlags{Wlc: true}, - }, - }, - }, - State: &balancerpb.StateConfig{ - Wlc: &balancerpb.WlcConfig{ - MaxWeight: ptrUint32(1000), - }, - }, - }, - shouldErr: true, - }, - { - name: "WLC enabled but missing max_weight", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Flags: &balancerpb.VsFlags{Wlc: true}, - }, - }, - }, - State: &balancerpb.StateConfig{ - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(2), - }, - }, - }, - shouldErr: true, - }, - { - name: "WLC enabled but no config", - config: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Flags: &balancerpb.VsFlags{Wlc: true}, - }, - }, - }, - State: &balancerpb.StateConfig{}, - }, - shouldErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := createWlcConfig(tt.config) - if tt.shouldErr { - assert.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -// TestMergeBalancerConfig tests config merging for UPDATE mode -func TestMergeBalancerConfig(t *testing.T) { - currentConfig := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - SessionsTimeouts: ffi.SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 50, - Default: 30, - }, - VirtualServices: []ffi.VsConfig{}, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("::1"), - DecapV4: []netip.Addr{}, - DecapV6: []netip.Addr{}, - }, - State: ffi.StateConfig{ - TableCapacity: 1000, - }, - }, - RefreshPeriod: 5 * time.Second, - MaxLoadFactor: 0.75, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - Vs: []uint32{}, - }, - } - - tests := []struct { - name string - newConfig *balancerpb.BalancerConfig - verify func(t *testing.T, result *balancerpb.BalancerConfig) - }{ - { - name: "Full update with all fields", - newConfig: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 15, - TcpSyn: 25, - TcpFin: 20, - Tcp: 120, - Udp: 60, - Default: 40, - }, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(2000), - SessionTableMaxLoadFactor: ptrFloat32(0.85), - RefreshPeriod: durationpb.New(10 * time.Second), - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(3), - MaxWeight: ptrUint32(2000), - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.BalancerConfig) { - assert.Equal( - t, - uint32(15), - result.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - assert.Equal( - t, - uint64(2000), - *result.State.SessionTableCapacity, - ) - assert.Equal( - t, - float32(0.85), - *result.State.SessionTableMaxLoadFactor, - ) - assert.Equal( - t, - 10*time.Second, - result.State.RefreshPeriod.AsDuration(), - ) - }, - }, - { - name: "Partial update - only state fields", - newConfig: &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(3000), - }, - }, - verify: func(t *testing.T, result *balancerpb.BalancerConfig) { - // Handler should be from current config - assert.Equal( - t, - uint32(10), - result.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - // State capacity should be updated - assert.Equal( - t, - uint64(3000), - *result.State.SessionTableCapacity, - ) - // Other state fields should be from current config - assert.Equal( - t, - float32(0.75), - *result.State.SessionTableMaxLoadFactor, - ) - }, - }, - { - name: "Partial update - only handler", - newConfig: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 25, - TcpSyn: 35, - TcpFin: 30, - Tcp: 150, - Udp: 70, - Default: 50, - }, - Vs: []*balancerpb.VirtualService{}, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - }, - verify: func(t *testing.T, result *balancerpb.BalancerConfig) { - // Handler should be updated - assert.Equal( - t, - uint32(25), - result.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - // State should be from current config - assert.Equal( - t, - uint64(1000), - *result.State.SessionTableCapacity, - ) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mergeBalancerConfig(tt.newConfig, currentConfig) - require.NoError(t, err) - tt.verify(t, result) - }) - } -} - -// TestConvertFFIRealUpdateToProto tests FFI to proto conversion -func TestConvertFFIRealUpdateToProto(t *testing.T) { - tests := []struct { - name string - update *ffi.RealUpdate - expected *balancerpb.RealUpdate - }{ - { - name: "Update with weight and enabled", - update: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 8080, - }, - }, - Weight: 100, - Enabled: 1, - }, - expected: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 8080, - }, - }, - Weight: ptrUint32(100), - Enable: ptrBool(true), - }, - }, - { - name: "Update with DontUpdateRealWeight", - update: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 443, - TransportProto: ffi.VsTransportProtoUDP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("172.16.0.1"), - Port: 8443, - }, - }, - Weight: ffi.DontUpdateRealWeight, - Enabled: 0, - }, - expected: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_UDP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.1").AsSlice(), - }, - Port: 8443, - }, - }, - Weight: nil, - Enable: ptrBool(false), - }, - }, - { - name: "Update with DontUpdateRealEnabled", - update: &ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - VsIdentifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Relative: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 5353, - }, - }, - Weight: 200, - Enabled: ffi.DontUpdateRealEnabled, - }, - expected: &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - Port: 53, - Proto: balancerpb.TransportProto_UDP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::100"). - AsSlice(), - }, - Port: 5353, - }, - }, - Weight: ptrUint32(200), - Enable: nil, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ConvertFFIRealUpdateToProto(tt.update) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestProtoToVsConfig_PureL3Mode tests Pure L3 mode validation -func TestProtoToVsConfig_PureL3Mode(t *testing.T) { - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 0, // Must be 0 for Pure L3 - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{}, - Flags: &balancerpb.VsFlags{ - PureL3: true, - }, - } - - result, err := protoToVsConfig(vs) - require.NoError(t, err) - assert.True(t, result.Flags.PureL3) - assert.Equal(t, uint16(0), result.Identifier.Port) -} - -// TestProtoToVsConfig_AllFlags tests all flag combinations -func TestProtoToVsConfig_AllFlags(t *testing.T) { - tests := []struct { - name string - flags *balancerpb.VsFlags - }{ - { - name: "All flags enabled", - flags: &balancerpb.VsFlags{ - Gre: true, - FixMss: true, - Ops: true, - PureL3: false, // Can't be true with port != 0 - Wlc: true, - }, - }, - { - name: "All flags disabled", - flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - }, - { - name: "Mixed flags", - flags: &balancerpb.VsFlags{ - Gre: true, - FixMss: false, - Ops: true, - PureL3: false, - Wlc: false, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{}, - Flags: tt.flags, - } - - result, err := protoToVsConfig(vs) - require.NoError(t, err) - assert.Equal(t, tt.flags.Gre, result.Flags.GRE) - assert.Equal(t, tt.flags.FixMss, result.Flags.FixMSS) - assert.Equal(t, tt.flags.Ops, result.Flags.OPS) - assert.Equal(t, tt.flags.PureL3, result.Flags.PureL3) - }) - } -} - -// TestProtoToVsConfig_Schedulers tests both scheduler types -func TestProtoToVsConfig_Schedulers(t *testing.T) { - tests := []struct { - name string - scheduler balancerpb.VsScheduler - expected ffi.VsScheduler - }{ - { - name: "SOURCE_HASH", - scheduler: balancerpb.VsScheduler_SOURCE_HASH, - expected: ffi.VsSchedulerSourceHash, - }, - { - name: "ROUND_ROBIN", - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - expected: ffi.VsSchedulerRoundRobin, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: tt.scheduler, - Reals: []*balancerpb.Real{}, - } - - result, err := protoToVsConfig(vs) - require.NoError(t, err) - assert.Equal(t, tt.expected, result.Scheduler) - }) - } -} - -// TestProtoToVsConfig_EmptyArrays tests with empty optional arrays -func TestProtoToVsConfig_EmptyArrays(t *testing.T) { - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{}, - AllowedSrcs: []*balancerpb.AllowedSources{}, - Peers: []*balancerpb.Addr{}, - } - - result, err := protoToVsConfig(vs) - require.NoError(t, err) - assert.Empty(t, result.Reals) - assert.Empty(t, result.AllowedSources) - assert.Empty(t, result.PeersV4) - assert.Empty(t, result.PeersV6) -} - -// TestProtoToRealConfig_SourcePrefix tests source prefix conversion -func TestProtoToRealConfig_SourcePrefix(t *testing.T) { - tests := []struct { - name string - srcAddr []byte - srcMask []byte - expected xnetip.NetWithMask - }{ - { - name: "IPv4 /24", - srcAddr: netip.MustParseAddr("172.16.0.0").AsSlice(), - srcMask: []byte{255, 255, 255, 0}, - expected: xnetip.FromPrefix(netip.MustParsePrefix("172.16.0.0/24")), - }, - { - name: "IPv4 /16", - srcAddr: netip.MustParseAddr("10.0.0.0").AsSlice(), - srcMask: []byte{255, 255, 0, 0}, - expected: xnetip.FromPrefix(netip.MustParsePrefix("10.0.0.0/16")), - }, - { - name: "IPv6 /64", - srcAddr: netip.MustParseAddr("2001:db8::").AsSlice(), - srcMask: netip.MustParseAddr("ffff:ffff:ffff:ffff::").AsSlice(), - expected: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/64")), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - real := &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 8080, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{Bytes: tt.srcAddr}, - SrcMask: &balancerpb.Addr{Bytes: tt.srcMask}, - } - - result, err := protoToRealConfig(real) - require.NoError(t, err) - assert.Equal(t, tt.expected, result.Src) - }) - } -} - -// Helper functions -func ptrUint32(v uint32) *uint32 { - return &v -} - -func ptrUint64(v uint64) *uint64 { - return &v -} - -func ptrFloat32(v float32) *float32 { - return &v -} - -func ptrBool(v bool) *bool { - return &v -} - -// TestConvertPacketHandlerToProtoWithWlc tests WLC-aware packet handler conversion -func TestConvertPacketHandlerToProtoWithWlc(t *testing.T) { - handler := &ffi.PacketHandlerConfig{ - SessionsTimeouts: ffi.SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 50, - Default: 30, - }, - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.101"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.102"), - Port: 8080, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{}, - DecapV6: []netip.Addr{}, - } - - tests := []struct { - name string - wlcConfig *ffi.BalancerManagerWlcConfig - verify func(t *testing.T, result *balancerpb.PacketHandlerConfig) - }{ - { - name: "No WLC config", - wlcConfig: nil, - verify: func(t *testing.T, result *balancerpb.PacketHandlerConfig) { - require.Len(t, result.Vs, 3) - assert.False( - t, - result.Vs[0].Flags.Wlc, - "VS0 should have WLC=false", - ) - assert.False( - t, - result.Vs[1].Flags.Wlc, - "VS1 should have WLC=false", - ) - assert.False( - t, - result.Vs[2].Flags.Wlc, - "VS2 should have WLC=false", - ) - }, - }, - { - name: "WLC enabled for VS 0 and 2", - wlcConfig: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 2}, - }, - verify: func(t *testing.T, result *balancerpb.PacketHandlerConfig) { - require.Len(t, result.Vs, 3) - assert.True( - t, - result.Vs[0].Flags.Wlc, - "VS0 should have WLC=true", - ) - assert.False( - t, - result.Vs[1].Flags.Wlc, - "VS1 should have WLC=false", - ) - assert.True( - t, - result.Vs[2].Flags.Wlc, - "VS2 should have WLC=true", - ) - }, - }, - { - name: "WLC enabled for all VSs", - wlcConfig: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 1, 2}, - }, - verify: func(t *testing.T, result *balancerpb.PacketHandlerConfig) { - require.Len(t, result.Vs, 3) - assert.True( - t, - result.Vs[0].Flags.Wlc, - "VS0 should have WLC=true", - ) - assert.True( - t, - result.Vs[1].Flags.Wlc, - "VS1 should have WLC=true", - ) - assert.True( - t, - result.Vs[2].Flags.Wlc, - "VS2 should have WLC=true", - ) - }, - }, - { - name: "Empty WLC VS list", - wlcConfig: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{}, - }, - verify: func(t *testing.T, result *balancerpb.PacketHandlerConfig) { - require.Len(t, result.Vs, 3) - assert.False( - t, - result.Vs[0].Flags.Wlc, - "VS0 should have WLC=false", - ) - assert.False( - t, - result.Vs[1].Flags.Wlc, - "VS1 should have WLC=false", - ) - assert.False( - t, - result.Vs[2].Flags.Wlc, - "VS2 should have WLC=false", - ) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertPacketHandlerToProtoWithWlc(handler, tt.wlcConfig) - require.NotNil(t, result) - tt.verify(t, result) - }) - } -} - -// TestConvertVsConfigToProtoWithWlc tests WLC-aware VS config conversion -func TestConvertVsConfigToProtoWithWlc(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Flags: ffi.VsFlags{ - GRE: true, - FixMSS: false, - OPS: true, - PureL3: false, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - AllowedSources: []ffi.AllowedSources{}, - PeersV4: []netip.Addr{}, - PeersV6: []netip.Addr{}, - } - - tests := []struct { - name string - wlcEnabled bool - verify func(t *testing.T, result *balancerpb.VirtualService) - }{ - { - name: "WLC disabled", - wlcEnabled: false, - verify: func(t *testing.T, result *balancerpb.VirtualService) { - require.NotNil(t, result.Flags) - assert.False(t, result.Flags.Wlc, "WLC should be false") - assert.True(t, result.Flags.Gre, "GRE should be preserved") - assert.True(t, result.Flags.Ops, "OPS should be preserved") - }, - }, - { - name: "WLC enabled", - wlcEnabled: true, - verify: func(t *testing.T, result *balancerpb.VirtualService) { - require.NotNil(t, result.Flags) - assert.True(t, result.Flags.Wlc, "WLC should be true") - assert.True(t, result.Flags.Gre, "GRE should be preserved") - assert.True(t, result.Flags.Ops, "OPS should be preserved") - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertVsConfigToProtoWithWlc(vsConfig, tt.wlcEnabled) - require.NotNil(t, result) - tt.verify(t, result) - }) - } -} - -// TestConvertBalancerConfigToProto_WithWlc tests full config conversion with WLC -func TestConvertBalancerConfigToProto_WithWlc(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - SessionsTimeouts: ffi.SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 50, - Default: 30, - }, - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.100", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.101", - ), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{}, - DecapV6: []netip.Addr{}, - }, - State: ffi.StateConfig{ - TableCapacity: 1000, - }, - }, - RefreshPeriod: 5 * time.Second, - MaxLoadFactor: 0.75, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, // Only first VS has WLC enabled - }, - } - - result := ConvertBalancerConfigToProto(config) - require.NotNil(t, result) - require.NotNil(t, result.PacketHandler) - require.Len(t, result.PacketHandler.Vs, 2) - - // Verify WLC flags - assert.True( - t, - result.PacketHandler.Vs[0].Flags.Wlc, - "VS0 should have WLC=true", - ) - assert.False( - t, - result.PacketHandler.Vs[1].Flags.Wlc, - "VS1 should have WLC=false", - ) - - // Verify state config - require.NotNil(t, result.State) - assert.Equal(t, uint64(1000), *result.State.SessionTableCapacity) - assert.Equal(t, float32(0.75), *result.State.SessionTableMaxLoadFactor) - assert.Equal(t, 5*time.Second, result.State.RefreshPeriod.AsDuration()) - - // Verify WLC config - require.NotNil(t, result.State.Wlc) - assert.Equal(t, uint64(10), *result.State.Wlc.Power) - assert.Equal(t, uint32(1000), *result.State.Wlc.MaxWeight) -} - -// TestAllowedSourcesTagConversion tests that tag field is properly converted -func TestAllowedSourcesTagConversion(t *testing.T) { - tests := []struct { - name string - protoConfig *balancerpb.BalancerConfig - verifyTag func(t *testing.T, config *ffi.BalancerManagerConfig) - }{ - { - name: "Tag field set to non-zero value", - protoConfig: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{}, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "12345"; return &s }(), // Non-zero tag - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - }, - }, - }, - verifyTag: func(t *testing.T, config *ffi.BalancerManagerConfig) { - require.Len(t, config.Balancer.Handler.VirtualServices, 1) - require.Len( - t, - config.Balancer.Handler.VirtualServices[0].AllowedSources, - 1, - ) - assert.Equal( - t, - "12345", - config.Balancer.Handler.VirtualServices[0].AllowedSources[0].Tag, - "Tag should be 12345", - ) - }, - }, - { - name: "Tag field not specified (defaults to zero)", - protoConfig: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{}, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - // Tag not specified - should default to 0 - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - }, - }, - }, - verifyTag: func(t *testing.T, config *ffi.BalancerManagerConfig) { - require.Len(t, config.Balancer.Handler.VirtualServices, 1) - require.Len( - t, - config.Balancer.Handler.VirtualServices[0].AllowedSources, - 1, - ) - assert.Equal( - t, - "", - config.Balancer.Handler.VirtualServices[0].AllowedSources[0].Tag, - "Tag should default to empty string", - ) - }, - }, - { - name: "Multiple allowed sources with different tags", - protoConfig: &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 50, - Default: 30, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.100"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{}, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "100"; return &s }(), - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "200"; return &s }(), - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - // Tag not specified - should be 0 - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: ptrUint64(1000), - SessionTableMaxLoadFactor: ptrFloat32(0.75), - RefreshPeriod: durationpb.New(5 * time.Second), - Wlc: &balancerpb.WlcConfig{ - Power: ptrUint64(2), - MaxWeight: ptrUint32(1000), - }, - }, - }, - verifyTag: func(t *testing.T, config *ffi.BalancerManagerConfig) { - require.Len(t, config.Balancer.Handler.VirtualServices, 1) - require.Len( - t, - config.Balancer.Handler.VirtualServices[0].AllowedSources, - 3, - ) - assert.Equal( - t, - "100", - config.Balancer.Handler.VirtualServices[0].AllowedSources[0].Tag, - "First tag should be 100", - ) - assert.Equal( - t, - "200", - config.Balancer.Handler.VirtualServices[0].AllowedSources[1].Tag, - "Second tag should be 200", - ) - assert.Equal( - t, - "", - config.Balancer.Handler.VirtualServices[0].AllowedSources[2].Tag, - "Third tag should be empty string", - ) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config, err := ProtoToManagerConfig(tt.protoConfig) - require.NoError(t, err, "failed to convert config") - tt.verifyTag(t, config) - }) - } -} - -// TestAllowedSourcesTagRoundTrip tests bidirectional conversion of tag field -func TestAllowedSourcesTagRoundTrip(t *testing.T) { - originalConfig := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - SessionsTimeouts: ffi.SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 50, - Default: 30, - }, - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.100", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Scheduler: ffi.VsSchedulerRoundRobin, - Reals: []ffi.RealConfig{}, - AllowedSources: []ffi.AllowedSources{ - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - ), - }, - PortRanges: []ffi.PortRange{}, - Tag: "12345", - }, - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - ), - }, - PortRanges: []ffi.PortRange{}, - Tag: "0", // Zero tag - }, - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/12"), - ), - }, - PortRanges: []ffi.PortRange{}, - Tag: "99999", - }, - }, - PeersV4: []netip.Addr{}, - PeersV6: []netip.Addr{}, - Flags: ffi.VsFlags{}, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("::1"), - DecapV4: []netip.Addr{}, - DecapV6: []netip.Addr{}, - }, - State: ffi.StateConfig{ - TableCapacity: 1000, - }, - }, - RefreshPeriod: 5 * time.Second, - MaxLoadFactor: 0.75, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - Vs: []uint32{}, - }, - } - - // Convert to proto - protoConfig := ConvertBalancerConfigToProto(originalConfig) - require.NotNil(t, protoConfig) - require.NotNil(t, protoConfig.PacketHandler) - require.Len(t, protoConfig.PacketHandler.Vs, 1) - require.Len(t, protoConfig.PacketHandler.Vs[0].AllowedSrcs, 3) - - // Verify tags in proto - require.NotNil( - t, - protoConfig.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "First tag should not be nil in proto", - ) - assert.Equal( - t, - "12345", - *protoConfig.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "First tag should be 12345 in proto", - ) - require.NotNil( - t, - protoConfig.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "Second tag should not be nil in proto", - ) - assert.Equal( - t, - "0", - *protoConfig.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "Second tag should be 0 in proto", - ) - require.NotNil( - t, - protoConfig.PacketHandler.Vs[0].AllowedSrcs[2].Tag, - "Third tag should not be nil in proto", - ) - assert.Equal( - t, - "99999", - *protoConfig.PacketHandler.Vs[0].AllowedSrcs[2].Tag, - "Third tag should be 99999 in proto", - ) - - // Convert back to FFI - convertedConfig, err := ProtoToManagerConfig(protoConfig) - require.NoError(t, err) - require.Len(t, convertedConfig.Balancer.Handler.VirtualServices, 1) - require.Len( - t, - convertedConfig.Balancer.Handler.VirtualServices[0].AllowedSources, - 3, - ) - - // Verify tags are preserved - assert.Equal( - t, - "12345", - convertedConfig.Balancer.Handler.VirtualServices[0].AllowedSources[0].Tag, - "First tag should be preserved as 12345", - ) - assert.Equal( - t, - "0", - convertedConfig.Balancer.Handler.VirtualServices[0].AllowedSources[1].Tag, - "Second tag should be preserved as 0", - ) - assert.Equal( - t, - "99999", - convertedConfig.Balancer.Handler.VirtualServices[0].AllowedSources[2].Tag, - "Third tag should be preserved as 99999", - ) -} diff --git a/modules/balancer/agent/go/duplicate_vs_test.go b/modules/balancer/agent/go/duplicate_vs_test.go deleted file mode 100644 index da48af52c..000000000 --- a/modules/balancer/agent/go/duplicate_vs_test.go +++ /dev/null @@ -1,434 +0,0 @@ -package balancer - -import ( - "fmt" - "net/netip" - "strings" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Helper function to check if a string contains any of the given substrings -func containsAny(s string, substrs ...string) bool { - for _, substr := range substrs { - if strings.Contains(s, substr) { - return true - } - } - return false -} - -// Helper function to create a basic VS config -func createVsConfig( - addr string, - port uint32, - proto balancerpb.TransportProto, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr(addr).AsSlice(), - }, - Port: port, - Proto: proto, - }, - Flags: &balancerpb.VsFlags{ - FixMss: true, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0").AsSlice(), - }, - Weight: 100, - }, - }, - } -} - -// Helper function to create a base balancer config -func createBaseConfig( - vs []*balancerpb.VirtualService, -) *balancerpb.BalancerConfig { - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: vs, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.13.11.215").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.75); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(2); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -func TestDuplicateVsRejected(t *testing.T) { - // Create mock Yanet instance - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 512 * datasize.MB, - DpMemory: 16 * datasize.MB, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: "eth0", - }, - }, - }) - require.NoError(t, err, "failed to initialize mock") - require.NotNil(t, m, "mock is nil") - defer m.Free() - - // Create logger for tests - log := zap.NewNop().Sugar() - - // Create balancer agent - agent, err := NewBalancerAgent(m.SharedMemory(), 256*datasize.MB, log) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - t.Run("DuplicateIPv4Vs_SamePortAndProto", func(t *testing.T) { - // Create config with two VS having same IPv4 addr:port:proto - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.1.100", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.1.100", - 80, - balancerpb.TransportProto_TCP, - ), // Duplicate - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_ipv4_dup", config) - require.Error(t, err, "expected error for duplicate IPv4 VS") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg, - ) - }) - - t.Run("DuplicateIPv6Vs_SamePortAndProto", func(t *testing.T) { - // Create config with two VS having same IPv6 addr:port:proto - vs := []*balancerpb.VirtualService{ - createVsConfig("2001:db8::100", 443, balancerpb.TransportProto_TCP), - createVsConfig( - "2001:db8::100", - 443, - balancerpb.TransportProto_TCP, - ), // Duplicate - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_ipv6_dup", config) - require.Error(t, err, "expected error for duplicate IPv6 VS") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("SameIP_DifferentPort_Allowed", func(t *testing.T) { - // Same IP but different port should be allowed - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.1.100", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.1.100", - 443, - balancerpb.TransportProto_TCP, - ), // Different port - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_diff_port", config) - require.NoError(t, err, "same IP with different port should be allowed") - }) - - t.Run("SameIP_SamePort_DifferentProto_Allowed", func(t *testing.T) { - // Same IP:port but different protocol should be allowed - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.1.100", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.1.100", - 80, - balancerpb.TransportProto_UDP, - ), // Different proto - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_diff_proto", config) - require.NoError( - t, - err, - "same IP:port with different protocol should be allowed", - ) - }) - - t.Run("LargeConfig_FewIPv4Duplicates", func(t *testing.T) { - // Create large config with many unique VS and a few duplicates - vs := []*balancerpb.VirtualService{} - - // Add 50 unique IPv4 VS - for i := 0; i < 50; i++ { - addr := fmt.Sprintf("10.0.%d.%d", i/256, i%256) - vs = append( - vs, - createVsConfig(addr, 80, balancerpb.TransportProto_TCP), - ) - } - - // Add a duplicate at position 25 - vs = append( - vs, - createVsConfig("10.0.0.25", 80, balancerpb.TransportProto_TCP), - ) - - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_large_ipv4_dup", config) - require.Error(t, err, "expected error for duplicate in large config") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("LargeConfig_FewIPv6Duplicates", func(t *testing.T) { - // Create large config with many unique VS and a few duplicates - vs := []*balancerpb.VirtualService{} - - // Add 50 unique IPv6 VS - for i := 0; i < 50; i++ { - addr := fmt.Sprintf("2001:db8::%x", i) - vs = append( - vs, - createVsConfig(addr, 443, balancerpb.TransportProto_TCP), - ) - } - - // Add a duplicate at position 30 - vs = append( - vs, - createVsConfig("2001:db8::1e", 443, balancerpb.TransportProto_TCP), - ) - - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_large_ipv6_dup", config) - require.Error(t, err, "expected error for duplicate in large config") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("ManyIPv4_NoDuplicates_WithIPv6Duplicates", func(t *testing.T) { - // Many unique IPv4 VS but with IPv6 duplicates - vs := []*balancerpb.VirtualService{} - - // Add 30 unique IPv4 VS - for i := 0; i < 30; i++ { - addr := fmt.Sprintf("10.1.%d.%d", i/256, i%256) - vs = append( - vs, - createVsConfig(addr, 80, balancerpb.TransportProto_TCP), - ) - } - - // Add IPv6 VS with duplicates - vs = append( - vs, - createVsConfig("2001:db8::a", 443, balancerpb.TransportProto_TCP), - ) - vs = append( - vs, - createVsConfig("2001:db8::b", 443, balancerpb.TransportProto_TCP), - ) - vs = append( - vs, - createVsConfig("2001:db8::a", 443, balancerpb.TransportProto_TCP), - ) // Duplicate - - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_ipv4_ok_ipv6_dup", config) - require.Error( - t, - err, - "expected error for IPv6 duplicate despite unique IPv4", - ) - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("ManyIPv6_NoDuplicates_WithIPv4Duplicates", func(t *testing.T) { - // Many unique IPv6 VS but with IPv4 duplicates - vs := []*balancerpb.VirtualService{} - - // Add 30 unique IPv6 VS - for i := range 30 { - addr := fmt.Sprintf("2001:db8::%x", i+100) - vs = append( - vs, - createVsConfig(addr, 443, balancerpb.TransportProto_TCP), - ) - } - - // Add IPv4 VS with duplicates - vs = append( - vs, - createVsConfig("192.168.10.1", 80, balancerpb.TransportProto_TCP), - ) - vs = append( - vs, - createVsConfig("192.168.10.2", 80, balancerpb.TransportProto_TCP), - ) - vs = append( - vs, - createVsConfig("192.168.10.1", 80, balancerpb.TransportProto_TCP), - ) // Duplicate - - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_ipv6_ok_ipv4_dup", config) - require.Error( - t, - err, - "expected error for IPv4 duplicate despite unique IPv6", - ) - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("NoDuplicates_OnCreate_DuplicateOnUpdate", func(t *testing.T) { - // Create manager with unique VS - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.2.1", 80, balancerpb.TransportProto_TCP), - createVsConfig("192.168.2.2", 80, balancerpb.TransportProto_TCP), - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_update_dup", config) - require.NoError(t, err, "initial config should be valid") - - // Get the manager - manager, err := agent.BalancerManager("test_update_dup") - require.NoError(t, err, "failed to get manager") - - // Try to update with duplicate VS - vsUpdate := []*balancerpb.VirtualService{ - createVsConfig("192.168.2.1", 80, balancerpb.TransportProto_TCP), - createVsConfig("192.168.2.2", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.2.1", - 80, - balancerpb.TransportProto_TCP, - ), // Duplicate - } - updateConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: vsUpdate, - }, - } - - _, err = manager.Update(updateConfig, m.CurrentTime()) - require.Error(t, err, "expected error when updating with duplicate VS") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("MultipleDuplicates_InSameConfig", func(t *testing.T) { - // Config with multiple different duplicates - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.3.1", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.3.1", - 80, - balancerpb.TransportProto_TCP, - ), // Duplicate 1 - createVsConfig("192.168.3.2", 443, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.3.2", - 443, - balancerpb.TransportProto_TCP, - ), // Duplicate 2 - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_multi_dup", config) - require.Error(t, err, "expected error for multiple duplicates") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) - - t.Run("TripleDuplicate_SameVS", func(t *testing.T) { - // Three instances of the same VS - vs := []*balancerpb.VirtualService{ - createVsConfig("192.168.4.1", 80, balancerpb.TransportProto_TCP), - createVsConfig( - "192.168.4.1", - 80, - balancerpb.TransportProto_TCP, - ), // Duplicate 1 - createVsConfig( - "192.168.4.1", - 80, - balancerpb.TransportProto_TCP, - ), // Duplicate 2 - } - config := createBaseConfig(vs) - - err := agent.NewBalancerManager("test_triple_dup", config) - require.Error(t, err, "expected error for triple duplicate") - errMsg := err.Error() - assert.True(t, - containsAny(errMsg, "duplicate", "match"), - "error should mention 'duplicate' or 'match', got: %s", errMsg) - }) -} diff --git a/modules/balancer/agent/go/ffi/agent.go b/modules/balancer/agent/go/ffi/agent.go deleted file mode 100644 index 975940b04..000000000 --- a/modules/balancer/agent/go/ffi/agent.go +++ /dev/null @@ -1,115 +0,0 @@ -// Package ffi provides Foreign Function Interface (FFI) bindings to C code for balancer agent operations. -// This file implements the BalancerAgent FFI wrapper for creating and managing balancer managers. -package ffi - -/* -#cgo CFLAGS: -I../../ -I../../../../../ -#cgo LDFLAGS: -L../../../../../build/modules/balancer/agent -lbalancer_agent -L../../../../../build/modules/balancer/controlplane/api -lbalancer_cp -L../../../../../build/modules/balancer/controlplane/handler -lbalancer_packet_handler -L../../../../../build/modules/balancer/controlplane/state -lbalancer_state -lbalancer_packet_handler -lbalancer_state -#include "agent.h" -#include "modules/balancer/controlplane/api/inspect.h" -#include -*/ -import "C" - -import ( - "fmt" - "unsafe" - - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" -) - -// BalancerAgent wraps a C balancer_agent handle -type BalancerAgent struct { - handle *C.struct_balancer_agent -} - -// NewBalancerAgent creates a new balancer agent instance -func NewBalancerAgent( - shm *yanet.SharedMemory, - memory uint, -) (*BalancerAgent, error) { - if shm == nil { - return nil, fmt.Errorf("shared memory is nil") - } - - cShm := (*C.struct_yanet_shm)(shm.AsRawPtr()) - cMemory := C.size_t(memory) - - handle := C.balancer_agent(cShm, cMemory) - if handle == nil { - return nil, fmt.Errorf("failed to attach balancer agent") - } - - return &BalancerAgent{handle: handle}, nil -} - -// Managers retrieves all balancer managers registered with the agent -func (a *BalancerAgent) Managers() []BalancerManager { - var cManagers C.struct_balancer_managers - C.balancer_agent_managers(a.handle, &cManagers) - - if cManagers.count == 0 || cManagers.managers == nil { - return nil - } - - // Convert C array to Go slice - managers := make([]BalancerManager, cManagers.count) - cManagersSlice := unsafe.Slice(cManagers.managers, cManagers.count) - - for i := range managers { - managers[i] = BalancerManager{handle: cManagersSlice[i]} - } - - // Free the C-allocated array (the manager pointers themselves are owned by the agent) - C.free(unsafe.Pointer(cManagers.managers)) - - return managers -} - -// NewManager creates and registers a new balancer manager with the agent -func (a *BalancerAgent) NewManager( - name string, - config *BalancerManagerConfig, -) (*BalancerManager, error) { - if config == nil { - return nil, fmt.Errorf("config is nil") - } - - cName := C.CString(name) - defer C.free(unsafe.Pointer(cName)) - - // Convert Go config to C config - cConfig, err := goToCBalancerManagerConfig(config) - if err != nil { - return nil, fmt.Errorf("failed to convert config: %w", err) - } - defer freeCBalancerManagerConfig(cConfig) - - handle := C.balancer_agent_new_manager(a.handle, cName, cConfig) - if handle == nil { - // Get error message from agent - cErr := C.balancer_agent_take_error(a.handle) - if cErr != nil { - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return nil, fmt.Errorf("%s", errMsg) - } - return nil, fmt.Errorf("unknown error") - } - - return &BalancerManager{handle: handle}, nil -} - -// Inspect retrieves agent-level memory inspection -func (a *BalancerAgent) Inspect() *AgentInspect { - var cInspect C.struct_agent_inspect - C.balancer_agent_inspect(a.handle, &cInspect) - inspect := cToGoAgentInspect(&cInspect) - C.balancer_agent_inspect_free(&cInspect) - return inspect -} - -func (a *BalancerAgent) DPConfig() *yanet.DPConfig { - dpConfig := C.balancer_agent_dp_config(a.handle) - return yanet.NewDPConfigFromRaw(unsafe.Pointer(dpConfig)) -} diff --git a/modules/balancer/agent/go/ffi/agent_test.go b/modules/balancer/agent/go/ffi/agent_test.go deleted file mode 100644 index e7d51228f..000000000 --- a/modules/balancer/agent/go/ffi/agent_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package ffi - -import ( - "net/netip" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xnetip" - mock "github.com/yanet-platform/yanet2/mock/go" -) - -func TestAgent(t *testing.T) { - // Create mock - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 28, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: "eth0", - }, - }, - }) - require.NoError(t, err, "failed to initialize mock") - require.NotNil(t, m, "mock is nil") - defer m.Free() - - // Create balancer agent - - agent, err := NewBalancerAgent(m.SharedMemory(), 1<<27) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - managers := agent.Managers() - assert.Empty(t, managers) - - firstManagerConfig := BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 11, - Default: 19, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.12.13.213"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.12.13.213"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.1.1.1/24"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.12.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("12.1.1.3"), - netip.MustParseAddr("12.1.1.4"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::2"), - netip.MustParseAddr("2001:db8::3"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.12.13.213"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("10.13.11.215"), - netip.MustParseAddr("10.14.11.214"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::3"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - State: StateConfig{ - TableCapacity: 1000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1024, - Vs: []uint32{0, 1}, - }, - RefreshPeriod: time.Millisecond * 10, - MaxLoadFactor: 0.75, - } - - t.Run("First_Manager", func(t *testing.T) { - _, err := agent.NewManager("balancer0", &firstManagerConfig) - require.NoError(t, err, "failed to create manager") - }) - - t.Run("Managers", func(t *testing.T) { - managers := agent.Managers() - assert.Len(t, managers, 1) - assert.Equal(t, "balancer0", managers[0].Name()) - assert.Equal(t, &firstManagerConfig, managers[0].Config()) - }) - - secondManagerConfig := BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 15, - TCPSyn: 25, - TCPFin: 20, - TCP: 69, - UDP: 15, - Default: 25, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.20.30.40"), - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{OPS: true, GRE: true}, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.20.30.40"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.17.0.0/24"), - ), - Weight: 150, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.20.30.41"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.17.1.0/24"), - ), - Weight: 200, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.2.2.0/24"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.13.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("12.2.2.3"), - netip.MustParseAddr("12.2.2.4"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::4"), - netip.MustParseAddr("2001:db8::5"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.20.30.40"), - SourceV6: netip.MustParseAddr("2001:db8::10"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("10.15.12.216"), - netip.MustParseAddr("10.16.12.215"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::5"), - netip.MustParseAddr("2001:db8::4"), - }, - }, - State: StateConfig{ - TableCapacity: 2000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 15, - MaxRealWeight: 512, - Vs: []uint32{}, - }, - RefreshPeriod: time.Millisecond * 20, - MaxLoadFactor: 0.85, - } - - t.Run("Second_Manager", func(t *testing.T) { - _, err := agent.NewManager("balancer1", &secondManagerConfig) - require.NoError(t, err, "failed to create manager") - }) - - t.Run("Managers2", func(t *testing.T) { - managers := agent.Managers() - assert.Len(t, managers, 2) - - assert.Equal(t, managers[0].Name(), "balancer0") - assert.Equal(t, managers[0].Config(), &firstManagerConfig) - - assert.Equal(t, managers[1].Name(), "balancer1") - assert.Equal(t, managers[1].Config(), &secondManagerConfig) - }) - - t.Run("Create_Existing_Manager", func(t *testing.T) { - _, err := agent.NewManager("balancer0", &firstManagerConfig) - require.Error(t, err, "created existent manager") - }) - - t.Run("Reattach", func(t *testing.T) { - agent1, err := NewBalancerAgent(m.SharedMemory(), 1<<22) - require.NoError(t, err, "failed to create agent") - - managers := agent1.Managers() - assert.Len(t, managers, 2) - - assert.Equal(t, managers[0].Name(), "balancer0") - assert.Equal(t, managers[0].Config(), &firstManagerConfig) - - assert.Equal(t, managers[1].Name(), "balancer1") - assert.Equal(t, managers[1].Config(), &secondManagerConfig) - }) -} diff --git a/modules/balancer/agent/go/ffi/conversion_test.go b/modules/balancer/agent/go/ffi/conversion_test.go deleted file mode 100644 index e7500f535..000000000 --- a/modules/balancer/agent/go/ffi/conversion_test.go +++ /dev/null @@ -1,1622 +0,0 @@ -package ffi - -import ( - "fmt" - "net/netip" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/yanet-platform/yanet2/common/go/xnetip" -) - -// Test helper to compare netip.Addr -func compareAddr(a, b netip.Addr) bool { - return a.Compare(b) == 0 -} - -// Test helper to compare netip.Prefix -func comparePrefix(a, b netip.Prefix) bool { - return a.Addr().Compare(b.Addr()) == 0 && a.Bits() == b.Bits() -} - -// Test helper to compare xnetip.NetWithMask -func compareNetWithMask(a, b xnetip.NetWithMask) bool { - if !compareAddr(a.Addr, b.Addr) { - return false - } - if len(a.Mask) != len(b.Mask) { - return false - } - for i := range a.Mask { - if a.Mask[i] != b.Mask[i] { - return false - } - } - return true -} - -// TestNetAddrConversion tests round-trip conversion of network addresses -func TestNetAddrConversion(t *testing.T) { - tests := []struct { - name string - addr netip.Addr - isV4 bool - }{ - { - name: "IPv4 localhost", - addr: netip.MustParseAddr("127.0.0.1"), - isV4: true, - }, - { - name: "IPv4 zero", - addr: netip.MustParseAddr("0.0.0.0"), - isV4: true, - }, - { - name: "IPv4 broadcast", - addr: netip.MustParseAddr("255.255.255.255"), - isV4: true, - }, - { - name: "IPv4 typical", - addr: netip.MustParseAddr("192.168.1.100"), - isV4: true, - }, - { - name: "IPv6 localhost", - addr: netip.MustParseAddr("::1"), - isV4: false, - }, - { - name: "IPv6 zero", - addr: netip.MustParseAddr("::"), - isV4: false, - }, - { - name: "IPv6 typical", - addr: netip.MustParseAddr("2001:db8::1"), - isV4: false, - }, - { - name: "IPv6 full", - addr: netip.MustParseAddr( - "2001:0db8:85a3:0000:0000:8a2e:0370:7334", - ), - isV4: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cAddr := goToCNetAddr(tt.addr) - - // Convert C -> Go - result := cToGoNetAddr(cAddr, tt.isV4) - - // Compare - if !compareAddr(tt.addr, result) { - t.Errorf( - "Round-trip conversion failed: got %v, want %v", - result, - tt.addr, - ) - } - }) - } -} - -// TestNetConversion tests round-trip conversion of network prefixes -func TestNetConversion(t *testing.T) { - tests := []struct { - name string - net xnetip.NetWithMask - isV4 bool - }{ - { - name: "IPv4 /32", - net: xnetip.FromPrefix(netip.MustParsePrefix("192.168.1.1/32")), - isV4: true, - }, - { - name: "IPv4 /24", - net: xnetip.FromPrefix(netip.MustParsePrefix("192.168.1.0/24")), - isV4: true, - }, - { - name: "IPv4 /17", - net: xnetip.FromPrefix(netip.MustParsePrefix("192.168.0.0/17")), - isV4: true, - }, - { - name: "IPv4 /16", - net: xnetip.FromPrefix(netip.MustParsePrefix("192.168.0.0/16")), - isV4: true, - }, - { - name: "IPv4 /11", - net: xnetip.FromPrefix(netip.MustParsePrefix("10.0.0.0/11")), - isV4: true, - }, - { - name: "IPv4 /8", - net: xnetip.FromPrefix(netip.MustParsePrefix("10.0.0.0/8")), - isV4: true, - }, - { - name: "IPv4 /0", - net: xnetip.FromPrefix(netip.MustParsePrefix("0.0.0.0/0")), - isV4: true, - }, - { - name: "IPv6 /128", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::1/128")), - isV4: false, - }, - { - name: "IPv6 /73", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/73")), - isV4: false, - }, - { - name: "IPv6 /64", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/64")), - isV4: false, - }, - { - name: "IPv6 /49", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/49")), - isV4: false, - }, - { - name: "IPv6 /48", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/48")), - isV4: false, - }, - { - name: "IPv6 /37", - net: xnetip.FromPrefix(netip.MustParsePrefix("2001:db8::/37")), - isV4: false, - }, - { - name: "IPv6 /0", - net: xnetip.FromPrefix(netip.MustParsePrefix("::/0")), - isV4: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cNet := goToCNet(tt.net) - - // Convert C -> Go - result := cToGoNet(cNet, tt.isV4) - - // Compare - if !compareNetWithMask(tt.net, result) { - t.Errorf( - "Round-trip conversion failed: got %v, want %v", - result, - tt.net, - ) - } - }) - } -} - -// TestVsIdentifierConversion tests round-trip conversion of VS identifiers -func TestVsIdentifierConversion(t *testing.T) { - tests := []struct { - name string - id VsIdentifier - }{ - { - name: "IPv4 TCP VS", - id: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - name: "IPv4 UDP VS", - id: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 53, - TransportProto: VsTransportProtoUDP, - }, - }, - { - name: "IPv6 TCP VS", - id: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - }, - { - name: "IPv6 UDP VS", - id: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 53, - TransportProto: VsTransportProtoUDP, - }, - }, - { - name: "Port zero (PureL3)", - id: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - - Port: 0, - TransportProto: VsTransportProtoTCP, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cID := goToCVsIdentifier(tt.id) - - // Convert C -> Go - result := cToGoVsIdentifier(cID) - - // Compare - if diff := cmp.Diff(tt.id, result, cmp.Comparer(compareAddr)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestRelativeRealIdentifierConversion tests round-trip conversion of relative real identifiers -func TestRelativeRealIdentifierConversion(t *testing.T) { - tests := []struct { - name string - id RelativeRealIdentifier - }{ - { - name: "IPv4 real", - id: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - - Port: 8080, - }, - }, - { - name: "IPv6 real", - id: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - - Port: 8080, - }, - }, - { - name: "Port zero", - id: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - - Port: 0, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cID := goToCRelativeRealIdentifier(tt.id) - - // Convert C -> Go - result := cToGoRelativeRealIdentifier(cID) - - // Compare - if diff := cmp.Diff(tt.id, result, cmp.Comparer(compareAddr)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestRealIdentifierConversion tests round-trip conversion of real identifiers -func TestRealIdentifierConversion(t *testing.T) { - tests := []struct { - name string - id RealIdentifier - }{ - { - name: "Complete IPv4 real", - id: RealIdentifier{ - VsIdentifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Relative: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - - Port: 8080, - }, - }, - }, - { - name: "Complete IPv6 real", - id: RealIdentifier{ - VsIdentifier: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 443, - TransportProto: VsTransportProtoUDP, - }, - Relative: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - - Port: 8443, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cID := goToCRealIdentifier(tt.id) - - // Convert C -> Go - result := cToGoRealIdentifier(cID) - - // Compare - if diff := cmp.Diff(tt.id, result, cmp.Comparer(compareAddr)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestTimestampConversion tests round-trip conversion of timestamps -func TestTimestampConversion(t *testing.T) { - tests := []struct { - name string - ts time.Time - }{ - { - name: "Unix epoch", - ts: time.Unix(0, 0), - }, - { - name: "Current time", - ts: time.Unix(1700000000, 0), - }, - { - name: "Future time", - ts: time.Unix(2000000000, 0), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cTS := goToCTimestamp(tt.ts) - - // Convert C -> Go - result := cToGoTimestamp(cTS) - - // Compare (only seconds precision) - if result.Unix() != tt.ts.Unix() { - t.Errorf( - "Round-trip conversion failed: got %v, want %v", - result.Unix(), - tt.ts.Unix(), - ) - } - }) - } -} - -// TestVsConfigConversion tests round-trip conversion of VS configuration -func TestVsConfigConversion(t *testing.T) { - tests := []struct { - name string - config VsConfig - }{ - { - name: "Simple IPv4 VS with one real", - config: VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: false, - FixMSS: true, - GRE: false, - OPS: false, - }, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - PeersV6: []netip.Addr{}, - }, - }, - { - name: "IPv6 VS with multiple reals", - config: VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: false, - FixMSS: false, - GRE: true, - OPS: false, - }, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("2001:db8:1::/64"), - ), - Weight: 50, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::101"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("2001:db8:2::/64"), - ), - Weight: 150, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("2001:db8::/32"), - )}, - }, - }, - PeersV4: []netip.Addr{}, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::2"), - }, - }, - }, - { - name: "VS with PureL3 flag and multiple reals", - config: VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - - Port: 0, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: true, - FixMSS: false, - GRE: false, - OPS: false, - }, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.10"), - Port: 0, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.11"), - Port: 0, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.1.0/24"), - ), - Weight: 150, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.12"), - Port: 0, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.2.0/24"), - ), - Weight: 200, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("0.0.0.0/0"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("10.0.0.2"), - netip.MustParseAddr("10.0.0.3"), - netip.MustParseAddr("10.0.0.4"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - }, - { - name: "VS with port ranges in AllowedSrc", - config: VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: false, - FixMSS: true, - GRE: false, - OPS: false, - }, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - PortRanges: []PortRange{ - {From: 1024, To: 65535}, - }, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - )}, - PortRanges: []PortRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - {From: 8000, To: 9000}, - }, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - }, - PeersV6: []netip.Addr{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cConfig, err := goToCVsConfig(&tt.config) - if err != nil { - t.Fatalf("Failed to convert to C: %v", err) - } - defer freeCVsConfig(cConfig) - - // Convert C -> Go - result := cToGoVsConfig(cConfig) - - // Compare - if diff := cmp.Diff(&tt.config, result, cmp.Comparer(compareAddr), cmp.Comparer(comparePrefix), cmp.Comparer(compareNetWithMask)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestPacketHandlerConfigConversion tests round-trip conversion of packet handler configuration -func TestPacketHandlerConfigConversion(t *testing.T) { - tests := []struct { - name string - config PacketHandlerConfig - }{ - { - name: "Complete configuration", - config: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.1.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - }, - }, - }, - { - name: "Configuration with multiple VS", - config: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.1.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.1.2"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.1.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{GRE: true}, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 8443, - }, - Src: xnetip.FromPrefix(netip.MustParsePrefix( - "2001:db8:1::/64", - )), - Weight: 50, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("2001:db8::/32"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::2"), - netip.MustParseAddr("2001:db8::3"), - netip.MustParseAddr("2001:db8::4"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - netip.MustParseAddr("2001:db8::101"), - netip.MustParseAddr("2001:db8::102"), - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cConfig, err := goToCPacketHandlerConfig(&tt.config) - if err != nil { - t.Fatalf("Failed to convert to C: %v", err) - } - defer freeCPacketHandlerConfig(cConfig) - - // Convert C -> Go - result := cToGoPacketHandlerConfig(cConfig) - - // Compare - if diff := cmp.Diff(&tt.config, result, cmp.Comparer(compareAddr), cmp.Comparer(comparePrefix), cmp.Comparer(compareNetWithMask)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestBalancerConfigConversion tests round-trip conversion of balancer configuration -func TestBalancerConfigConversion(t *testing.T) { - tests := []struct { - name string - config BalancerConfig - }{ - { - name: "Complete balancer config", - config: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.1.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "172.16.0.0/24", - ), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.1.2"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "172.16.1.0/24", - ), - ), - Weight: 150, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix( - "192.168.0.0/16", - ), - ), - }, - }, - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - ), - }, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - netip.MustParseAddr("192.168.1.3"), - netip.MustParseAddr("192.168.1.4"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - netip.MustParseAddr("2001:db8::101"), - netip.MustParseAddr("2001:db8::102"), - netip.MustParseAddr("2001:db8::103"), - netip.MustParseAddr("2001:db8::104"), - }, - }, - State: StateConfig{ - TableCapacity: 10000, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cConfig, err := goToCBalancerConfig(&tt.config) - if err != nil { - t.Fatalf("Failed to convert to C: %v", err) - } - defer freeCBalancerConfig(cConfig) - - // Convert C -> Go - result := cToGoBalancerConfig(cConfig) - - // Compare - if diff := cmp.Diff(&tt.config, result, cmp.Comparer(compareAddr), cmp.Comparer(comparePrefix), cmp.Comparer(compareNetWithMask)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestBalancerManagerConfigConversion tests round-trip conversion of manager configuration -func TestBalancerManagerConfigConversion(t *testing.T) { - tests := []struct { - name string - config BalancerManagerConfig - }{ - { - name: "Complete manager config", - config: BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "10.0.1.1", - ), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "172.16.0.0/24", - ), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix( - "192.168.0.0/16", - ), - ), - }, - }, - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix( - "10.0.0.0/8", - ), - ), - }, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - netip.MustParseAddr("192.168.1.3"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - netip.MustParseAddr("192.168.1.3"), - netip.MustParseAddr("192.168.1.4"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - netip.MustParseAddr("2001:db8::101"), - netip.MustParseAddr("2001:db8::102"), - netip.MustParseAddr("2001:db8::103"), - netip.MustParseAddr("2001:db8::104"), - }, - }, - State: StateConfig{ - TableCapacity: 10000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - Vs: []uint32{1, 2, 3, 4, 5, 6, 7}, - }, - RefreshPeriod: 5 * time.Second, - MaxLoadFactor: 0.75, - }, - }, - { - name: "Manager config with different WLC settings", - config: BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "10.0.1.1", - ), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "172.16.0.0/24", - ), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix( - "192.168.0.0/16", - ), - ), - }, - }, - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix( - "10.0.0.0/8", - ), - ), - }, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - netip.MustParseAddr("192.168.1.3"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("::1"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - netip.MustParseAddr("192.168.1.3"), - netip.MustParseAddr("192.168.1.4"), - netip.MustParseAddr("192.168.1.5"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - netip.MustParseAddr("2001:db8::101"), - netip.MustParseAddr("2001:db8::102"), - }, - }, - State: StateConfig{ - TableCapacity: 5000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 1, - MaxRealWeight: 500, - Vs: []uint32{10, 20, 30}, - }, - RefreshPeriod: 10 * time.Second, - MaxLoadFactor: 0.9, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Convert Go -> C - cConfig, err := goToCBalancerManagerConfig(&tt.config) - if err != nil { - t.Fatalf("Failed to convert to C: %v", err) - } - defer freeCBalancerManagerConfig(cConfig) - - // Convert C -> Go - result := cToGoBalancerManagerConfig(cConfig) - - // Compare - if diff := cmp.Diff(&tt.config, result, cmp.Comparer(compareAddr), cmp.Comparer(comparePrefix), cmp.Comparer(compareNetWithMask)); diff != "" { - t.Errorf( - "Round-trip conversion mismatch (-want +got):\n%s", - diff, - ) - } - }) - } -} - -// TestEdgeCases tests edge cases and boundary conditions -func TestEdgeCases(t *testing.T) { - t.Run("Empty arrays in VsConfig", func(t *testing.T) { - config := VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{}, - AllowedSources: []AllowedSources{}, - PeersV4: []netip.Addr{}, - PeersV6: []netip.Addr{}, - } - - cConfig, err := goToCVsConfig(&config) - if err != nil { - t.Fatalf("Failed to convert: %v", err) - } - defer freeCVsConfig(cConfig) - - result := cToGoVsConfig(cConfig) - - if len(result.Reals) != 0 { - t.Errorf("Expected empty Reals, got %d items", len(result.Reals)) - } - if len(result.AllowedSources) != 0 { - t.Errorf( - "Expected empty AllowedSrc, got %d items", - len(result.AllowedSources), - ) - } - }) - - t.Run("Zero port (PureL3 mode)", func(t *testing.T) { - id := VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 0, - TransportProto: 6, - } - - cID := goToCVsIdentifier(id) - result := cToGoVsIdentifier(cID) - - if result.Port != 0 { - t.Errorf("Port should be 0, got %v", result.Port) - } - }) - - t.Run("IPv4 vs IPv6 distinction", func(t *testing.T) { - v4Addr := netip.MustParseAddr("192.168.1.1") - v6Addr := netip.MustParseAddr("2001:db8::1") - - cV4 := goToCNetAddr(v4Addr) - cV6 := goToCNetAddr(v6Addr) - - resultV4 := cToGoNetAddr(cV4, true) - resultV6 := cToGoNetAddr(cV6, false) - - if !resultV4.Is4() { - t.Errorf("Expected IPv4 address, got %v", resultV4) - } - if !resultV6.Is6() { - t.Errorf("Expected IPv6 address, got %v", resultV6) - } - }) - - t.Run("Nil config pointer", func(t *testing.T) { - result := cToGoBalancerConfig(nil) - if result != nil { - t.Errorf("Expected nil result for nil input, got %v", result) - } - }) -} - -// TestComplexScenario tests a realistic complex scenario -func TestComplexScenario(t *testing.T) { - config := BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("192.168.1.100"), - - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: false, - FixMSS: true, - GRE: false, - OPS: false, - }, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.1.0/24"), - ), - Weight: 150, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("10.0.0.0/8"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - PeersV6: []netip.Addr{}, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - PureL3: false, - FixMSS: false, - GRE: true, - OPS: false, - }, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 8443, - }, - Src: xnetip.FromPrefix(netip.MustParsePrefix( - "2001:db8:1::/64", - )), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("2001:db8::/32"), - )}, - }, - }, - PeersV4: []netip.Addr{}, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::2"), - netip.MustParseAddr("2001:db8::3"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.1.1"), - netip.MustParseAddr("192.168.1.2"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::100"), - netip.MustParseAddr("2001:db8::101"), - }, - }, - State: StateConfig{ - TableCapacity: 100000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - Vs: []uint32{1, 2, 3, 4, 5}, - }, - RefreshPeriod: 5 * time.Second, - MaxLoadFactor: 0.75, - } - - // Convert Go -> C - cConfig, err := goToCBalancerManagerConfig(&config) - if err != nil { - t.Fatalf("Failed to convert to C: %v", err) - } - defer freeCBalancerManagerConfig(cConfig) - - // Convert C -> Go - result := cToGoBalancerManagerConfig(cConfig) - - // Compare - if diff := cmp.Diff(&config, result, cmp.Comparer(compareAddr), cmp.Comparer(comparePrefix), cmp.Comparer(compareNetWithMask)); diff != "" { - t.Errorf("Round-trip conversion mismatch (-want +got):\n%s", diff) - } -} - -// TestLargeScaleConversion tests conversion performance with a large configuration -func TestLargeScaleConversion(t *testing.T) { - // Build a large configuration with 100 virtual services - config := PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 60, - TCPSyn: 30, - TCPFin: 30, - TCP: 300, - UDP: 120, - Default: 60, - }, - VirtualServices: make([]VsConfig, 100), - SourceV4: netip.MustParseAddr("10.0.0.1"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: make([]netip.Addr, 100), - DecapV6: make([]netip.Addr, 100), - } - - // Generate 100 DecapV4 addresses - for i := 0; i < 100; i++ { - config.DecapV4[i] = netip.MustParseAddr( - fmt.Sprintf("192.168.%d.%d", i/256, i%256), - ) - } - - // Generate 100 DecapV6 addresses - for i := 0; i < 100; i++ { - config.DecapV6[i] = netip.MustParseAddr( - fmt.Sprintf("2001:db8::%x", i+1), - ) - } - - // Generate 100 virtual services - for vsIdx := 0; vsIdx < 100; vsIdx++ { - vs := VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr( - fmt.Sprintf("10.0.%d.%d", vsIdx/256, vsIdx%256), - ), - - Port: uint16(8000 + vsIdx), - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{ - FixMSS: vsIdx%2 == 0, - GRE: vsIdx%3 == 0, - }, - Scheduler: VsScheduler(vsIdx % 2), - Reals: make([]RealConfig, 100), - } - - // Generate 100 reals for each VS - for realIdx := 0; realIdx < 100; realIdx++ { - vs.Reals[realIdx] = RealConfig{ - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - fmt.Sprintf("172.16.%d.%d", realIdx/256, realIdx%256), - ), - Port: uint16(9000 + realIdx), - }, - Src: xnetip.FromPrefix(netip.MustParsePrefix( - fmt.Sprintf("172.16.%d.0/24", realIdx/256)), - ), - Weight: uint16(100 + realIdx%900), - } - } - - // Generate 10-100 allowed sources (varies per VS) - allowedCount := 10 + (vsIdx % 91) - vs.AllowedSources = make([]AllowedSources, allowedCount) - for i := 0; i < allowedCount; i++ { - vs.AllowedSources[i] = AllowedSources{ - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix(netip.MustParsePrefix( - fmt.Sprintf("10.%d.%d.0/24", i/256, i%256), - )), - }, - } - } - - // Generate 10-100 IPv4 peers (varies per VS) - peersV4Count := 10 + (vsIdx % 91) - vs.PeersV4 = make([]netip.Addr, peersV4Count) - for i := 0; i < peersV4Count; i++ { - vs.PeersV4[i] = netip.MustParseAddr( - fmt.Sprintf("192.168.%d.%d", i/256, i%256), - ) - } - - // Generate 10-100 IPv6 peers (varies per VS) - peersV6Count := 10 + ((vsIdx + 50) % 91) - vs.PeersV6 = make([]netip.Addr, peersV6Count) - for i := 0; i < peersV6Count; i++ { - vs.PeersV6[i] = netip.MustParseAddr( - fmt.Sprintf("2001:db8::%x", i+1), - ) - } - - config.VirtualServices[vsIdx] = vs - } - - // Measure Go -> C conversion time - startGoToC := time.Now() - cConfig, err := goToCPacketHandlerConfig(&config) - goToCDuration := time.Since(startGoToC) - if err != nil { - t.Fatalf("Failed to convert Go -> C: %v", err) - } - defer freeCPacketHandlerConfig(cConfig) - - // Measure C -> Go conversion time - startCToGo := time.Now() - result := cToGoPacketHandlerConfig(cConfig) - cToGoDuration := time.Since(startCToGo) - - // Measure round-trip time - totalDuration := goToCDuration + cToGoDuration - - // Output timing results - t.Logf("Large-scale conversion performance:") - t.Logf( - " Configuration: 100 VS × 100 reals each + varying peers/allowed sources", - ) - t.Logf(" Go -> C conversion: %v", goToCDuration) - t.Logf(" C -> Go conversion: %v", cToGoDuration) - t.Logf(" Total round-trip: %v", totalDuration) - - // Basic validation - if len(result.VirtualServices) != 100 { - t.Errorf( - "Expected 100 virtual services, got %d", - len(result.VirtualServices), - ) - } - if len(result.DecapV4) != 100 { - t.Errorf("Expected 100 DecapV4 addresses, got %d", len(result.DecapV4)) - } - if len(result.DecapV6) != 100 { - t.Errorf("Expected 100 DecapV6 addresses, got %d", len(result.DecapV6)) - } - - // Validate first VS has 100 reals - if len(result.VirtualServices) > 0 && - len(result.VirtualServices[0].Reals) != 100 { - t.Errorf( - "Expected first VS to have 100 reals, got %d", - len(result.VirtualServices[0].Reals), - ) - } -} diff --git a/modules/balancer/agent/go/ffi/conversions.go b/modules/balancer/agent/go/ffi/conversions.go deleted file mode 100644 index 9fe74e4aa..000000000 --- a/modules/balancer/agent/go/ffi/conversions.go +++ /dev/null @@ -1,1550 +0,0 @@ -package ffi - -// Low-level conversion functions between Go and C types for balancer FFI operations. -// Handles memory-safe transformations of network addresses, configurations, statistics, -// and complex nested structures with proper memory management and cleanup. - -/* -#cgo CFLAGS: -I../../ -I../../../../../ -#cgo LDFLAGS: -L../../../../../build/modules/balancer/agent -lbalancer_agent -L../../../../../build/modules/balancer/controlplane/api -lbalancer_cp -L../../../../../build/modules/balancer/controlplane/handler -lbalancer_packet_handler -L../../../../../build/modules/balancer/controlplane/state -lbalancer_state -lbalancer_packet_handler -lbalancer_state -#include "agent.h" -#include "manager.h" -#include "modules/balancer/controlplane/api/graph.h" -#include "modules/balancer/controlplane/api/vs.h" -#include "modules/balancer/controlplane/api/real.h" -#include "modules/balancer/controlplane/api/inspect.h" -#include -#include -*/ -import "C" - -import ( - "fmt" - "net/netip" - "time" - "unsafe" - - "github.com/yanet-platform/yanet2/common/go/xnetip" -) - -func goToCNetAddr(addr netip.Addr) C.struct_net_addr { - var cAddr C.struct_net_addr - // Zero-initialize the entire union to avoid padding issues - ptr := unsafe.Pointer(&cAddr) - size := unsafe.Sizeof(cAddr) - slice := unsafe.Slice((*byte)(ptr), size) - for i := range slice { - slice[i] = 0 - } - - if addr.Is4() { - v4 := addr.As4() - // Access union field through unsafe pointer cast - pv4 := (*C.struct_net4_addr)(unsafe.Pointer(&cAddr)) - C.memcpy(unsafe.Pointer(&pv4.bytes[0]), unsafe.Pointer(&v4[0]), 4) - } else { - v6 := addr.As16() - // Access union field through unsafe pointer cast - pv6 := (*C.struct_net6_addr)(unsafe.Pointer(&cAddr)) - C.memcpy(unsafe.Pointer(&pv6.bytes[0]), unsafe.Pointer(&v6[0]), 16) - } - return cAddr -} - -func cToGoNetAddr(cAddr C.struct_net_addr, isV4 bool) netip.Addr { - if isV4 { - var v4 [4]byte - pv4 := (*C.struct_net4_addr)(unsafe.Pointer(&cAddr)) - C.memcpy(unsafe.Pointer(&v4[0]), unsafe.Pointer(&pv4.bytes[0]), 4) - return netip.AddrFrom4(v4) - } - var v6 [16]byte - pv6 := (*C.struct_net6_addr)(unsafe.Pointer(&cAddr)) - C.memcpy(unsafe.Pointer(&v6[0]), unsafe.Pointer(&pv6.bytes[0]), 16) - return netip.AddrFrom16(v6) -} - -func goToCNet(net xnetip.NetWithMask) C.struct_net { - var cNet C.struct_net - addr := net.Addr - mask := net.MaskBytes() - - // Zero-initialize the entire union to avoid garbage data - ptr := unsafe.Pointer(&cNet) - size := unsafe.Sizeof(cNet) - slice := unsafe.Slice((*byte)(ptr), size) - for i := range slice { - slice[i] = 0 - } - - if addr.Is4() { - v4 := addr.As4() - // For IPv4, the struct net4 layout is: - // - addr[4] at offset 0 - // - mask[4] at offset 4 - // Copy addr to bytes 0-3 - for i := range 4 { - slice[i] = v4[i] - } - // Copy mask to bytes 4-7 - for i := 0; i < 4; i++ { - slice[4+i] = mask[i] - } - } else { - v6 := addr.As16() - // For IPv6, the struct net6 layout is: - // - addr[16] at offset 0 - // - mask[16] at offset 16 - // Copy addr to bytes 0-15 - for i := 0; i < 16; i++ { - slice[i] = v6[i] - } - // Copy mask to bytes 16-31 - for i := 0; i < 16; i++ { - slice[16+i] = mask[i] - } - } - return cNet -} - -func cToGoNet(cNet C.struct_net, isV4 bool) xnetip.NetWithMask { - if isV4 { - var addr [4]byte - var mask [4]byte - // Copy from union bytes - C.memcpy(unsafe.Pointer(&addr[0]), unsafe.Pointer(&cNet), 4) - C.memcpy( - unsafe.Pointer(&mask[0]), - unsafe.Pointer(uintptr(unsafe.Pointer(&cNet))+4), - 4, - ) - - return xnetip.NetWithMask{ - Addr: netip.AddrFrom4(addr), - Mask: mask[:], - } - } - - var addr [16]byte - var mask [16]byte - // Copy from union bytes - C.memcpy(unsafe.Pointer(&addr[0]), unsafe.Pointer(&cNet), 16) - C.memcpy( - unsafe.Pointer(&mask[0]), - unsafe.Pointer(uintptr(unsafe.Pointer(&cNet))+16), - 16, - ) - - return xnetip.NetWithMask{ - Addr: netip.AddrFrom16(addr), - Mask: mask[:], - } -} - -// VS type conversions - -func goToCVsIdentifier(id VsIdentifier) C.struct_vs_identifier { - var cID C.struct_vs_identifier - // Zero-initialize the entire structure to avoid padding issues - ptr := unsafe.Pointer(&cID) - size := unsafe.Sizeof(cID) - slice := unsafe.Slice((*byte)(ptr), size) - for i := range slice { - slice[i] = 0 - } - - cID.addr = goToCNetAddr(id.Addr) - // Derive ip_proto from the address type - if id.Addr.Is4() { - cID.ip_proto = 0 // IPPROTO_IP (IPv4) - } else { - cID.ip_proto = 41 // IPPROTO_IPV6 - } - cID.port = C.uint16_t(id.Port) - // Convert Go enum (0=TCP, 1=UDP) to C constants (6=IPPROTO_TCP, 17=IPPROTO_UDP) - if id.TransportProto == VsTransportProtoTCP { - cID.transport_proto = C.IPPROTO_TCP // 6 - } else { - cID.transport_proto = C.IPPROTO_UDP // 17 - } - return cID -} - -func cToGoVsIdentifier(cID C.struct_vs_identifier) VsIdentifier { - // Determine if IPv4 or IPv6 based on ip_proto - isV4 := cID.ip_proto == 0 // IPPROTO_IP (IPv4) - return VsIdentifier{ - Addr: cToGoNetAddr(cID.addr, isV4), - Port: uint16(cID.port), - // Convert C constants (6=IPPROTO_TCP, 17=IPPROTO_UDP) to Go enum (0=TCP, 1=UDP) - TransportProto: func() VsTransportProto { - if cID.transport_proto == C.IPPROTO_TCP { // 6 - return VsTransportProtoTCP // 0 - } - return VsTransportProtoUDP // 1 - }(), - } -} - -// Real type conversions - -func goToCRelativeRealIdentifier( - id RelativeRealIdentifier, -) C.struct_relative_real_identifier { - var cID C.struct_relative_real_identifier - // Zero-initialize the entire structure to avoid padding issues - ptr := unsafe.Pointer(&cID) - size := unsafe.Sizeof(cID) - slice := unsafe.Slice((*byte)(ptr), size) - for i := range slice { - slice[i] = 0 - } - - cID.addr = goToCNetAddr(id.Addr) - // Derive ip_proto from the address type - if id.Addr.Is4() { - cID.ip_proto = 0 // IPPROTO_IP (IPv4) - } else { - cID.ip_proto = 41 // IPPROTO_IPV6 - } - cID.port = C.uint16_t(id.Port) - return cID -} - -func cToGoRelativeRealIdentifier( - cID C.struct_relative_real_identifier, -) RelativeRealIdentifier { - isV4 := cID.ip_proto == 0 // IPPROTO_IP (IPv4) - return RelativeRealIdentifier{ - Addr: cToGoNetAddr(cID.addr, isV4), - Port: uint16(cID.port), - } -} - -func goToCRealIdentifier(id RealIdentifier) C.struct_real_identifier { - var cID C.struct_real_identifier - // Zero-initialize the entire structure to avoid padding issues - ptr := unsafe.Pointer(&cID) - size := unsafe.Sizeof(cID) - slice := unsafe.Slice((*byte)(ptr), size) - for i := range slice { - slice[i] = 0 - } - - cID.vs_identifier = goToCVsIdentifier(id.VsIdentifier) - cID.relative = goToCRelativeRealIdentifier(id.Relative) - return cID -} - -func cToGoRealIdentifier(cID C.struct_real_identifier) RealIdentifier { - return RealIdentifier{ - VsIdentifier: cToGoVsIdentifier(cID.vs_identifier), - Relative: cToGoRelativeRealIdentifier(cID.relative), - } -} - -// Time conversions (uint32 monotonic timestamp to time.Time) -func cToGoTimestamp(ts uint32) time.Time { - return time.Unix(int64(ts), 0) -} - -func goToCTimestamp(t time.Time) uint32 { - return uint32(t.Unix()) -} - -// RealUpdate conversions - -func goToCRealUpdate(update RealUpdate) C.struct_real_update { - var cUpdate C.struct_real_update - cUpdate.identifier = goToCRealIdentifier(update.Identifier) - cUpdate.weight = C.uint16_t(update.Weight) - cUpdate.enabled = C.uint8_t(update.Enabled) - return cUpdate -} - -// PacketHandlerRef conversions - -func goToCPacketHandlerRef( - ref *PacketHandlerRef, -) *C.struct_packet_handler_ref { - if ref == nil { - return nil - } - - cRef := (*C.struct_packet_handler_ref)( - C.malloc(C.sizeof_struct_packet_handler_ref), - ) - - if ref.Device != nil { - cRef.device = C.CString(*ref.Device) - } else { - cRef.device = nil - } - - if ref.Pipeline != nil { - cRef.pipeline = C.CString(*ref.Pipeline) - } else { - cRef.pipeline = nil - } - - if ref.Function != nil { - cRef.function = C.CString(*ref.Function) - } else { - cRef.function = nil - } - - if ref.Chain != nil { - cRef.chain = C.CString(*ref.Chain) - } else { - cRef.chain = nil - } - - return cRef -} - -func freeCPacketHandlerRef(cRef *C.struct_packet_handler_ref) { - if cRef == nil { - return - } - - if cRef.device != nil { - C.free(unsafe.Pointer(cRef.device)) - } - if cRef.pipeline != nil { - C.free(unsafe.Pointer(cRef.pipeline)) - } - if cRef.function != nil { - C.free(unsafe.Pointer(cRef.function)) - } - if cRef.chain != nil { - C.free(unsafe.Pointer(cRef.chain)) - } - - C.free(unsafe.Pointer(cRef)) -} - -// BalancerManagerConfig conversions - -func goToCBalancerManagerConfig( - config *BalancerManagerConfig, -) (*C.struct_balancer_manager_config, error) { - if config == nil { - return nil, fmt.Errorf("config is nil") - } - - cConfig := (*C.struct_balancer_manager_config)( - C.malloc(C.sizeof_struct_balancer_manager_config), - ) - - // Convert balancer config directly into the embedded struct - err := goToCBalancerConfigInPlace(&config.Balancer, &cConfig.balancer) - if err != nil { - C.free(unsafe.Pointer(cConfig)) - return nil, err - } - - // Convert WLC config - cConfig.wlc.power = C.size_t(config.Wlc.Power) - cConfig.wlc.max_real_weight = C.size_t(config.Wlc.MaxRealWeight) - cConfig.wlc.vs_count = C.size_t(len(config.Wlc.Vs)) - - if len(config.Wlc.Vs) > 0 { - cConfig.wlc.vs = (*C.uint32_t)( - C.malloc(C.size_t(len(config.Wlc.Vs)) * C.sizeof_uint32_t), - ) - cVsSlice := unsafe.Slice(cConfig.wlc.vs, len(config.Wlc.Vs)) - for i, vs := range config.Wlc.Vs { - cVsSlice[i] = C.uint32_t(vs) - } - } else { - cConfig.wlc.vs = nil - } - - cConfig.refresh_period = C.uint32_t(config.RefreshPeriod.Milliseconds()) - cConfig.max_load_factor = C.float(config.MaxLoadFactor) - - return cConfig, nil -} - -func freeCBalancerManagerConfig(cConfig *C.struct_balancer_manager_config) { - if cConfig == nil { - return - } - - // Free balancer config internals - freeCBalancerConfig(&cConfig.balancer) - - // Free WLC VS array - if cConfig.wlc.vs != nil { - C.free(unsafe.Pointer(cConfig.wlc.vs)) - } - - C.free(unsafe.Pointer(cConfig)) -} - -func cToGoBalancerManagerConfig( - cConfig *C.struct_balancer_manager_config, -) *BalancerManagerConfig { - if cConfig == nil { - return nil - } - - config := &BalancerManagerConfig{ - Balancer: *cToGoBalancerConfig(&cConfig.balancer), - RefreshPeriod: time.Duration(cConfig.refresh_period) * time.Millisecond, - MaxLoadFactor: float32(cConfig.max_load_factor), - } - - // Convert WLC config - config.Wlc.Power = uint(cConfig.wlc.power) - config.Wlc.MaxRealWeight = uint(cConfig.wlc.max_real_weight) - - if cConfig.wlc.vs_count > 0 && cConfig.wlc.vs != nil { - cVsSlice := unsafe.Slice(cConfig.wlc.vs, cConfig.wlc.vs_count) - config.Wlc.Vs = make([]uint32, cConfig.wlc.vs_count) - for i := range config.Wlc.Vs { - config.Wlc.Vs[i] = uint32(cVsSlice[i]) - } - } else { - config.Wlc.Vs = []uint32{} - } - - return config -} - -// BalancerConfig conversions - -func goToCBalancerConfig( - config *BalancerConfig, -) (*C.struct_balancer_config, error) { - cConfig := (*C.struct_balancer_config)( - C.malloc(C.sizeof_struct_balancer_config), - ) - err := goToCBalancerConfigInPlace(config, cConfig) - if err != nil { - C.free(unsafe.Pointer(cConfig)) - return nil, err - } - return cConfig, nil -} - -func goToCBalancerConfigInPlace( - config *BalancerConfig, - cConfig *C.struct_balancer_config, -) error { - // Convert handler config directly into the embedded struct - err := goToCPacketHandlerConfigInPlace(&config.Handler, &cConfig.handler) - if err != nil { - return err - } - - // Convert state config - cConfig.state.table_capacity = C.size_t(config.State.TableCapacity) - - return nil -} - -func freeCBalancerConfig(cConfig *C.struct_balancer_config) { - if cConfig == nil { - return - } - - freeCPacketHandlerConfig(&cConfig.handler) -} - -func cToGoBalancerConfig(cConfig *C.struct_balancer_config) *BalancerConfig { - if cConfig == nil { - return nil - } - - return &BalancerConfig{ - Handler: *cToGoPacketHandlerConfig(&cConfig.handler), - State: StateConfig{ - TableCapacity: uint(cConfig.state.table_capacity), - }, - } -} - -// PacketHandlerConfig conversions - -func goToCPacketHandlerConfig( - config *PacketHandlerConfig, -) (*C.struct_packet_handler_config, error) { - cConfig := (*C.struct_packet_handler_config)( - C.malloc(C.sizeof_struct_packet_handler_config), - ) - err := goToCPacketHandlerConfigInPlace(config, cConfig) - if err != nil { - C.free(unsafe.Pointer(cConfig)) - return nil, err - } - return cConfig, nil -} - -func goToCPacketHandlerConfigInPlace( - config *PacketHandlerConfig, - cConfig *C.struct_packet_handler_config, -) error { - if config == nil { - return fmt.Errorf("config is nil") - } - if len(config.SourceV4.AsSlice()) == 0 { - return fmt.Errorf("IPv4 source address is empty") - } - if len(config.SourceV6.AsSlice()) == 0 { - return fmt.Errorf("IPv6 source address is empty") - } - - // Convert sessions timeouts - cConfig.sessions_timeouts.tcp_syn_ack = C.uint32_t( - config.SessionsTimeouts.TCPSynAck, - ) - cConfig.sessions_timeouts.tcp_syn = C.uint32_t( - config.SessionsTimeouts.TCPSyn, - ) - cConfig.sessions_timeouts.tcp_fin = C.uint32_t( - config.SessionsTimeouts.TCPFin, - ) - cConfig.sessions_timeouts.tcp = C.uint32_t(config.SessionsTimeouts.TCP) - cConfig.sessions_timeouts.udp = C.uint32_t(config.SessionsTimeouts.UDP) - cConfig.sessions_timeouts.def = C.uint32_t(config.SessionsTimeouts.Default) - - // Convert source addresses (need to cast from net_addr to net4_addr/net6_addr) - v4 := config.SourceV4.As4() - C.memcpy( - unsafe.Pointer(&cConfig.source_v4.bytes[0]), - unsafe.Pointer(&v4[0]), - 4, - ) - - v6 := config.SourceV6.As16() - C.memcpy( - unsafe.Pointer(&cConfig.source_v6.bytes[0]), - unsafe.Pointer(&v6[0]), - 16, - ) - - // Convert decap addresses - cConfig.decap_v4_count = C.size_t(len(config.DecapV4)) - cConfig.decap_v6_count = C.size_t(len(config.DecapV6)) - - if len(config.DecapV4) > 0 { - cConfig.decap_v4 = (*C.struct_net4_addr)( - C.malloc(C.size_t(len(config.DecapV4)) * C.sizeof_struct_net4_addr), - ) - cDecapV4Slice := unsafe.Slice(cConfig.decap_v4, len(config.DecapV4)) - for i, addr := range config.DecapV4 { - v4 := addr.As4() - C.memcpy( - unsafe.Pointer(&cDecapV4Slice[i].bytes[0]), - unsafe.Pointer(&v4[0]), - 4, - ) - } - } else { - cConfig.decap_v4 = nil - } - - if len(config.DecapV6) > 0 { - cConfig.decap_v6 = (*C.struct_net6_addr)( - C.malloc(C.size_t(len(config.DecapV6)) * C.sizeof_struct_net6_addr), - ) - cDecapV6Slice := unsafe.Slice(cConfig.decap_v6, len(config.DecapV6)) - for i, addr := range config.DecapV6 { - v6 := addr.As16() - C.memcpy( - unsafe.Pointer(&cDecapV6Slice[i].bytes[0]), - unsafe.Pointer(&v6[0]), - 16, - ) - } - } else { - cConfig.decap_v6 = nil - } - - // Convert virtual services - cConfig.vs_count = C.size_t(len(config.VirtualServices)) - if len(config.VirtualServices) > 0 { - cConfig.vs = (*C.struct_named_vs_config)( - C.malloc( - C.size_t( - len(config.VirtualServices), - ) * C.sizeof_struct_named_vs_config, - ), - ) - cVsSlice := unsafe.Slice(cConfig.vs, len(config.VirtualServices)) - for i, vs := range config.VirtualServices { - err := goToCVsConfigInPlace(&vs, &cVsSlice[i]) - if err != nil { - // Cleanup on error - freeCPacketHandlerConfig(cConfig) - return err - } - } - } else { - cConfig.vs = nil - } - - return nil -} - -func freeCPacketHandlerConfig(cConfig *C.struct_packet_handler_config) { - if cConfig == nil { - return - } - - if cConfig.decap_v4 != nil { - C.free(unsafe.Pointer(cConfig.decap_v4)) - } - if cConfig.decap_v6 != nil { - C.free(unsafe.Pointer(cConfig.decap_v6)) - } - - if cConfig.vs != nil { - cVsSlice := unsafe.Slice(cConfig.vs, cConfig.vs_count) - for i := range cVsSlice { - // Free only the internal allocations, not the struct itself - // since it's part of an array - if cVsSlice[i].config.reals != nil { - C.free(unsafe.Pointer(cVsSlice[i].config.reals)) - } - if cVsSlice[i].config.allowed_src != nil { - // Free port ranges and nets within each allowed_src - cAllowedSlice := unsafe.Slice( - cVsSlice[i].config.allowed_src, - cVsSlice[i].config.allowed_src_count, - ) - for j := range cAllowedSlice { - if cAllowedSlice[j].nets != nil { - C.free(unsafe.Pointer(cAllowedSlice[j].nets)) - } - if cAllowedSlice[j].port_ranges != nil { - C.free(unsafe.Pointer(cAllowedSlice[j].port_ranges)) - } - // Free tag string if allocated - if cAllowedSlice[j].tag != nil { - C.free(unsafe.Pointer(cAllowedSlice[j].tag)) - } - } - C.free(unsafe.Pointer(cVsSlice[i].config.allowed_src)) - } - if cVsSlice[i].config.peers_v4 != nil { - C.free(unsafe.Pointer(cVsSlice[i].config.peers_v4)) - } - if cVsSlice[i].config.peers_v6 != nil { - C.free(unsafe.Pointer(cVsSlice[i].config.peers_v6)) - } - } - C.free(unsafe.Pointer(cConfig.vs)) - } -} - -func cToGoPacketHandlerConfig( - cConfig *C.struct_packet_handler_config, -) *PacketHandlerConfig { - if cConfig == nil { - return nil - } - - config := &PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: uint32(cConfig.sessions_timeouts.tcp_syn_ack), - TCPSyn: uint32(cConfig.sessions_timeouts.tcp_syn), - TCPFin: uint32(cConfig.sessions_timeouts.tcp_fin), - TCP: uint32(cConfig.sessions_timeouts.tcp), - UDP: uint32(cConfig.sessions_timeouts.udp), - Default: uint32(cConfig.sessions_timeouts.def), - }, - SourceV4: func() netip.Addr { - var v4 [4]byte - C.memcpy( - unsafe.Pointer(&v4[0]), - unsafe.Pointer(&cConfig.source_v4.bytes[0]), - 4, - ) - return netip.AddrFrom4(v4) - }(), - SourceV6: func() netip.Addr { - var v6 [16]byte - C.memcpy( - unsafe.Pointer(&v6[0]), - unsafe.Pointer(&cConfig.source_v6.bytes[0]), - 16, - ) - return netip.AddrFrom16(v6) - }(), - } - - // Convert decap addresses - if cConfig.decap_v4_count > 0 && cConfig.decap_v4 != nil { - cDecapV4Slice := unsafe.Slice(cConfig.decap_v4, cConfig.decap_v4_count) - config.DecapV4 = make([]netip.Addr, cConfig.decap_v4_count) - for i := range config.DecapV4 { - var v4 [4]byte - C.memcpy( - unsafe.Pointer(&v4[0]), - unsafe.Pointer(&cDecapV4Slice[i].bytes[0]), - 4, - ) - config.DecapV4[i] = netip.AddrFrom4(v4) - } - } else { - config.DecapV4 = []netip.Addr{} - } - - if cConfig.decap_v6_count > 0 && cConfig.decap_v6 != nil { - cDecapV6Slice := unsafe.Slice(cConfig.decap_v6, cConfig.decap_v6_count) - config.DecapV6 = make([]netip.Addr, cConfig.decap_v6_count) - for i := range config.DecapV6 { - var v6 [16]byte - C.memcpy( - unsafe.Pointer(&v6[0]), - unsafe.Pointer(&cDecapV6Slice[i].bytes[0]), - 16, - ) - config.DecapV6[i] = netip.AddrFrom16(v6) - } - } else { - config.DecapV6 = []netip.Addr{} - } - - // Convert virtual services - if cConfig.vs_count > 0 && cConfig.vs != nil { - cVsSlice := unsafe.Slice(cConfig.vs, cConfig.vs_count) - config.VirtualServices = make([]VsConfig, cConfig.vs_count) - for i := range config.VirtualServices { - config.VirtualServices[i] = *cToGoVsConfig(&cVsSlice[i]) - } - } else { - config.VirtualServices = []VsConfig{} - } - - return config -} - -// VsConfig conversions - -func goToCVsConfig(config *VsConfig) (*C.struct_named_vs_config, error) { - if config == nil { - return nil, fmt.Errorf("config is nil") - } - - cConfig := (*C.struct_named_vs_config)( - C.malloc(C.sizeof_struct_named_vs_config), - ) - err := goToCVsConfigInPlace(config, cConfig) - if err != nil { - C.free(unsafe.Pointer(cConfig)) - return nil, err - } - return cConfig, nil -} - -func goToCVsConfigInPlace( - config *VsConfig, - cConfig *C.struct_named_vs_config, -) error { - if config == nil { - return fmt.Errorf("config is nil") - } - - // Convert identifier - cConfig.identifier = goToCVsIdentifier(config.Identifier) - - // Convert flags to C bitfield (using constants from vs.h) - var flags C.uint8_t - if config.Flags.PureL3 { - flags |= C.VS_PURE_L3_FLAG - } - if config.Flags.FixMSS { - flags |= C.VS_FIX_MSS_FLAG - } - if config.Flags.GRE { - flags |= C.VS_GRE_FLAG - } - if config.Flags.OPS { - flags |= C.VS_OPS_FLAG - } - cConfig.config.flags = flags - - // Convert scheduler (enum vs_scheduler from vs.h: source_hash=0, round_robin=1) - // Cast through int to match C enum type - cConfig.config.scheduler = C.enum_vs_scheduler(C.int(config.Scheduler)) - - // Convert reals array - cConfig.config.real_count = C.size_t(len(config.Reals)) - if len(config.Reals) > 0 { - cConfig.config.reals = (*C.struct_named_real_config)( - C.malloc( - C.size_t( - len(config.Reals), - ) * C.size_t( - unsafe.Sizeof(C.struct_named_real_config{}), - ), - ), - ) - cRealsSlice := unsafe.Slice(cConfig.config.reals, len(config.Reals)) - for i, real := range config.Reals { - cRealsSlice[i].real = goToCRelativeRealIdentifier(real.Identifier) - cRealsSlice[i].config.src = goToCNet(real.Src) - cRealsSlice[i].config.weight = C.uint16_t(real.Weight) - } - } else { - cConfig.config.reals = nil - } - - // Convert allowed sources - cConfig.config.allowed_src_count = C.size_t(len(config.AllowedSources)) - if len(config.AllowedSources) > 0 { - cConfig.config.allowed_src = (*C.struct_allowed_sources)( - C.malloc( - C.size_t( - len(config.AllowedSources), - ) * C.size_t( - unsafe.Sizeof(C.struct_allowed_sources{}), - ), - ), - ) - cAllowedSlice := unsafe.Slice( - cConfig.config.allowed_src, - len(config.AllowedSources), - ) - for i, allowedSrc := range config.AllowedSources { - // Convert networks array - cAllowedSlice[i].nets_count = C.size_t(len(allowedSrc.Nets)) - if len(allowedSrc.Nets) > 0 { - cAllowedSlice[i].nets = (*C.struct_net)( - C.malloc( - C.size_t( - len(allowedSrc.Nets), - ) * C.size_t( - unsafe.Sizeof(C.struct_net{}), - ), - ), - ) - cNetsSlice := unsafe.Slice( - cAllowedSlice[i].nets, - len(allowedSrc.Nets), - ) - for j, net := range allowedSrc.Nets { - cNetsSlice[j] = goToCNet(net) - } - } else { - cAllowedSlice[i].nets = nil - } - - // Convert port ranges - cAllowedSlice[i].port_ranges_count = C.size_t( - len(allowedSrc.PortRanges), - ) - if len(allowedSrc.PortRanges) > 0 { - cAllowedSlice[i].port_ranges = (*C.struct_ports_range)( - C.malloc( - C.size_t( - len(allowedSrc.PortRanges), - ) * C.size_t( - unsafe.Sizeof(C.struct_ports_range{}), - ), - ), - ) - cPortRangesSlice := unsafe.Slice( - cAllowedSlice[i].port_ranges, - len(allowedSrc.PortRanges), - ) - for j, portRange := range allowedSrc.PortRanges { - cPortRangesSlice[j].from = C.uint16_t(portRange.From) - cPortRangesSlice[j].to = C.uint16_t(portRange.To) - } - } else { - cAllowedSlice[i].port_ranges = nil - } - - // Set tag field - convert Go string to C *char - if allowedSrc.Tag != "" { - cAllowedSlice[i].tag = C.CString(allowedSrc.Tag) - } else { - cAllowedSlice[i].tag = nil - } - } - } else { - cConfig.config.allowed_src = nil - } - - // Convert IPv4 peers - cConfig.config.peers_v4_count = C.size_t(len(config.PeersV4)) - if len(config.PeersV4) > 0 { - cConfig.config.peers_v4 = (*C.struct_net4_addr)( - C.malloc(C.size_t(len(config.PeersV4)) * C.sizeof_struct_net4_addr), - ) - cPeersV4Slice := unsafe.Slice( - cConfig.config.peers_v4, - len(config.PeersV4), - ) - for i, peer := range config.PeersV4 { - v4 := peer.As4() - C.memcpy( - unsafe.Pointer(&cPeersV4Slice[i].bytes[0]), - unsafe.Pointer(&v4[0]), - 4, - ) - } - } else { - cConfig.config.peers_v4 = nil - } - - // Convert IPv6 peers - cConfig.config.peers_v6_count = C.size_t(len(config.PeersV6)) - if len(config.PeersV6) > 0 { - cConfig.config.peers_v6 = (*C.struct_net6_addr)( - C.malloc(C.size_t(len(config.PeersV6)) * C.sizeof_struct_net6_addr), - ) - cPeersV6Slice := unsafe.Slice( - cConfig.config.peers_v6, - len(config.PeersV6), - ) - for i, peer := range config.PeersV6 { - v6 := peer.As16() - C.memcpy( - unsafe.Pointer(&cPeersV6Slice[i].bytes[0]), - unsafe.Pointer(&v6[0]), - 16, - ) - } - } else { - cConfig.config.peers_v6 = nil - } - - return nil -} - -func freeCVsConfig(cConfig *C.struct_named_vs_config) { - if cConfig == nil { - return - } - - if cConfig.config.reals != nil { - C.free(unsafe.Pointer(cConfig.config.reals)) - } - if cConfig.config.allowed_src != nil { - // Free port ranges and nets within each allowed_src - cAllowedSlice := unsafe.Slice( - cConfig.config.allowed_src, - cConfig.config.allowed_src_count, - ) - for i := range cAllowedSlice { - if cAllowedSlice[i].nets != nil { - C.free(unsafe.Pointer(cAllowedSlice[i].nets)) - } - if cAllowedSlice[i].port_ranges != nil { - C.free(unsafe.Pointer(cAllowedSlice[i].port_ranges)) - } - // Free tag string if allocated - if cAllowedSlice[i].tag != nil { - C.free(unsafe.Pointer(cAllowedSlice[i].tag)) - } - } - C.free(unsafe.Pointer(cConfig.config.allowed_src)) - } - if cConfig.config.peers_v4 != nil { - C.free(unsafe.Pointer(cConfig.config.peers_v4)) - } - if cConfig.config.peers_v6 != nil { - C.free(unsafe.Pointer(cConfig.config.peers_v6)) - } -} - -func cToGoVsConfig(cConfig *C.struct_named_vs_config) *VsConfig { - if cConfig == nil { - return nil - } - - config := &VsConfig{ - Identifier: cToGoVsIdentifier(cConfig.identifier), - Scheduler: VsScheduler(cConfig.config.scheduler), - } - - // Convert flags from C bitfield - config.Flags.PureL3 = (cConfig.config.flags & C.VS_PURE_L3_FLAG) != 0 - config.Flags.FixMSS = (cConfig.config.flags & C.VS_FIX_MSS_FLAG) != 0 - config.Flags.GRE = (cConfig.config.flags & C.VS_GRE_FLAG) != 0 - config.Flags.OPS = (cConfig.config.flags & C.VS_OPS_FLAG) != 0 - - // Convert reals array - if cConfig.config.real_count > 0 && cConfig.config.reals != nil { - cRealsSlice := unsafe.Slice( - cConfig.config.reals, - cConfig.config.real_count, - ) - config.Reals = make([]RealConfig, cConfig.config.real_count) - for i := range config.Reals { - relative := cToGoRelativeRealIdentifier(cRealsSlice[i].real) - isV4 := relative.Addr.Is4() - config.Reals[i] = RealConfig{ - Identifier: relative, - Src: cToGoNet(cRealsSlice[i].config.src, isV4), - Weight: uint16(cRealsSlice[i].config.weight), - } - } - } else { - config.Reals = []RealConfig{} - } - - // Convert allowed sources - if cConfig.config.allowed_src_count > 0 && - cConfig.config.allowed_src != nil { - cAllowedSlice := unsafe.Slice( - cConfig.config.allowed_src, - cConfig.config.allowed_src_count, - ) - config.AllowedSources = make( - []AllowedSources, - cConfig.config.allowed_src_count, - ) - for i := range config.AllowedSources { - // Determine if IPv4 or IPv6 from the VS identifier address - isV4 := config.Identifier.Addr.Is4() - - // Convert networks array - if cAllowedSlice[i].nets_count > 0 && cAllowedSlice[i].nets != nil { - cNetsSlice := unsafe.Slice( - cAllowedSlice[i].nets, - cAllowedSlice[i].nets_count, - ) - config.AllowedSources[i].Nets = make( - []xnetip.NetWithMask, - cAllowedSlice[i].nets_count, - ) - for j := range config.AllowedSources[i].Nets { - net := cToGoNet(cNetsSlice[j], isV4) - config.AllowedSources[i].Nets[j] = net - } - } else { - config.AllowedSources[i].Nets = []xnetip.NetWithMask{} - } - - // Convert port ranges - if cAllowedSlice[i].port_ranges_count > 0 && - cAllowedSlice[i].port_ranges != nil { - cPortRangesSlice := unsafe.Slice( - cAllowedSlice[i].port_ranges, - cAllowedSlice[i].port_ranges_count, - ) - config.AllowedSources[i].PortRanges = make( - []PortRange, - cAllowedSlice[i].port_ranges_count, - ) - for j := range config.AllowedSources[i].PortRanges { - config.AllowedSources[i].PortRanges[j] = PortRange{ - From: uint16(cPortRangesSlice[j].from), - To: uint16(cPortRangesSlice[j].to), - } - } - } - - // Get tag field - convert C *char to Go string - if cAllowedSlice[i].tag != nil { - config.AllowedSources[i].Tag = C.GoString(cAllowedSlice[i].tag) - } else { - config.AllowedSources[i].Tag = "" - } - } - } else { - config.AllowedSources = []AllowedSources{} - } - - // Convert IPv4 peers - if cConfig.config.peers_v4_count > 0 && cConfig.config.peers_v4 != nil { - cPeersV4Slice := unsafe.Slice( - cConfig.config.peers_v4, - cConfig.config.peers_v4_count, - ) - config.PeersV4 = make([]netip.Addr, cConfig.config.peers_v4_count) - for i := range config.PeersV4 { - var v4 [4]byte - C.memcpy( - unsafe.Pointer(&v4[0]), - unsafe.Pointer(&cPeersV4Slice[i].bytes[0]), - 4, - ) - config.PeersV4[i] = netip.AddrFrom4(v4) - } - } else { - config.PeersV4 = []netip.Addr{} - } - - // Convert IPv6 peers - if cConfig.config.peers_v6_count > 0 && cConfig.config.peers_v6 != nil { - cPeersV6Slice := unsafe.Slice( - cConfig.config.peers_v6, - cConfig.config.peers_v6_count, - ) - config.PeersV6 = make([]netip.Addr, cConfig.config.peers_v6_count) - for i := range config.PeersV6 { - var v6 [16]byte - C.memcpy( - unsafe.Pointer(&v6[0]), - unsafe.Pointer(&cPeersV6Slice[i].bytes[0]), - 16, - ) - config.PeersV6[i] = netip.AddrFrom16(v6) - } - } else { - config.PeersV6 = []netip.Addr{} - } - - return config -} - -// BalancerInfo conversions - -func cToGoBalancerInfo(cInfo *C.struct_balancer_info) *BalancerInfo { - info := &BalancerInfo{ - ActiveSessions: uint64(cInfo.active_sessions), - LastPacketTimestamp: time.Unix(int64(cInfo.last_packet_timestamp), 0), - } - - // Convert VS info array - if cInfo.vs_count > 0 && cInfo.vs != nil { - cVsSlice := unsafe.Slice(cInfo.vs, cInfo.vs_count) - info.Vs = make([]VsInfo, cInfo.vs_count) - for i := range info.Vs { - info.Vs[i] = *cToGoVsInfo(&cVsSlice[i]) - } - } - - return info -} - -func cToGoVsInfo(cInfo *C.struct_named_vs_info) *VsInfo { - info := &VsInfo{ - Identifier: cToGoVsIdentifier(cInfo.identifier), - LastPacketTimestamp: time.Unix(int64(cInfo.last_packet_timestamp), 0), - ActiveSessions: uint64(cInfo.active_sessions), - } - - // Convert reals array - if cInfo.reals_count > 0 && cInfo.reals != nil { - cRealsSlice := unsafe.Slice(cInfo.reals, cInfo.reals_count) - info.Reals = make([]RealInfo, cInfo.reals_count) - for i := range info.Reals { - relative := cToGoRelativeRealIdentifier(cRealsSlice[i].real) - info.Reals[i] = RealInfo{ - Dst: relative.Addr, - LastPacketTimestamp: time.Unix( - int64(cRealsSlice[i].last_packet_timestamp), - 0, - ), - ActiveSessions: uint64(cRealsSlice[i].active_sessions), - } - } - } - - return info -} - -// Sessions conversions - -func cToGoSessions(cSessions *C.struct_sessions) *Sessions { - if cSessions == nil { - return nil - } - - sessions := &Sessions{} - - // Convert sessions array - if cSessions.sessions_count > 0 && cSessions.sessions != nil { - cSessionsSlice := unsafe.Slice( - cSessions.sessions, - cSessions.sessions_count, - ) - sessions.Sessions = make([]struct { - Identifier SessionIdentifier - Info SessionInfo - }, cSessions.sessions_count) - - for i := range sessions.Sessions { - sessions.Sessions[i].Identifier = cToGoSessionIdentifier( - &cSessionsSlice[i].identifier, - ) - sessions.Sessions[i].Info = cToGoSessionInfo( - &cSessionsSlice[i].info, - ) - } - } - - return sessions -} - -func cToGoSessionIdentifier( - cID *C.struct_session_identifier, -) SessionIdentifier { - realID := cToGoRealIdentifier(cID.real) - return SessionIdentifier{ - ClientIP: cToGoNetAddr( - cID.client_ip, - realID.VsIdentifier.Addr.Is4(), - ), - ClientPort: uint16(cID.client_port), - Real: realID, - } -} - -func cToGoSessionInfo(cInfo *C.struct_session_info) SessionInfo { - return SessionInfo{ - CreateTimestamp: time.Unix(int64(cInfo.create_timestamp), 0), - LastPacketTimestamp: time.Unix(int64(cInfo.last_packet_timestamp), 0), - Timeout: time.Duration(cInfo.timeout) * time.Second, - } -} - -// BalancerStats conversions - -func cToGoBalancerStats(cStats *C.struct_balancer_stats) *BalancerStats { - if cStats == nil { - return nil - } - - stats := &BalancerStats{ - L4: cToGoL4Stats(&cStats.l4), - IcmpIpv4: cToGoIcmpStats(&cStats.icmp_ipv4), - IcmpIpv6: cToGoIcmpStats(&cStats.icmp_ipv6), - Common: cToGoCommonStats(&cStats.common), - } - - // Convert VS stats array - if cStats.vs_count > 0 && cStats.vs != nil { - cVsSlice := unsafe.Slice(cStats.vs, cStats.vs_count) - stats.Vs = make([]NamedVsStats, cStats.vs_count) - for i := range stats.Vs { - stats.Vs[i] = *cToGoNamedVsStats(&cVsSlice[i]) - } - } - - return stats -} - -func cToGoL4Stats(cStats *C.struct_balancer_l4_stats) L4Stats { - return L4Stats{ - IncomingPackets: uint64(cStats.incoming_packets), - SelectVsFailed: uint64(cStats.select_vs_failed), - InvalidPackets: uint64(cStats.invalid_packets), - SelectRealFailed: uint64(cStats.select_real_failed), - OutgoingPackets: uint64(cStats.outgoing_packets), - } -} - -func cToGoIcmpStats(cStats *C.struct_balancer_icmp_stats) IcmpStats { - return IcmpStats{ - IncomingPackets: uint64(cStats.incoming_packets), - SrcNotAllowed: uint64(cStats.src_not_allowed), - EchoResponses: uint64(cStats.echo_responses), - PayloadTooShortIP: uint64(cStats.payload_too_short_ip), - UnmatchingSrcFromOriginal: uint64(cStats.unmatching_src_from_original), - PayloadTooShortPort: uint64(cStats.payload_too_short_port), - UnexpectedTransport: uint64(cStats.unexpected_transport), - UnrecognizedVs: uint64(cStats.unrecognized_vs), - ForwardedPackets: uint64(cStats.forwarded_packets), - BroadcastedPackets: uint64(cStats.broadcasted_packets), - PacketClonesSent: uint64(cStats.packet_clones_sent), - PacketClonesReceived: uint64(cStats.packet_clones_received), - PacketCloneFailures: uint64(cStats.packet_clone_failures), - } -} - -func cToGoCommonStats(cStats *C.struct_balancer_common_stats) CommonStats { - return CommonStats{ - IncomingPackets: uint64(cStats.incoming_packets), - IncomingBytes: uint64(cStats.incoming_bytes), - UnexpectedNetworkProto: uint64(cStats.unexpected_network_proto), - DecapSuccessful: uint64(cStats.decap_successful), - DecapFailed: uint64(cStats.decap_failed), - OutgoingPackets: uint64(cStats.outgoing_packets), - OutgoingBytes: uint64(cStats.outgoing_bytes), - } -} - -func cToGoNamedVsStats(cStats *C.struct_named_vs_stats) *NamedVsStats { - if cStats == nil { - return nil - } - - stats := &NamedVsStats{ - Identifier: cToGoVsIdentifier(cStats.identifier), - Stats: VsStats{ - IncomingPackets: uint64(cStats.stats.incoming_packets), - IncomingBytes: uint64(cStats.stats.incoming_bytes), - PacketSrcNotAllowed: uint64(cStats.stats.packet_src_not_allowed), - NoReals: uint64(cStats.stats.no_reals), - OpsPackets: uint64(cStats.stats.ops_packets), - SessionTableOverflow: uint64(cStats.stats.session_table_overflow), - EchoIcmpPackets: uint64(cStats.stats.echo_icmp_packets), - ErrorIcmpPackets: uint64(cStats.stats.error_icmp_packets), - RealIsDisabled: uint64(cStats.stats.real_is_disabled), - RealIsRemoved: uint64(cStats.stats.real_is_removed), - NotRescheduledPackets: uint64( - cStats.stats.not_rescheduled_packets, - ), - BroadcastedIcmpPackets: uint64( - cStats.stats.broadcasted_icmp_packets, - ), - CreatedSessions: uint64(cStats.stats.created_sessions), - OutgoingPackets: uint64(cStats.stats.outgoing_packets), - OutgoingBytes: uint64(cStats.stats.outgoing_bytes), - }, - } - - // Convert reals stats array - if cStats.reals_count > 0 { - cRealsSlice := unsafe.Slice(cStats.reals, cStats.reals_count) - stats.Reals = make([]struct { - Dst netip.Addr - Stats RealStats - }, cStats.reals_count) - - for i := range stats.Reals { - relative := cToGoRelativeRealIdentifier(cRealsSlice[i].real) - stats.Reals[i].Dst = relative.Addr - stats.Reals[i].Stats = RealStats{ - PacketsRealDisabled: uint64( - cRealsSlice[i].stats.packets_real_disabled, - ), - OpsPackets: uint64(cRealsSlice[i].stats.ops_packets), - ErrorIcmpPackets: uint64( - cRealsSlice[i].stats.error_icmp_packets, - ), - CreatedSessions: uint64( - cRealsSlice[i].stats.created_sessions, - ), - Packets: uint64(cRealsSlice[i].stats.packets), - Bytes: uint64(cRealsSlice[i].stats.bytes), - } - } - } - - // Convert allowed sources stats array - if cStats.allowed_sources_count > 0 && cStats.allowed_sources != nil { - cAllowedSourcesSlice := unsafe.Slice( - cStats.allowed_sources, - cStats.allowed_sources_count, - ) - stats.AllowedSources = make([]struct { - Tag string - Passes uint64 - }, cStats.allowed_sources_count) - - for i := range stats.AllowedSources { - // Convert C *char to Go string - if cAllowedSourcesSlice[i].tag != nil { - stats.AllowedSources[i].Tag = C.GoString(cAllowedSourcesSlice[i].tag) - } else { - stats.AllowedSources[i].Tag = "" - } - stats.AllowedSources[i].Passes = uint64( - cAllowedSourcesSlice[i].passes, - ) - } - } - - return stats -} - -// BalancerGraph conversions - -func cToGoBalancerGraph(cGraph *C.struct_balancer_graph) *BalancerGraph { - if cGraph == nil { - return nil - } - - graph := &BalancerGraph{} - - // Convert VS array - if cGraph.vs_count > 0 && cGraph.vs != nil { - cVsSlice := unsafe.Slice(cGraph.vs, cGraph.vs_count) - graph.VirtualServices = make([]GraphVs, cGraph.vs_count) - for i := range graph.VirtualServices { - graph.VirtualServices[i] = *cToGoGraphVs(&cVsSlice[i]) - } - } - - return graph -} - -func cToGoGraphVs(cVs *C.struct_graph_vs) *GraphVs { - vs := &GraphVs{ - Identifier: cToGoVsIdentifier(cVs.identifier), - } - - // Convert reals array - if cVs.real_count > 0 && cVs.reals != nil { - cRealsSlice := unsafe.Slice(cVs.reals, cVs.real_count) - vs.Reals = make([]GraphReal, cVs.real_count) - for i := range vs.Reals { - vs.Reals[i] = cToGoGraphReal(&cRealsSlice[i]) - } - } - - return vs -} - -func cToGoGraphReal(cReal *C.struct_graph_real) GraphReal { - return GraphReal{ - Identifier: cToGoRelativeRealIdentifier(cReal.identifier), - Weight: uint16(cReal.weight), - Enabled: bool(cReal.enabled), - } -} - -// UpdateInfo conversions - -func cToGoUpdateInfo(cInfo *C.struct_balancer_update_info) *UpdateInfo { - if cInfo == nil { - return nil - } - - info := &UpdateInfo{ - VsIpv4MatcherReused: cInfo.vs_ipv4_matcher_reused != 0, - VsIpv6MatcherReused: cInfo.vs_ipv6_matcher_reused != 0, - } - - // Convert ACL reused VS identifiers array - if cInfo.vs_acl_reused_count > 0 && cInfo.vs_acl_reused != nil { - cVsSlice := unsafe.Slice(cInfo.vs_acl_reused, cInfo.vs_acl_reused_count) - info.ACLReusedVs = make([]VsIdentifier, cInfo.vs_acl_reused_count) - for i := range info.ACLReusedVs { - info.ACLReusedVs[i] = cToGoVsIdentifier(cVsSlice[i]) - } - } else { - info.ACLReusedVs = []VsIdentifier{} - } - - return info -} - -// AgentInspect conversions - -func cToGoAgentInspect(cInspect *C.struct_agent_inspect) *AgentInspect { - if cInspect == nil { - return nil - } - - inspect := &AgentInspect{ - MemoryLimit: uint64(cInspect.memory_limit), - MemoryUsage: uint64(cInspect.memory_usage), - } - - // Convert balancers array - if cInspect.balancer_count > 0 && cInspect.balancers != nil { - cBalancersSlice := unsafe.Slice( - cInspect.balancers, - cInspect.balancer_count, - ) - inspect.Balancers = make( - []NamedBalancerInspect, - cInspect.balancer_count, - ) - for i := range inspect.Balancers { - inspect.Balancers[i] = *cToGoNamedBalancerInspect( - &cBalancersSlice[i], - ) - } - } - - return inspect -} - -func cToGoNamedBalancerInspect( - cInspect *C.struct_named_balancer_inspect, -) *NamedBalancerInspect { - return &NamedBalancerInspect{ - Name: C.GoString(cInspect.name), - Inspect: *cToGoBalancerInspect(&cInspect.inspect), - } -} - -func cToGoBalancerInspect( - cInspect *C.struct_balancer_inspect, -) *BalancerInspect { - return &BalancerInspect{ - PacketHandler: *cToGoPacketHandlerInspect( - &cInspect.packet_handler_inspect, - ), - State: *cToGoStateInspect(&cInspect.state_inspect), - OtherUsage: uint64(cInspect.other_usage), - TotalUsage: uint64(cInspect.total_usage), - } -} - -func cToGoPacketHandlerInspect( - cInspect *C.struct_packet_handler_inspect, -) *PacketHandlerInspect { - return &PacketHandlerInspect{ - VsIpv4Inspect: *cToGoPacketHandlerVsInspect( - &cInspect.vs_ipv4_inspect, - ), - VsIpv6Inspect: *cToGoPacketHandlerVsInspect( - &cInspect.vs_ipv6_inspect, - ), - SummaryVsUsage: uint64(cInspect.summary_vs_usage), - VsIndexUsage: uint64(cInspect.vs_index_usage), - RealsIndexUsage: uint64(cInspect.reals_index_usage), - CountersUsage: uint64(cInspect.counters_usage), - DecapUsage: uint64(cInspect.decap_usage), - TotalUsage: uint64(cInspect.total_usage), - } -} - -func cToGoPacketHandlerVsInspect( - cInspect *C.struct_packet_handler_vs_inspect, -) *PacketHandlerVsInspect { - inspect := &PacketHandlerVsInspect{ - MatcherUsage: uint64(cInspect.matcher_usage), - SummaryVsUsage: uint64(cInspect.summary_vs_usage), - AnnounceUsage: uint64(cInspect.announce_usage), - IndexUsage: uint64(cInspect.index_usage), - TotalUsage: uint64(cInspect.total_usage), - } - - // Convert VS inspects array - if cInspect.vs_count > 0 && cInspect.vs_inspects != nil { - cVsSlice := unsafe.Slice(cInspect.vs_inspects, cInspect.vs_count) - inspect.VsInspects = make([]NamedVsInspect, cInspect.vs_count) - for i := range inspect.VsInspects { - inspect.VsInspects[i] = *cToGoNamedVsInspect(&cVsSlice[i]) - } - } - - return inspect -} - -func cToGoNamedVsInspect(cInspect *C.struct_named_vs_inspect) *NamedVsInspect { - return &NamedVsInspect{ - Identifier: cToGoVsIdentifier(cInspect.identifier), - Inspect: *cToGoVsInspect(&cInspect.inspect), - } -} - -func cToGoVsInspect(cInspect *C.struct_vs_inspect) *VsInspect { - return &VsInspect{ - ACLUsage: uint64(cInspect.acl_usage), - RingUsage: uint64(cInspect.ring_usage), - CountersUsage: uint64(cInspect.counters_usage), - RealsUsage: *cToGoRealsUsage(&cInspect.reals_usage), - OtherUsage: uint64(cInspect.other_usage), - TotalUsage: uint64(cInspect.total_usage), - } -} - -func cToGoRealsUsage(cUsage *C.struct_reals_usage) *RealsUsage { - return &RealsUsage{ - CountersUsage: uint64(cUsage.counters_usage), - DataUsage: uint64(cUsage.data_usage), - TotalUsage: uint64(cUsage.total_usage), - } -} - -func cToGoStateInspect(cInspect *C.struct_state_inspect) *StateInspect { - return &StateInspect{ - SessionTableUsage: uint64(cInspect.session_table_usage), - TotalUsage: uint64(cInspect.total_usage), - } -} diff --git a/modules/balancer/agent/go/ffi/manager.go b/modules/balancer/agent/go/ffi/manager.go deleted file mode 100644 index 132b40bde..000000000 --- a/modules/balancer/agent/go/ffi/manager.go +++ /dev/null @@ -1,243 +0,0 @@ -package ffi - -// BalancerManager FFI wrapper providing C interop for manager operations including -// configuration updates, real server management, session table resizing, statistics -// collection, and graph topology retrieval with proper error handling. - -/* -#cgo CFLAGS: -I../../ -I../../../../../ -#cgo LDFLAGS: -L../../../../../build/modules/balancer/agent -lbalancer_agent -L../../../../../build/modules/balancer/controlplane/api -lbalancer_cp -L../../../../../build/modules/balancer/controlplane/handler -lbalancer_packet_handler -L../../../../../build/modules/balancer/controlplane/state -lbalancer_state -lbalancer_packet_handler -L../../../../../build/filter -lfilter_compiler -#include "manager.h" -#include "modules/balancer/controlplane/api/balancer.h" -#include -*/ -import "C" - -import ( - "fmt" - "time" - "unsafe" -) - -var ( - DontUpdateRealWeight uint16 = uint16(C.DONT_UPDATE_REAL_WEIGHT) - DontUpdateRealEnabled uint8 = uint8(C.DONT_UPDATE_REAL_ENABLED) - MaxRealWeight uint16 = uint16(C.MAX_REAL_WEIGHT) -) - -// BalancerManager wraps a C balancer_manager handle -type BalancerManager struct { - handle *C.struct_balancer_manager -} - -// Name returns the name of the balancer manager -func (m *BalancerManager) Name() string { - cName := C.balancer_manager_name(m.handle) - return C.GoString(cName) -} - -// Config retrieves the current configuration of the manager -func (m *BalancerManager) Config() *BalancerManagerConfig { - var cConfig C.struct_balancer_manager_config - C.balancer_manager_config(m.handle, &cConfig) - return cToGoBalancerManagerConfig(&cConfig) -} - -// Update updates the manager's configuration and returns update metadata -func (m *BalancerManager) Update( - config *BalancerManagerConfig, - now time.Time, -) (*UpdateInfo, error) { - if config == nil { - return nil, fmt.Errorf("config is nil") - } - - cConfig, err := goToCBalancerManagerConfig(config) - if err != nil { - return nil, fmt.Errorf("failed to convert config: %w", err) - } - defer freeCBalancerManagerConfig(cConfig) - - // Allocate C update_info structure - var cUpdateInfo C.struct_balancer_update_info - - cNow := C.uint32_t(now.Unix()) - - if C.balancer_manager_update(m.handle, cConfig, &cUpdateInfo, cNow) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return nil, fmt.Errorf("failed to perform update: %s", errMsg) - } - - // Convert C update_info to Go, copying all data - updateInfo := cToGoUpdateInfo(&cUpdateInfo) - - // Free C allocations from update_info - C.balancer_update_info_free(&cUpdateInfo) - - return updateInfo, nil -} - -// UpdateReals applies a batch of real server updates -func (m *BalancerManager) UpdateReals(updates []RealUpdate) error { - if len(updates) == 0 { - return nil - } - - // Convert Go updates to C updates - cUpdates := make([]C.struct_real_update, len(updates)) - for i, update := range updates { - cUpdates[i] = goToCRealUpdate(update) - } - - if C.balancer_manager_update_reals( - m.handle, - C.size_t(len(updates)), - &cUpdates[0], - ) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return fmt.Errorf("%s", errMsg) - } - - return nil -} - -// UpdateRealsWlc applies a batch of real server weight updates for WLC algorithm -// This method only updates state weights, not config weights, preserving the -// original static weights for WLC calculations -func (m *BalancerManager) UpdateRealsWlc(updates []RealUpdate) error { - if len(updates) == 0 { - return nil - } - - // Convert Go updates to C updates - cUpdates := make([]C.struct_real_update, len(updates)) - for i, update := range updates { - cUpdates[i] = goToCRealUpdate(update) - } - - if C.balancer_manager_update_reals_wlc( - m.handle, - C.size_t(len(updates)), - &cUpdates[0], - ) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return fmt.Errorf("%s", errMsg) - } - - return nil -} - -// ResizeSessionTable resizes the session table used by the manager's balancer -func (m *BalancerManager) ResizeSessionTable( - newSize uint, - now time.Time, -) error { - cNow := C.uint32_t(now.Unix()) - - if C.balancer_manager_resize_session_table( - m.handle, - C.size_t(newSize), - cNow, - ) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return fmt.Errorf("%s", errMsg) - } - - return nil -} - -// Info queries aggregated balancer information from the manager -func (m *BalancerManager) Info(now time.Time) (*BalancerInfo, error) { - var cInfo C.struct_balancer_info - cNow := C.uint32_t(now.Unix()) - - if C.balancer_manager_info(m.handle, &cInfo, cNow) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return nil, fmt.Errorf("%s", errMsg) - } - - // Convert C info to Go, copying all data - info := cToGoBalancerInfo(&cInfo) - - // Free C allocations - C.balancer_manager_info_free(&cInfo) - - return info, nil -} - -// ActiveSessions queries active session counters from the manager without a full info refresh. -func (m *BalancerManager) ActiveSessions() *BalancerInfo { - var cInfo C.struct_balancer_info - - C.balancer_manager_active_sessions(m.handle, &cInfo) - - return cToGoBalancerInfo(&cInfo) -} - -// Sessions enumerates active sessions tracked by the manager's balancer -func (m *BalancerManager) Sessions(now time.Time) *Sessions { - var cSessions C.struct_sessions - cNow := C.uint32_t(now.Unix()) - - C.balancer_manager_sessions(m.handle, &cSessions, cNow) - - // Convert C sessions to Go, copying all data - sessions := cToGoSessions(&cSessions) - - // Free C allocations - C.balancer_manager_sessions_free(&cSessions) - - return sessions -} - -// Stats reads balancer statistics from the manager -func (m *BalancerManager) Stats(ref *PacketHandlerRef) (*BalancerStats, error) { - if ref == nil { - return nil, fmt.Errorf("ref is nil") - } - - var cStats C.struct_balancer_stats - - cRef := goToCPacketHandlerRef(ref) - defer freeCPacketHandlerRef(cRef) - - if C.balancer_manager_stats(m.handle, &cStats, cRef) != 0 { - cErr := C.balancer_manager_take_error(m.handle) - errMsg := C.GoString(cErr) - C.free(unsafe.Pointer(cErr)) - return nil, fmt.Errorf("%s", errMsg) - } - - // Convert C stats to Go, copying all data - stats := cToGoBalancerStats(&cStats) - - // Free C allocations - C.balancer_manager_stats_free(&cStats) - - return stats, nil -} - -// Graph retrieves graph representation of the manager's balancer topology -func (m *BalancerManager) Graph() *BalancerGraph { - var cGraph C.struct_balancer_graph - - C.balancer_manager_graph(m.handle, &cGraph) - - // Convert C graph to Go, copying all data - graph := cToGoBalancerGraph(&cGraph) - - // Free C allocations - C.balancer_manager_graph_free(&cGraph) - - return graph -} diff --git a/modules/balancer/agent/go/ffi/manager_test.go b/modules/balancer/agent/go/ffi/manager_test.go deleted file mode 100644 index 7df5334d3..000000000 --- a/modules/balancer/agent/go/ffi/manager_test.go +++ /dev/null @@ -1,1247 +0,0 @@ -package ffi - -import ( - "net/netip" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xnetip" - "github.com/yanet-platform/yanet2/controlplane/ffi" - mock "github.com/yanet-platform/yanet2/mock/go" -) - -var ( - deviceName = "eth0" - pipelineName = "pipeline0" - functionName = "function0" - chainName = "chain0" - balancerName = "balancer0" -) - -func TestManager(t *testing.T) { - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 27, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: deviceName, - }, - }, - }) - require.NoError(t, err, "failed to create mock") - require.NotNil(t, m, "mock is nil") - - agent, err := NewBalancerAgent(m.SharedMemory(), 1<<25) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - managerConfig := BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 10, - TCPSyn: 20, - TCPFin: 15, - TCP: 100, - UDP: 11, - Default: 19, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.12.13.213"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: true}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.12.13.213"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.0.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.12.13.214"), - Port: 8080, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.1.0/24"), - ), - Weight: 150, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.12.13.215"), - Port: 8081, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.16.2.0/24"), - ), - Weight: 200, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.1.1.1/24"), - )}, - }, - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.12.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("12.1.1.3"), - netip.MustParseAddr("12.1.1.4"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::2"), - netip.MustParseAddr("2001:db8::3"), - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.20.30.40"), - Port: 443, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: false}, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.20.30.41"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.17.0.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.20.30.42"), - Port: 8443, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.17.1.0/24"), - ), - Weight: 100, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("192.168.0.0/16"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("12.2.2.3"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::10"), - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.50.60.70"), - Port: 53, - TransportProto: VsTransportProtoUDP, - }, - Flags: VsFlags{FixMSS: false}, - Scheduler: VsSchedulerSourceHash, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.50.60.71"), - Port: 5353, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.18.0.0/24"), - ), - Weight: 50, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.50.60.72"), - Port: 5353, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.18.1.0/24"), - ), - Weight: 75, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.50.60.73"), - Port: 5353, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.18.2.0/24"), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.50.60.74"), - Port: 5353, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.18.3.0/24"), - ), - Weight: 125, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr("10.50.60.75"), - Port: 5354, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix("172.18.4.0/24"), - ), - Weight: 150, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{xnetip.FromPrefix( - netip.MustParsePrefix("0.0.0.0/0"), - )}, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("12.3.3.3"), - netip.MustParseAddr("12.3.3.4"), - netip.MustParseAddr("12.3.3.5"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::20"), - netip.MustParseAddr("2001:db8::21"), - }, - }, - }, - SourceV4: netip.MustParseAddr("10.12.13.213"), - SourceV6: netip.MustParseAddr("2001:db8::1"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("10.13.11.215"), - netip.MustParseAddr("10.14.11.214"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::3"), - netip.MustParseAddr("2001:db8::2"), - }, - }, - State: StateConfig{ - TableCapacity: 1000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1024, - Vs: []uint32{0, 1, 2}, - }, - RefreshPeriod: time.Millisecond * 10, - MaxLoadFactor: 0.75, - } - - manager, err := agent.NewManager(balancerName, &managerConfig) - require.NoError(t, err, "failed to create balancer manager") - require.NotNil(t, manager, "balancer manager is nil") - - // Use mock's current time for all operations - now := m.CurrentTime() - - // Test 1: Verify manager was created successfully - t.Run("ManagerCreation", func(t *testing.T) { - require.NotNil(t, manager, "manager should not be nil") - - // Verify we can get the manager name - name := manager.Name() - require.Equal(t, balancerName, name, "manager name should match") - }) - - t.Run("SetupControlplane", func(t *testing.T) { - agent, err := m.SharedMemory().AgentReattach("bootstrap", 0, 1<<20) - require.NoError(t, err, "failed to attach bootstrap agent") - { - functionConfig := ffi.FunctionConfig{ - Name: functionName, - Chains: []ffi.FunctionChainConfig{ - { - Weight: 1, - Chain: ffi.ChainConfig{ - Name: chainName, - Modules: []ffi.ChainModuleConfig{ - { - Type: "balancer", - Name: balancerName, - }, - }, - }, - }, - }, - } - - if err := agent.UpdateFunction(functionConfig); err != nil { - t.Fatalf("failed to update functions: %v", err) - } - } - - // update pipelines - { - inputPipelineConfig := ffi.PipelineConfig{ - Name: pipelineName, - Functions: []string{functionName}, - } - - dummyPipelineConfig := ffi.PipelineConfig{ - Name: "dummy", - Functions: []string{}, - } - - if err := agent.UpdatePipeline(inputPipelineConfig); err != nil { - t.Fatalf("failed to update pipeline: %v", err) - } - - if err := agent.UpdatePipeline(dummyPipelineConfig); err != nil { - t.Fatalf("failed to update pipeline: %v", err) - } - } - - // update devices - { - deviceConfig := ffi.DeviceConfig{ - Name: deviceName, - Input: []ffi.DevicePipelineConfig{ - { - Name: pipelineName, - Weight: 1, - }, - }, - Output: []ffi.DevicePipelineConfig{ - { - Name: "dummy", - Weight: 1, - }, - }, - } - - if err := agent.UpdatePlainDevices([]ffi.DeviceConfig{deviceConfig}); err != nil { - t.Fatalf("failed to update pipelines: %v", err) - } - } - }) - - // Test 2: Get initial configuration - t.Run("GetInitialConfig", func(t *testing.T) { - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.Equal(t, managerConfig, *config, "config should match") - }) - - // Test 3: Get initial graph - t.Run("GetInitialGraph", func(t *testing.T) { - graph := manager.Graph() - require.NotNil(t, graph, "graph should not be nil") - require.Equal( - t, - 3, - len(graph.VirtualServices), - "graph should have 3 virtual services", - ) - - // Verify first VS has 3 reals - require.Equal( - t, - 3, - len(graph.VirtualServices[0].Reals), - "first VS should have 3 reals", - ) - // Verify second VS has 2 reals - require.Equal( - t, - 2, - len(graph.VirtualServices[1].Reals), - "second VS should have 2 reals", - ) - // Verify third VS has 5 reals - require.Equal( - t, - 5, - len(graph.VirtualServices[2].Reals), - "third VS should have 5 reals", - ) - - // Verify reals match config - match by identifier since order may differ - for _, configVs := range managerConfig.Balancer.Handler.VirtualServices { - // Find matching VS in graph by identifier - var graphVs *GraphVs - for i := range graph.VirtualServices { - if graph.VirtualServices[i].Identifier.Addr.Compare( - configVs.Identifier.Addr, - ) == 0 && - graph.VirtualServices[i].Identifier.Port == configVs.Identifier.Port && - graph.VirtualServices[i].Identifier.TransportProto == configVs.Identifier.TransportProto { - graphVs = &graph.VirtualServices[i] - break - } - } - require.NotNil(t, graphVs, "VS %s:%d should exist in graph", - configVs.Identifier.Addr, configVs.Identifier.Port) - - require.Equal( - t, - len(configVs.Reals), - len(graphVs.Reals), - "VS %s:%d should have same number of reals in graph as in config", - configVs.Identifier.Addr, - configVs.Identifier.Port, - ) - - for _, configReal := range configVs.Reals { - // Find matching real in graph by identifier - var graphReal *GraphReal - for i := range graphVs.Reals { - if graphVs.Reals[i].Identifier.Addr.Compare( - configReal.Identifier.Addr, - ) == 0 && - graphVs.Reals[i].Identifier.Port == configReal.Identifier.Port { - graphReal = &graphVs.Reals[i] - break - } - } - require.NotNil(t, graphReal, "Real %s:%d should exist in graph", - configReal.Identifier.Addr, configReal.Identifier.Port) - - // Check weight matches - require.Equal(t, configReal.Weight, graphReal.Weight, - "Real %s:%d weight should match config", - configReal.Identifier.Addr, configReal.Identifier.Port) - - // Check all reals are initially disabled - require.False(t, graphReal.Enabled, - "Real %s:%d should be initially disabled", - configReal.Identifier.Addr, configReal.Identifier.Port) - } - } - }) - - // Test 4: Get initial info - t.Run("GetInitialInfo", func(t *testing.T) { - info, err := manager.Info(now) - require.NoError(t, err, "failed to get info") - require.NotNil(t, info, "info should not be nil") - // Skip detailed checks if info is not properly populated (C code issue) - if len(info.Vs) == 0 { - t.Skip("Info not properly populated - skipping detailed checks") - return - } - require.Equal(t, 3, len(info.Vs), "info should have 3 virtual services") - - // Check info variables are zeroes initially - require.Equal( - t, - uint64(0), - info.ActiveSessions, - "active sessions should be zero initially", - ) - require.True( - t, - info.LastPacketTimestamp.IsZero() || - info.LastPacketTimestamp.Unix() == 0, - "last packet timestamp should be zero initially", - ) - - // Check info topology matches config topology - require.Equal( - t, - len(managerConfig.Balancer.Handler.VirtualServices), - len(info.Vs), - "info should have same number of virtual services as config", - ) - - for vsIdx, configVs := range managerConfig.Balancer.Handler.VirtualServices { - infoVs := info.Vs[vsIdx] - - // Check VS identifier matches - require.Equal(t, configVs.Identifier.Addr, infoVs.Identifier.Addr, - "VS %d address should match in info", vsIdx) - require.Equal(t, configVs.Identifier.Port, infoVs.Identifier.Port, - "VS %d port should match in info", vsIdx) - require.Equal( - t, - configVs.Identifier.TransportProto, - infoVs.Identifier.TransportProto, - "VS %d transport proto should match in info", - vsIdx, - ) - - // Check VS info variables are zeroes - require.Equal(t, uint64(0), infoVs.ActiveSessions, - "VS %d active sessions should be zero initially", vsIdx) - require.True( - t, - infoVs.LastPacketTimestamp.IsZero() || - infoVs.LastPacketTimestamp.Unix() == 0, - "VS %d last packet timestamp should be zero initially", - vsIdx, - ) - - // Check reals topology matches - require.Equal( - t, - len(configVs.Reals), - len(infoVs.Reals), - "VS %d should have same number of reals in info as in config", - vsIdx, - ) - - for realIdx, configReal := range configVs.Reals { - infoReal := infoVs.Reals[realIdx] - - // Check real identifier matches - require.Equal( - t, - configReal.Identifier.Addr, - infoReal.Dst, - "VS %d Real %d address should match in info", - vsIdx, - realIdx, - ) - - // Check real info variables are zeroes - require.Equal( - t, - uint64(0), - infoReal.ActiveSessions, - "VS %d Real %d active sessions should be zero initially", - vsIdx, - realIdx, - ) - require.True( - t, - infoReal.LastPacketTimestamp.IsZero() || - infoReal.LastPacketTimestamp.Unix() == 0, - "VS %d Real %d last packet timestamp should be zero initially", - vsIdx, - realIdx, - ) - } - } - }) - - ref := PacketHandlerRef{ - Device: &deviceName, - Pipeline: &pipelineName, - Function: &functionName, - Chain: &chainName, - } - - // Test 5: Get initial stats - t.Run("GetInitialStats", func(t *testing.T) { - stats, err := manager.Stats(&ref) - require.NoError(t, err, "failed to get stats") - require.NotNil(t, stats, "stats should not be nil") - require.Equal( - t, - 3, - len(stats.Vs), - "stats should have 3 virtual services", - ) - - // Check common stats are zeroes - require.Equal( - t, - uint64(0), - stats.Common.IncomingPackets, - "incoming packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.IncomingBytes, - "incoming bytes should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.OutgoingPackets, - "outgoing packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.OutgoingBytes, - "outgoing bytes should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.UnexpectedNetworkProto, - "unexpected network proto should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.DecapSuccessful, - "decap successful should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.DecapFailed, - "decap failed should be zero", - ) - - // Check L4 stats are zeroes - require.Equal( - t, - uint64(0), - stats.L4.IncomingPackets, - "L4 incoming packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.L4.SelectVsFailed, - "L4 select VS failed should be zero", - ) - require.Equal( - t, - uint64(0), - stats.L4.InvalidPackets, - "L4 invalid packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.L4.SelectRealFailed, - "L4 select real failed should be zero", - ) - require.Equal( - t, - uint64(0), - stats.L4.OutgoingPackets, - "L4 outgoing packets should be zero", - ) - - // Check ICMP stats are zeroes - require.Equal( - t, - uint64(0), - stats.IcmpIpv4.IncomingPackets, - "ICMP IPv4 incoming packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.IcmpIpv6.IncomingPackets, - "ICMP IPv6 incoming packets should be zero", - ) - - // Check stats topology matches config topology - require.Equal( - t, - len(managerConfig.Balancer.Handler.VirtualServices), - len(stats.Vs), - "stats should have same number of virtual services as config", - ) - - for vsIdx, configVs := range managerConfig.Balancer.Handler.VirtualServices { - statsVs := stats.Vs[vsIdx] - - // Check VS identifier matches - require.Equal(t, configVs.Identifier.Addr, statsVs.Identifier.Addr, - "VS %d address should match in stats", vsIdx) - require.Equal(t, configVs.Identifier.Port, statsVs.Identifier.Port, - "VS %d port should match in stats", vsIdx) - require.Equal( - t, - configVs.Identifier.TransportProto, - statsVs.Identifier.TransportProto, - "VS %d transport proto should match in stats", - vsIdx, - ) - - // Check VS stats are zeroes (skip bytes check as it may have uninitialized data) - require.Equal(t, uint64(0), statsVs.Stats.IncomingPackets, - "VS %d incoming packets should be zero", vsIdx) - // Note: IncomingBytes may not be zero due to uninitialized memory or actual data - require.Equal(t, uint64(0), statsVs.Stats.OutgoingPackets, - "VS %d outgoing packets should be zero", vsIdx) - require.Equal(t, uint64(0), statsVs.Stats.OutgoingBytes, - "VS %d outgoing bytes should be zero", vsIdx) - require.Equal(t, uint64(0), statsVs.Stats.CreatedSessions, - "VS %d created sessions should be zero", vsIdx) - - // Check reals topology matches - require.Equal( - t, - len(configVs.Reals), - len(statsVs.Reals), - "VS %d should have same number of reals in stats as in config", - vsIdx, - ) - - for realIdx, configReal := range configVs.Reals { - statsReal := statsVs.Reals[realIdx] - - // Check real identifier matches - require.Equal( - t, - configReal.Identifier.Addr, - statsReal.Dst, - "VS %d Real %d address should match in stats", - vsIdx, - realIdx, - ) - - // Check real stats are zeroes - require.Equal(t, uint64(0), statsReal.Stats.Packets, - "VS %d Real %d packets should be zero", vsIdx, realIdx) - require.Equal(t, uint64(0), statsReal.Stats.Bytes, - "VS %d Real %d bytes should be zero", vsIdx, realIdx) - require.Equal( - t, - uint64(0), - statsReal.Stats.CreatedSessions, - "VS %d Real %d created sessions should be zero", - vsIdx, - realIdx, - ) - } - } - }) - - // Test 6: Get initial sessions - t.Run("GetInitialSessions", func(t *testing.T) { - sessions := manager.Sessions(now) - require.NotNil(t, sessions, "sessions should not be nil") - // Initially should have no sessions - require.Equal( - t, - 0, - len(sessions.Sessions), - "should have no sessions initially", - ) - }) - - // Test 7: Update individual reals using UpdateReals - t.Run("UpdateRealsMethod", func(t *testing.T) { - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[0].Identifier, - }, - Weight: 250, - Enabled: 1, - }, - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[1].Identifier, - }, - Weight: 300, - Enabled: 1, - }, - } - - err := manager.UpdateReals(updates) - require.NoError(t, err, "failed to update reals") - - // Verify the updates - graph := manager.Graph() - require.Equal( - t, - uint16(250), - graph.VirtualServices[0].Reals[0].Weight, - "first real weight should be 250", - ) - require.Equal( - t, - uint16(300), - graph.VirtualServices[0].Reals[1].Weight, - "second real weight should be 300", - ) - - // Test updating only weight (enabled unchanged) - t.Run("UpdateWeightOnly", func(t *testing.T) { - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[0].Identifier, - }, - Weight: 350, - Enabled: DontUpdateRealEnabled, // Don't change enabled status - }, - } - - err := manager.UpdateReals(updates) - require.NoError(t, err, "failed to update real weight only") - - graph := manager.Graph() - require.Equal( - t, - uint16(350), - graph.VirtualServices[0].Reals[0].Weight, - "weight should be updated to 350", - ) - require.True(t, graph.VirtualServices[0].Reals[0].Enabled, - "enabled status should remain true") - }) - - // Test updating only enabled status (weight unchanged) - t.Run("UpdateEnabledOnly", func(t *testing.T) { - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[1].Identifier, - }, - Weight: DontUpdateRealWeight, // Don't change weight - Enabled: 0, // Disable the real - }, - } - - err := manager.UpdateReals(updates) - require.NoError(t, err, "failed to update real enabled only") - - graph := manager.Graph() - require.Equal( - t, - uint16(300), - graph.VirtualServices[0].Reals[1].Weight, - "weight should remain 300", - ) - require.False(t, graph.VirtualServices[0].Reals[1].Enabled, - "enabled status should be false") - }) - - // Test updating both weight and enabled - t.Run("UpdateBoth", func(t *testing.T) { - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[2].Identifier, - }, - Weight: 400, - Enabled: 1, - }, - } - - err := manager.UpdateReals(updates) - require.NoError(t, err, "failed to update real weight and enabled") - - graph := manager.Graph() - require.Equal( - t, - uint16(400), - graph.VirtualServices[0].Reals[2].Weight, - "weight should be updated to 400", - ) - require.True(t, graph.VirtualServices[0].Reals[2].Enabled, - "enabled status should be true") - }) - - // Test DontUpdateRealWeight constant - t.Run("DontUpdateRealWeightConstant", func(t *testing.T) { - // Get current weight - graphBefore := manager.Graph() - weightBefore := graphBefore.VirtualServices[0].Reals[0].Weight - - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[0].Identifier, - }, - Weight: DontUpdateRealWeight, - Enabled: 0, // Disable - }, - } - - err := manager.UpdateReals(updates) - require.NoError( - t, - err, - "failed to update with DontUpdateRealWeight", - ) - - graphAfter := manager.Graph() - require.Equal( - t, - weightBefore, - graphAfter.VirtualServices[0].Reals[0].Weight, - "weight should not change when using DontUpdateRealWeight", - ) - require.False(t, graphAfter.VirtualServices[0].Reals[0].Enabled, - "enabled status should be updated") - }) - - // Test DontUpdateRealEnabled constant - t.Run("DontUpdateRealEnabledConstant", func(t *testing.T) { - // Get current enabled status - graphBefore := manager.Graph() - enabledBefore := graphBefore.VirtualServices[0].Reals[0].Enabled - - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[0].Identifier, - }, - Weight: 500, - Enabled: DontUpdateRealEnabled, - }, - } - - err := manager.UpdateReals(updates) - require.NoError( - t, - err, - "failed to update with DontUpdateRealEnabled", - ) - - graphAfter := manager.Graph() - require.Equal( - t, - uint16(500), - graphAfter.VirtualServices[0].Reals[0].Weight, - "weight should be updated", - ) - require.Equal( - t, - enabledBefore, - graphAfter.VirtualServices[0].Reals[0].Enabled, - "enabled status should not change when using DontUpdateRealEnabled", - ) - }) - }) - - // Test 8: Resize session table - t.Run("ResizeSessionTable", func(t *testing.T) { - err := manager.ResizeSessionTable(1050, now) - require.NoError(t, err, "failed to resize session table") - - // Verify the resize - config := manager.Config() - require.GreaterOrEqual( - t, - config.Balancer.State.TableCapacity, - uint(1050), - "table capacity should be at least 1050", - ) - }) - - // Test updating manager with completely new config - t.Run("UpdateWithNewConfig", func(t *testing.T) { - // Create a completely new configuration - newConfig := BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - SessionsTimeouts: SessionsTimeouts{ - TCPSynAck: 15, - TCPSyn: 25, - TCPFin: 20, - TCP: 100, - UDP: 15, - Default: 25, - }, - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.100", - ), - Port: 8080, - TransportProto: VsTransportProtoTCP, - }, - Flags: VsFlags{FixMSS: false}, - Scheduler: VsSchedulerRoundRobin, - Reals: []RealConfig{ - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.101", - ), - Port: 9090, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "10.0.0.0/24", - ), - ), - Weight: 100, - }, - { - Identifier: RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.102", - ), - Port: 9090, - }, - Src: xnetip.FromPrefix( - netip.MustParsePrefix( - "10.0.1.0/24", - ), - ), - Weight: 200, - }, - }, - AllowedSources: []AllowedSources{ - { - Nets: []xnetip.NetWithMask{ - xnetip.FromPrefix( - netip.MustParsePrefix("0.0.0.0/0"), - ), - }, - }, - }, - PeersV4: []netip.Addr{ - netip.MustParseAddr("192.168.2.1"), - }, - PeersV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::200"), - }, - }, - }, - SourceV4: netip.MustParseAddr("192.168.1.1"), - SourceV6: netip.MustParseAddr("2001:db8::100"), - DecapV4: []netip.Addr{ - netip.MustParseAddr("192.168.3.1"), - }, - DecapV6: []netip.Addr{ - netip.MustParseAddr("2001:db8::300"), - }, - }, - State: StateConfig{ - TableCapacity: 2000, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 15, - MaxRealWeight: 2048, - Vs: []uint32{0}, - }, - RefreshPeriod: time.Millisecond * 20, - MaxLoadFactor: 0.8, - } - - _, err := manager.Update(&newConfig, now) - require.NoError(t, err, "failed to update manager with new config") - - // Verify the new config is applied - updatedConfig := manager.Config() - require.NotNil(t, updatedConfig, "updated config should not be nil") - require.Equal(t, 1, len(updatedConfig.Balancer.Handler.VirtualServices), - "should have 1 virtual service after update") - require.Equal( - t, - newConfig.Balancer.Handler.VirtualServices[0].Identifier.Addr, - updatedConfig.Balancer.Handler.VirtualServices[0].Identifier.Addr, - "VS address should match new config", - ) - - // Verify graph reflects new config - graph := manager.Graph() - require.NotNil(t, graph, "graph should not be nil") - require.Equal( - t, - 1, - len(graph.VirtualServices), - "graph should have 1 virtual service", - ) - require.Equal( - t, - 2, - len(graph.VirtualServices[0].Reals), - "VS should have 2 reals", - ) - - // Verify all reals are initially disabled after config update - // Match by identifier since order may differ - for _, configReal := range newConfig.Balancer.Handler.VirtualServices[0].Reals { - var graphReal *GraphReal - for i := range graph.VirtualServices[0].Reals { - if graph.VirtualServices[0].Reals[i].Identifier.Addr.Compare( - configReal.Identifier.Addr, - ) == 0 && - graph.VirtualServices[0].Reals[i].Identifier.Port == configReal.Identifier.Port { - graphReal = &graph.VirtualServices[0].Reals[i] - break - } - } - require.NotNil(t, graphReal, "Real %s:%d should exist in graph", - configReal.Identifier.Addr, configReal.Identifier.Port) - require.False( - t, - graphReal.Enabled, - "Real %s:%d should be disabled after config update", - configReal.Identifier.Addr, - configReal.Identifier.Port, - ) - require.Equal(t, configReal.Weight, graphReal.Weight, - "Real %s:%d weight should match new config", - configReal.Identifier.Addr, configReal.Identifier.Port) - } - - // Verify info reflects new topology - info, err := manager.Info(now) - require.NoError(t, err, "failed to get info after update") - require.NotNil(t, info, "info should not be nil") - require.Equal(t, 1, len(info.Vs), "info should have 1 virtual service") - require.Equal( - t, - 2, - len(info.Vs[0].Reals), - "info VS should have 2 reals", - ) - - // Verify stats reflects new topology - stats, err := manager.Stats(&ref) - require.NoError(t, err, "failed to get stats after update") - require.NotNil(t, stats, "stats should not be nil") - require.Equal( - t, - 1, - len(stats.Vs), - "stats should have 1 virtual service", - ) - require.Equal( - t, - 2, - len(stats.Vs[0].Reals), - "stats VS should have 2 reals", - ) - - // Verify sessions (should still be empty or reset) - sessions := manager.Sessions(now) - require.NotNil(t, sessions, "sessions should not be nil") - }) - - // Test updating manager with same reals/VS but in different order - t.Run("UpdateWithReorderedConfig", func(t *testing.T) { - // First, restore the original config since previous test changed it - _, err = manager.Update(&managerConfig, now) - require.NoError(t, err, "failed to restore original config") - - // Now enable some reals - updates := []RealUpdate{ - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[0].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[0].Reals[0].Identifier, - }, - Weight: 100, - Enabled: 1, - }, - { - Identifier: RealIdentifier{ - VsIdentifier: managerConfig.Balancer.Handler.VirtualServices[1].Identifier, - Relative: managerConfig.Balancer.Handler.VirtualServices[1].Reals[0].Identifier, - }, - Weight: 100, - Enabled: 1, - }, - } - - err := manager.UpdateReals(updates) - require.NoError(t, err, "failed to enable reals") - - // Verify reals are enabled - graphBefore := manager.Graph() - require.True(t, graphBefore.VirtualServices[0].Reals[0].Enabled, - "first VS first real should be enabled") - require.True(t, graphBefore.VirtualServices[1].Reals[0].Enabled, - "second VS first real should be enabled") - - // Create a config with same VS and reals but in different order - reorderedConfig := managerConfig - // Swap the first two virtual services - reorderedConfig.Balancer.Handler.VirtualServices = []VsConfig{ - managerConfig.Balancer.Handler.VirtualServices[1], // Second VS first - managerConfig.Balancer.Handler.VirtualServices[0], // First VS second - managerConfig.Balancer.Handler.VirtualServices[2], // Third VS unchanged - } - // Also reverse the order of reals in the first VS (which is now at index 0) - reorderedConfig.Balancer.Handler.VirtualServices[0].Reals = []RealConfig{ - managerConfig.Balancer.Handler.VirtualServices[1].Reals[1], - managerConfig.Balancer.Handler.VirtualServices[1].Reals[0], - } - - _, err = manager.Update(&reorderedConfig, now) - require.NoError( - t, - err, - "failed to update manager with reordered config", - ) - - // Verify the config was updated - updatedConfig := manager.Config() - require.Equal(t, 3, len(updatedConfig.Balancer.Handler.VirtualServices), - "should still have 3 virtual services") - - // Get the graph after reordering - graphAfter := manager.Graph() - - // Find the reals that were enabled before and verify they're still enabled - // The real at 10.20.30.41:8443 (from original VS[1].Reals[0]) should still be enabled - // It's now at VS[0].Reals[1] in the reordered config - foundEnabledReal1 := false - for _, vs := range graphAfter.VirtualServices { - for _, real := range vs.Reals { - if real.Identifier.Addr.String() == "10.20.30.41" && - real.Identifier.Port == 8443 { - require.True( - t, - real.Enabled, - "previously enabled real 10.20.30.41:8443 should still be enabled after reordering", - ) - foundEnabledReal1 = true - } - } - } - require.True( - t, - foundEnabledReal1, - "should find the previously enabled real 10.20.30.41:8443", - ) - - // The real at 10.12.13.213:8080 (from original VS[0].Reals[0]) should still be enabled - foundEnabledReal2 := false - for _, vs := range graphAfter.VirtualServices { - for _, real := range vs.Reals { - if real.Identifier.Addr.String() == "10.12.13.213" && - real.Identifier.Port == 8080 { - require.True( - t, - real.Enabled, - "previously enabled real 10.12.13.213:8080 should still be enabled after reordering", - ) - foundEnabledReal2 = true - } - } - } - require.True( - t, - foundEnabledReal2, - "should find the previously enabled real 10.12.13.213:8080", - ) - }) -} diff --git a/modules/balancer/agent/go/ffi/types.go b/modules/balancer/agent/go/ffi/types.go deleted file mode 100644 index 3a7abdf84..000000000 --- a/modules/balancer/agent/go/ffi/types.go +++ /dev/null @@ -1,403 +0,0 @@ -package ffi - -// Go type definitions for balancer FFI operations, defining structures for virtual services, -// real servers, sessions, statistics, and configuration with support for IPv4/IPv6, -// TCP/UDP protocols, and various load balancing algorithms. - -import ( - "net/netip" - "time" - - "github.com/yanet-platform/yanet2/common/go/xnetip" -) - -// AddrRange represents a range of IP addresses -type AddrRange struct { - From netip.Addr - To netip.Addr -} - -// PortRange represents a range of port numbers -type PortRange struct { - From uint16 - To uint16 -} - -// AllowedSources represents an allowed source with network and optional port ranges -type AllowedSources struct { - Nets []xnetip.NetWithMask // Network with address and arbitrary mask - PortRanges []PortRange // Optional port ranges - Tag string // Tag for identification/filtering (empty string means no tracking) -} - -// VsScheduler represents the scheduling algorithm for a virtual service -type VsScheduler uint32 - -const ( - VsSchedulerSourceHash VsScheduler = 0 // source_hash - VsSchedulerRoundRobin VsScheduler = 1 // round_robin -) - -type VsTransportProto uint32 - -const ( - VsTransportProtoTCP VsTransportProto = 0 // IPPROTO_TCP - VsTransportProtoUDP VsTransportProto = 1 // IPPROTO_UDP -) - -func (proto *VsTransportProto) String() string { - switch *proto { - case VsTransportProtoTCP: - return "TCP" - case VsTransportProtoUDP: - return "UDP" - default: - return "unknown" - } -} - -// VsFlags represents flags for virtual service configuration -type VsFlags struct { - PureL3 bool // VS_PURE_L3_FLAG - serve all ports - FixMSS bool // VS_FIX_MSS_FLAG - fix MSS TCP option - GRE bool // VS_GRE_FLAG - use GRE tunneling - OPS bool // VS_OPS_FLAG - One Packet Scheduling (disable sessions) -} - -// VsIdentifier uniquely identifies a virtual service -type VsIdentifier struct { - Addr netip.Addr // Virtual service address - Port uint16 // Destination port (0 if PureL3) - TransportProto VsTransportProto // TCP or UDP -} - -func (id *VsIdentifier) String() string { - return netip.AddrPortFrom(id.Addr, id.Port).String() + "/" + id.TransportProto.String() -} - -// RelativeRealIdentifier identifies a real server relative to its VS -type RelativeRealIdentifier struct { - Addr netip.Addr // Real endpoint address - Port uint16 // Destination port on the real -} - -func (id *RelativeRealIdentifier) String() string { - return netip.AddrPortFrom(id.Addr, id.Port).String() -} - -// RealIdentifier uniquely identifies a real server within a virtual service -type RealIdentifier struct { - VsIdentifier VsIdentifier - Relative RelativeRealIdentifier -} - -// RealConfig contains static configuration for a real server -type RealConfig struct { - Identifier RelativeRealIdentifier // Relative identifier (within VS context) - Src xnetip.NetWithMask // Source network/addresses for this real (supports arbitrary masks) - Weight uint16 // Scheduler weight [0..MAX_REAL_WEIGHT] -} - -// RealUpdate represents a partial update for a real server -type RealUpdate struct { - Identifier RealIdentifier - Weight uint16 // New weight (DONT_UPDATE_REAL_WEIGHT to skip) - Enabled uint8 // 0=disabled, non-zero=enabled (DONT_UPDATE_REAL_ENABLED to skip) -} - -// RealStats contains statistics for a real server -type RealStats struct { - PacketsRealDisabled uint64 // Packets while real was disabled - OpsPackets uint64 // One-Packet Scheduling packets - ErrorIcmpPackets uint64 // ICMP error packets - CreatedSessions uint64 // Sessions created with this real - Packets uint64 // Total packets sent to real - Bytes uint64 // Total bytes sent to real -} - -// RealInfo contains runtime information about a real server -type RealInfo struct { - Dst netip.Addr // Real destination address - LastPacketTimestamp time.Time // Last packet time observed - ActiveSessions uint64 // Active sessions to this real -} - -// VsConfig contains static configuration for a virtual service -type VsConfig struct { - Identifier VsIdentifier - Flags VsFlags - Scheduler VsScheduler - Reals []RealConfig - AllowedSources []AllowedSources // Client source allowlist with networks and optional port ranges - PeersV4 []netip.Addr // IPv4 peer balancers for ICMP - PeersV6 []netip.Addr // IPv6 peer balancers for ICMP -} - -// VsStats contains per-virtual-service runtime counters -type VsStats struct { - IncomingPackets uint64 // Packets received for this VS - IncomingBytes uint64 // Bytes received for this VS - PacketSrcNotAllowed uint64 // Dropped due to disallowed source - NoReals uint64 // Failed real selection (all disabled) - OpsPackets uint64 // OPS packets sent without session - SessionTableOverflow uint64 // Failed to create session - EchoIcmpPackets uint64 // ICMP echo packets processed - ErrorIcmpPackets uint64 // ICMP error packets forwarded - RealIsDisabled uint64 // Session exists but real disabled - RealIsRemoved uint64 // Session exists but real removed - NotRescheduledPackets uint64 // No session and packet doesn't start one - BroadcastedIcmpPackets uint64 // ICMP broadcasted to peers - CreatedSessions uint64 // Sessions created for this VS - OutgoingPackets uint64 // Packets sent to selected real - OutgoingBytes uint64 // Bytes sent to selected real -} - -// VsInfo contains runtime information about a virtual service -type VsInfo struct { - Identifier VsIdentifier - LastPacketTimestamp time.Time - ActiveSessions uint64 - Reals []RealInfo -} - -// NamedVsStats pairs a VS identifier with its statistics -type NamedVsStats struct { - Identifier VsIdentifier - Stats VsStats - Reals []struct { - Dst netip.Addr - Stats RealStats - } - AllowedSources []struct { - Tag string - Passes uint64 - } -} - -// SessionsTimeouts contains timeout configuration per transport/state -type SessionsTimeouts struct { - TCPSynAck uint32 // Timeout for TCP SYN-ACK sessions (seconds) - TCPSyn uint32 // Timeout for TCP SYN sessions (seconds) - TCPFin uint32 // Timeout for TCP FIN sessions (seconds) - TCP uint32 // Default timeout for TCP packets (seconds) - UDP uint32 // Default timeout for UDP packets (seconds) - Default uint32 // Fallback timeout for other packets (seconds) -} - -// SessionIdentifier uniquely identifies a session -type SessionIdentifier struct { - ClientIP netip.Addr // Client source IP - ClientPort uint16 // Client source port - Real RealIdentifier // Selected real endpoint -} - -// SessionInfo contains runtime session metadata -type SessionInfo struct { - CreateTimestamp time.Time // Session creation time - LastPacketTimestamp time.Time // Last packet time observed - Timeout time.Duration // Current timeout applied (seconds) -} - -// Sessions contains a list of active sessions -type Sessions struct { - Sessions []struct { - Identifier SessionIdentifier - Info SessionInfo - } -} - -// PacketHandlerConfig defines runtime parameters for session handling -type PacketHandlerConfig struct { - SessionsTimeouts SessionsTimeouts - VirtualServices []VsConfig - SourceV4 netip.Addr // IPv4 source for generated packets - SourceV6 netip.Addr // IPv6 source for generated packets - DecapV4 []netip.Addr // IPv4 addresses to decapsulate - DecapV6 []netip.Addr // IPv6 addresses to decapsulate -} - -// PacketHandlerRef optionally narrows statistics to a specific handler -type PacketHandlerRef struct { - Device *string // Optional device name - Pipeline *string // Optional pipeline name - Function *string // Optional function name - Chain *string // Optional chain name -} - -// StateConfig contains session table sizing configuration -type StateConfig struct { - TableCapacity uint // Number of session table entries -} - -// BalancerConfig combines packet handler and state configuration -type BalancerConfig struct { - Handler PacketHandlerConfig - State StateConfig -} - -// L4Stats contains module counters for L4 packets -type L4Stats struct { - IncomingPackets uint64 // L4 packets received - SelectVsFailed uint64 // Failed to select virtual service - InvalidPackets uint64 // Invalid or malformed packets - SelectRealFailed uint64 // Failed to select a real - OutgoingPackets uint64 // Packets sent to selected real -} - -// IcmpStats contains counters for ICMP packets -type IcmpStats struct { - IncomingPackets uint64 // ICMP packets received - SrcNotAllowed uint64 // Source not allowed by VS policy - EchoResponses uint64 // Echo replies generated - PayloadTooShortIP uint64 // Payload too short for IP header - UnmatchingSrcFromOriginal uint64 // Original src doesn't match dst - PayloadTooShortPort uint64 // Payload too short for ports - UnexpectedTransport uint64 // Original transport not TCP/UDP - UnrecognizedVs uint64 // Destination not recognized as VS - ForwardedPackets uint64 // ICMP forwarded to real - BroadcastedPackets uint64 // ICMP broadcasts sent to peers - PacketClonesSent uint64 // Packet clones created/sent - PacketClonesReceived uint64 // Packet clones received - PacketCloneFailures uint64 // Failures creating packet clone -} - -// CommonStats contains total incoming/outgoing packet counts -type CommonStats struct { - IncomingPackets uint64 // Total incoming packets - IncomingBytes uint64 // Total incoming bytes - UnexpectedNetworkProto uint64 // Unsupported network protocol - DecapSuccessful uint64 // Packets successfully decapsulated - DecapFailed uint64 // Packets that failed decapsulation - OutgoingPackets uint64 // Total outgoing packets - OutgoingBytes uint64 // Total outgoing bytes -} - -// BalancerStats contains aggregated statistics for the balancer -type BalancerStats struct { - L4 L4Stats - IcmpIpv4 IcmpStats - IcmpIpv6 IcmpStats - Common CommonStats - Vs []NamedVsStats -} - -// BalancerInfo contains aggregated information about a balancer instance -type BalancerInfo struct { - ActiveSessions uint64 - LastPacketTimestamp time.Time - Vs []VsInfo -} - -// GraphReal represents a real server in the graph topology -type GraphReal struct { - Identifier RelativeRealIdentifier - Weight uint16 - Enabled bool -} - -// GraphVs represents a virtual service in the graph topology -type GraphVs struct { - Identifier VsIdentifier - Reals []GraphReal -} - -// BalancerGraph represents the topology of VS to Reals relationships -type BalancerGraph struct { - VirtualServices []GraphVs -} - -// BalancerManagerWlcConfig contains WLC algorithm configuration -type BalancerManagerWlcConfig struct { - Power uint // Power factor for weight calculations - MaxRealWeight uint // Maximum weight value for any real - Vs []uint32 // Array of virtual service IDs -} - -// BalancerManagerConfig contains complete manager configuration -type BalancerManagerConfig struct { - Balancer BalancerConfig - Wlc BalancerManagerWlcConfig - RefreshPeriod time.Duration // Refresh interval - MaxLoadFactor float32 // Maximum load factor (0.0 to 1.0) -} - -// UpdateInfo contains metadata about what was reused during a balancer update -type UpdateInfo struct { - VsIpv4MatcherReused bool // Whether IPv4 VS matcher was reused - VsIpv6MatcherReused bool // Whether IPv6 VS matcher was reused - ACLReusedVs []VsIdentifier // VS identifiers for which ACL was reused -} - -// RealsUsage contains memory usage for real servers within a VS -type RealsUsage struct { - CountersUsage uint64 - DataUsage uint64 - TotalUsage uint64 -} - -// VsInspect contains memory usage for a single virtual service -type VsInspect struct { - ACLUsage uint64 - RingUsage uint64 - CountersUsage uint64 - RealsUsage RealsUsage - OtherUsage uint64 - TotalUsage uint64 -} - -// NamedVsInspect pairs a VS identifier with its memory inspection -type NamedVsInspect struct { - Identifier VsIdentifier - Inspect VsInspect -} - -// PacketHandlerVsInspect contains memory usage for IPv4 or IPv6 packet handler VS section -type PacketHandlerVsInspect struct { - MatcherUsage uint64 - SummaryVsUsage uint64 - VsInspects []NamedVsInspect - AnnounceUsage uint64 - IndexUsage uint64 - TotalUsage uint64 -} - -// PacketHandlerInspect contains complete packet handler memory usage -type PacketHandlerInspect struct { - VsIpv4Inspect PacketHandlerVsInspect - VsIpv6Inspect PacketHandlerVsInspect - SummaryVsUsage uint64 - VsIndexUsage uint64 - RealsIndexUsage uint64 - CountersUsage uint64 - DecapUsage uint64 - TotalUsage uint64 -} - -// StateInspect contains state memory usage -type StateInspect struct { - VsRegistryUsage uint64 - RealsRegistryUsage uint64 - SessionTableUsage uint64 - TotalUsage uint64 -} - -// BalancerInspect contains per-balancer memory inspection -type BalancerInspect struct { - PacketHandler PacketHandlerInspect - State StateInspect - OtherUsage uint64 - TotalUsage uint64 -} - -// NamedBalancerInspect pairs a balancer name with its memory inspection -type NamedBalancerInspect struct { - Name string - Inspect BalancerInspect -} - -// AgentInspect contains agent-level memory inspection -type AgentInspect struct { - MemoryLimit uint64 - MemoryUsage uint64 - Balancers []NamedBalancerInspect -} diff --git a/modules/balancer/agent/go/ffi/update_vs_test.go b/modules/balancer/agent/go/ffi/update_vs_test.go deleted file mode 100644 index 68a8230a6..000000000 --- a/modules/balancer/agent/go/ffi/update_vs_test.go +++ /dev/null @@ -1,506 +0,0 @@ -package ffi - -import ( - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestUpdateVsWLCIndicesInConfig tests that WLC indices are correctly set in BalancerManagerConfig -func TestUpdateVsWLCIndicesInConfig(t *testing.T) { - t.Run("NoWLCEnabled", func(t *testing.T) { - config := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{}, // No WLC indices - }, - } - - assert.Empty( - t, - config.Wlc.Vs, - "WLC indices should be empty when no VS has WLC enabled", - ) - }) - - t.Run("SingleWLCEnabled", func(t *testing.T) { - config := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{1}, // VS at index 1 has WLC - }, - } - - require.Len(t, config.Wlc.Vs, 1, "WLC indices should contain 1 entry") - assert.Equal( - t, - uint32(1), - config.Wlc.Vs[0], - "WLC index should be 1 (second VS)", - ) - }) - - t.Run("MultipleWLCEnabled", func(t *testing.T) { - config := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.4"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{ - 0, - 2, - 3, - }, // VS at indices 0, 2, 3 have WLC - }, - } - - require.Len(t, config.Wlc.Vs, 3, "WLC indices should contain 3 entries") - assert.Equal( - t, - uint32(0), - config.Wlc.Vs[0], - "First WLC index should be 0", - ) - assert.Equal( - t, - uint32(2), - config.Wlc.Vs[1], - "Second WLC index should be 2", - ) - assert.Equal( - t, - uint32(3), - config.Wlc.Vs[2], - "Third WLC index should be 3", - ) - }) - - t.Run("AllWLCEnabled", func(t *testing.T) { - config := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 1, 2}, // All VS have WLC - }, - } - - require.Len(t, config.Wlc.Vs, 3, "WLC indices should contain 3 entries") - assert.Equal( - t, - uint32(0), - config.Wlc.Vs[0], - "First WLC index should be 0", - ) - assert.Equal( - t, - uint32(1), - config.Wlc.Vs[1], - "Second WLC index should be 1", - ) - assert.Equal( - t, - uint32(2), - config.Wlc.Vs[2], - "Third WLC index should be 2", - ) - }) -} - -// TestUpdateVsWLCIndicesRecalculationScenarios tests WLC index recalculation scenarios -func TestUpdateVsWLCIndicesRecalculationScenarios(t *testing.T) { - t.Run("AddWLCEnabledVS", func(t *testing.T) { - // Initial config: VS0 (WLC=true), VS1 (WLC=false) - initialConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, - }, - } - - // After adding VS2 with WLC=true - updatedConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: append( - initialConfig.Balancer.Handler.VirtualServices, - VsConfig{ - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - ), - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 2}, // VS0 and VS2 have WLC - }, - } - - require.Len( - t, - updatedConfig.Wlc.Vs, - 2, - "Should have 2 WLC indices after adding WLC-enabled VS", - ) - assert.Equal( - t, - uint32(0), - updatedConfig.Wlc.Vs[0], - "First WLC index should be 0", - ) - assert.Equal( - t, - uint32(2), - updatedConfig.Wlc.Vs[1], - "Second WLC index should be 2", - ) - }) - - t.Run("RemoveWLCEnabledVS", func(t *testing.T) { - // Initial config: VS0 (WLC=true), VS1 (WLC=true), VS2 (WLC=false) - initialConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 1}, - }, - } - - // After removing VS1 (middle VS with WLC=true) - // New list: VS0 (WLC=true), VS2 (WLC=false) -> indices shift - updatedConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - initialConfig.Balancer.Handler.VirtualServices[0], - initialConfig.Balancer.Handler.VirtualServices[2], - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{ - 0, - }, // Only VS0 has WLC now (at new index 0) - }, - } - - require.Len( - t, - updatedConfig.Wlc.Vs, - 1, - "Should have 1 WLC index after removing WLC-enabled VS", - ) - assert.Equal( - t, - uint32(0), - updatedConfig.Wlc.Vs[0], - "WLC index should be 0", - ) - }) - - t.Run("UpdateVSToEnableWLC", func(t *testing.T) { - // Initial config: VS0 (WLC=false), VS1 (WLC=false) - initialConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{}, - }, - } - - // After updating VS1 to enable WLC - updatedConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - initialConfig.Balancer.Handler.VirtualServices[0], - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{1}, // VS1 now has WLC - }, - } - - require.Len( - t, - updatedConfig.Wlc.Vs, - 1, - "Should have 1 WLC index after enabling WLC on VS", - ) - assert.Equal( - t, - uint32(1), - updatedConfig.Wlc.Vs[0], - "WLC index should be 1", - ) - }) - - t.Run("UpdateVSToDisableWLC", func(t *testing.T) { - // Initial config: VS0 (WLC=true), VS1 (WLC=true), VS2 (WLC=true) - initialConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 1, 2}, - }, - } - - // After updating VS1 to disable WLC - updatedConfig := &BalancerManagerConfig{ - Balancer: BalancerConfig{ - Handler: PacketHandlerConfig{ - VirtualServices: []VsConfig{ - initialConfig.Balancer.Handler.VirtualServices[0], - { - Identifier: VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 80, - TransportProto: VsTransportProtoTCP, - }, - }, - initialConfig.Balancer.Handler.VirtualServices[2], - }, - }, - }, - Wlc: BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 2}, // VS0 and VS2 have WLC - }, - } - - require.Len( - t, - updatedConfig.Wlc.Vs, - 2, - "Should have 2 WLC indices after disabling WLC on VS", - ) - assert.Equal( - t, - uint32(0), - updatedConfig.Wlc.Vs[0], - "First WLC index should be 0", - ) - assert.Equal( - t, - uint32(2), - updatedConfig.Wlc.Vs[1], - "Second WLC index should be 2", - ) - }) -} diff --git a/modules/balancer/agent/go/graph_test.go b/modules/balancer/agent/go/graph_test.go deleted file mode 100644 index a33c52c35..000000000 --- a/modules/balancer/agent/go/graph_test.go +++ /dev/null @@ -1,980 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" -) - -// TestBuildConfigWeightsMap tests the helper function that builds config weights lookup map -func TestBuildConfigWeightsMap(t *testing.T) { - tests := []struct { - name string - config *ffi.BalancerManagerConfig - expected map[vsRealKey]uint16 - }{ - { - name: "Empty virtual services", - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{}, - }, - }, - }, - expected: map[vsRealKey]uint16{}, - }, - { - name: "Single VS with single real", - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - }, - }, - }, - }, - expected: map[vsRealKey]uint16{ - { - vsAddr: "10.0.0.1", - vsPort: 80, - vsProto: ffi.VsTransportProtoTCP, - realAddr: "192.168.1.1", - realPort: 8080, - }: 100, - }, - }, - { - name: "Multiple VS with multiple reals", - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.2", - ), - Port: 8080, - }, - Weight: 200, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.2", - ), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.2.1", - ), - Port: 8443, - }, - Weight: 150, - }, - }, - }, - }, - }, - }, - }, - expected: map[vsRealKey]uint16{ - { - vsAddr: "10.0.0.1", - vsPort: 80, - vsProto: ffi.VsTransportProtoTCP, - realAddr: "192.168.1.1", - realPort: 8080, - }: 100, - { - vsAddr: "10.0.0.1", - vsPort: 80, - vsProto: ffi.VsTransportProtoTCP, - realAddr: "192.168.1.2", - realPort: 8080, - }: 200, - { - vsAddr: "10.0.0.2", - vsPort: 443, - vsProto: ffi.VsTransportProtoTCP, - realAddr: "192.168.2.1", - realPort: 8443, - }: 150, - }, - }, - { - name: "IPv6 addresses", - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::100", - ), - Port: 8080, - }, - Weight: 250, - }, - }, - }, - }, - }, - }, - }, - expected: map[vsRealKey]uint16{ - { - vsAddr: "2001:db8::1", - vsPort: 80, - vsProto: ffi.VsTransportProtoTCP, - realAddr: "2001:db8::100", - realPort: 8080, - }: 250, - }, - }, - { - name: "UDP transport protocol", - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 5353, - }, - Weight: 75, - }, - }, - }, - }, - }, - }, - }, - expected: map[vsRealKey]uint16{ - { - vsAddr: "10.0.0.1", - vsPort: 53, - vsProto: ffi.VsTransportProtoUDP, - realAddr: "192.168.1.1", - realPort: 5353, - }: 75, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := buildConfigWeightsMap(tt.config) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestConvertGraphToProtoWithConfig tests the main Graph conversion function -// that merges FFI graph with config to produce protobuf Graph with correct -// Weight (from config) and EffectiveWeight (from graph) -func TestConvertGraphToProtoWithConfig(t *testing.T) { - tests := []struct { - name string - graph *ffi.BalancerGraph - config *ffi.BalancerManagerConfig - verify func(t *testing.T, result *balancerpb.Graph) - }{ - { - name: "Empty graph returns empty VirtualServices", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{}, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{}, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - assert.NotNil(t, result) - assert.Empty(t, result.VirtualServices) - }, - }, - { - name: "Basic single VS with single real - Weight from config, EffectiveWeight from graph", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 150, // Effective weight (e.g., after WLC adjustment) - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, // Config weight - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - require.Len(t, vs.Reals, 1) - real := vs.Reals[0] - - // Weight should come from config - assert.Equal( - t, - uint32(100), - real.Weight, - "Weight should be from config", - ) - // EffectiveWeight should come from graph - assert.Equal( - t, - uint32(150), - real.EffectiveWeight, - "EffectiveWeight should be from graph", - ) - assert.True(t, real.Enabled) - }, - }, - { - name: "Multiple reals with different weights", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 120, // Effective weight - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 180, // Effective weight - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.3"), - Port: 8081, - }, - Weight: 75, // Effective weight - Enabled: false, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, // Config weight - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.2", - ), - Port: 8080, - }, - Weight: 200, // Config weight - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.3", - ), - Port: 8081, - }, - Weight: 50, // Config weight - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - require.Len(t, vs.Reals, 3) - - // Real 1: Config=100, Effective=120 - assert.Equal(t, uint32(100), vs.Reals[0].Weight) - assert.Equal(t, uint32(120), vs.Reals[0].EffectiveWeight) - assert.True(t, vs.Reals[0].Enabled) - - // Real 2: Config=200, Effective=180 - assert.Equal(t, uint32(200), vs.Reals[1].Weight) - assert.Equal(t, uint32(180), vs.Reals[1].EffectiveWeight) - assert.True(t, vs.Reals[1].Enabled) - - // Real 3: Config=50, Effective=75, Disabled - assert.Equal(t, uint32(50), vs.Reals[2].Weight) - assert.Equal(t, uint32(75), vs.Reals[2].EffectiveWeight) - assert.False(t, vs.Reals[2].Enabled) - }, - }, - { - name: "Multiple virtual services", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 150, - Enabled: true, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.2.1"), - Port: 8443, - }, - Weight: 250, - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.2.2"), - Port: 8443, - }, - Weight: 300, - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.2", - ), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.2.1", - ), - Port: 8443, - }, - Weight: 200, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.2.2", - ), - Port: 8443, - }, - Weight: 300, - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 2) - - // VS 1 - vs1 := result.VirtualServices[0] - require.Len(t, vs1.Reals, 1) - assert.Equal(t, uint32(100), vs1.Reals[0].Weight) - assert.Equal(t, uint32(150), vs1.Reals[0].EffectiveWeight) - - // VS 2 - vs2 := result.VirtualServices[1] - require.Len(t, vs2.Reals, 2) - assert.Equal(t, uint32(200), vs2.Reals[0].Weight) - assert.Equal(t, uint32(250), vs2.Reals[0].EffectiveWeight) - assert.Equal(t, uint32(300), vs2.Reals[1].Weight) - assert.Equal(t, uint32(300), vs2.Reals[1].EffectiveWeight) - }, - }, - { - name: "Real in graph but not in config - Weight defaults to 0", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 150, - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.99", - ), // Not in config - Port: 9999, - }, - Weight: 200, - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - // Note: 192.168.1.99:9999 is NOT in config - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - require.Len(t, vs.Reals, 2) - - // Real 1: In config - assert.Equal(t, uint32(100), vs.Reals[0].Weight) - assert.Equal(t, uint32(150), vs.Reals[0].EffectiveWeight) - - // Real 2: NOT in config - Weight should be 0 - assert.Equal( - t, - uint32(0), - vs.Reals[1].Weight, - "Weight should be 0 for real not in config", - ) - assert.Equal( - t, - uint32(200), - vs.Reals[1].EffectiveWeight, - "EffectiveWeight should still come from graph", - ) - }, - }, - { - name: "IPv6 addresses", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 8080, - }, - Weight: 180, - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::100", - ), - Port: 8080, - }, - Weight: 150, - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - - // Verify VS identifier is IPv6 - vsAddr, _ := netip.AddrFromSlice(vs.Identifier.Addr.Bytes) - assert.True(t, vsAddr.Is6()) - assert.Equal(t, "2001:db8::1", vsAddr.String()) - - require.Len(t, vs.Reals, 1) - real := vs.Reals[0] - - // Verify real identifier is IPv6 - realAddr, _ := netip.AddrFromSlice(real.Identifier.Ip.Bytes) - assert.True(t, realAddr.Is6()) - assert.Equal(t, "2001:db8::100", realAddr.String()) - - assert.Equal(t, uint32(150), real.Weight) - assert.Equal(t, uint32(180), real.EffectiveWeight) - }, - }, - { - name: "UDP transport protocol", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 5353, - }, - Weight: 90, - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 5353, - }, - Weight: 75, - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - - assert.Equal( - t, - balancerpb.TransportProto_UDP, - vs.Identifier.Proto, - ) - require.Len(t, vs.Reals, 1) - assert.Equal(t, uint32(75), vs.Reals[0].Weight) - assert.Equal(t, uint32(90), vs.Reals[0].EffectiveWeight) - }, - }, - { - name: "Disabled real with zero effective weight", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 0, // Disabled real may have 0 effective weight - Enabled: false, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, // Config weight is still set - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - require.Len(t, vs.Reals, 1) - real := vs.Reals[0] - - // Config weight should still be present - assert.Equal( - t, - uint32(100), - real.Weight, - "Config weight should be preserved for disabled real", - ) - // Effective weight is 0 because real is disabled - assert.Equal( - t, - uint32(0), - real.EffectiveWeight, - "EffectiveWeight should be 0 for disabled real", - ) - assert.False(t, real.Enabled) - }, - }, - { - name: "WLC adjusted weights - effective weight differs from config", - graph: &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 250, // WLC increased weight (less loaded) - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 50, // WLC decreased weight (more loaded) - Enabled: true, - }, - }, - }, - }, - }, - config: &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "10.0.0.1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, // Both have same config weight - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.2", - ), - Port: 8080, - }, - Weight: 100, // Both have same config weight - }, - }, - }, - }, - }, - }, - }, - verify: func(t *testing.T, result *balancerpb.Graph) { - require.Len(t, result.VirtualServices, 1) - vs := result.VirtualServices[0] - require.Len(t, vs.Reals, 2) - - // Both reals have same config weight but different effective weights - assert.Equal( - t, - uint32(100), - vs.Reals[0].Weight, - "Config weight should be same", - ) - assert.Equal( - t, - uint32(250), - vs.Reals[0].EffectiveWeight, - "WLC increased weight for less loaded server", - ) - - assert.Equal( - t, - uint32(100), - vs.Reals[1].Weight, - "Config weight should be same", - ) - assert.Equal( - t, - uint32(50), - vs.Reals[1].EffectiveWeight, - "WLC decreased weight for more loaded server", - ) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ConvertGraphToProtoWithConfig(tt.graph, tt.config) - tt.verify(t, result) - }) - } -} diff --git a/modules/balancer/agent/go/inspect_test.go b/modules/balancer/agent/go/inspect_test.go deleted file mode 100644 index 63962e936..000000000 --- a/modules/balancer/agent/go/inspect_test.go +++ /dev/null @@ -1,450 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/durationpb" -) - -func TestBalancerAgentInspect(t *testing.T) { - // Create mock Yanet instance - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 27, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: "eth0", - }, - }, - }) - require.NoError(t, err, "failed to initialize mock") - require.NotNil(t, m, "mock is nil") - defer m.Free() - - // Create logger for tests - log := zap.NewNop().Sugar() - - // Create balancer agent - agent, err := NewBalancerAgent(m.SharedMemory(), 32*datasize.MB, log) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - // Define first balancer configuration - firstBalancerConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{ - FixMss: true, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.1"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.2").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.2"). - AsSlice(), - }, - Port: 8443, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 50, - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: nil, - RefreshPeriod: durationpb.New(0), - Wlc: nil, - }, - } - - // Define second balancer configuration - secondBalancerConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 15, - TcpSyn: 25, - TcpFin: 20, - Tcp: 69, - Udp: 15, - Default: 25, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.0.0.1").AsSlice(), - }, - Port: 8080, - Proto: balancerpb.TransportProto_UDP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.0.1.1"). - AsSlice(), - }, - Port: 9090, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.17.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 200, - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("20.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::a").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(2000); return &v }(), - SessionTableMaxLoadFactor: nil, - RefreshPeriod: durationpb.New(0), - Wlc: nil, - }, - } - - // Create first balancer - err = agent.NewBalancerManager("balancer0", firstBalancerConfig) - require.NoError(t, err, "failed to create first balancer") - - // Create second balancer - err = agent.NewBalancerManager("balancer1", secondBalancerConfig) - require.NoError(t, err, "failed to create second balancer") - - t.Run("Inspect_AgentLevel", func(t *testing.T) { - // Get inspect data - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Verify agent-level fields - assert.Greater( - t, - inspect.MemoryLimit, - uint64(0), - "memory limit should be greater than 0", - ) - assert.GreaterOrEqual( - t, - inspect.MemoryUsage, - uint64(0), - "memory usage should be non-negative", - ) - assert.LessOrEqual( - t, - inspect.MemoryUsage, - inspect.MemoryLimit, - "memory usage should not exceed limit", - ) - - // Verify we have two balancers - require.Len( - t, - inspect.Balancers, - 2, - "expected two balancers in inspect", - ) - }) - - t.Run("Inspect_BalancerNames", func(t *testing.T) { - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Collect balancer names - balancerNames := make(map[string]bool) - for _, balancer := range inspect.Balancers { - balancerNames[balancer.Name] = true - } - - // Verify both balancer names are present - assert.True( - t, - balancerNames["balancer0"], - "balancer0 should be in inspect", - ) - assert.True( - t, - balancerNames["balancer1"], - "balancer1 should be in inspect", - ) - }) - - t.Run("Inspect_Balancer0_VirtualServices", func(t *testing.T) { - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Find balancer0 - var balancer0 *balancerpb.BalancerInspect - for _, b := range inspect.Balancers { - if b.Name == "balancer0" { - balancer0 = b - break - } - } - require.NotNil(t, balancer0, "balancer0 not found in inspect") - - // Verify packet handler inspect exists - require.NotNil( - t, - balancer0.PacketHandlerInspect, - "packet handler inspect is nil", - ) - - // Verify memory usage fields are present - assert.GreaterOrEqual( - t, - balancer0.TotalUsage, - uint64(0), - "total usage should be non-negative", - ) - - // Check IPv4 VS inspect - require.NotNil( - t, - balancer0.PacketHandlerInspect.VsIpv4Inspect, - "IPv4 VS inspect is nil", - ) - vsIpv4Inspects := balancer0.PacketHandlerInspect.VsIpv4Inspect.VsInspects - require.Len( - t, - vsIpv4Inspects, - 2, - "expected 2 IPv4 virtual services in balancer0", - ) - - // Verify VS identifiers - vsAddrs := make(map[string]bool) - for _, vsInspect := range vsIpv4Inspects { - require.NotNil(t, vsInspect.Identifier, "VS identifier is nil") - addr := netip.AddrFrom4([4]byte(vsInspect.Identifier.Addr.Bytes)) - vsAddrs[addr.String()] = true - } - - assert.True( - t, - vsAddrs["10.0.0.1"], - "VS 10.0.0.1:80 should be in balancer0", - ) - assert.True( - t, - vsAddrs["10.0.0.2"], - "VS 10.0.0.2:443 should be in balancer0", - ) - }) - - t.Run("Inspect_Balancer1_VirtualServices", func(t *testing.T) { - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Find balancer1 - var balancer1 *balancerpb.BalancerInspect - for _, b := range inspect.Balancers { - if b.Name == "balancer1" { - balancer1 = b - break - } - } - require.NotNil(t, balancer1, "balancer1 not found in inspect") - - // Verify packet handler inspect exists - require.NotNil( - t, - balancer1.PacketHandlerInspect, - "packet handler inspect is nil", - ) - - // Check IPv4 VS inspect - require.NotNil( - t, - balancer1.PacketHandlerInspect.VsIpv4Inspect, - "IPv4 VS inspect is nil", - ) - vsIpv4Inspects := balancer1.PacketHandlerInspect.VsIpv4Inspect.VsInspects - require.Len( - t, - vsIpv4Inspects, - 1, - "expected 1 IPv4 virtual service in balancer1", - ) - - // Verify VS identifier - vsInspect := vsIpv4Inspects[0] - require.NotNil(t, vsInspect.Identifier, "VS identifier is nil") - addr := netip.AddrFrom4([4]byte(vsInspect.Identifier.Addr.Bytes)) - assert.Equal( - t, - "20.0.0.1", - addr.String(), - "VS address should be 20.0.0.1", - ) - assert.Equal( - t, - uint32(8080), - vsInspect.Identifier.Port, - "VS port should be 8080", - ) - assert.Equal( - t, - balancerpb.TransportProto_UDP, - vsInspect.Identifier.Proto, - "VS proto should be UDP", - ) - }) - - t.Run("Inspect_StateMemory", func(t *testing.T) { - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Verify state inspect for both balancers - for _, balancer := range inspect.Balancers { - require.NotNil( - t, - balancer.StateInspect, - "state inspect is nil for %s", - balancer.Name, - ) - assert.GreaterOrEqual( - t, - balancer.StateInspect.TotalUsage, - uint64(0), - "state total usage should be non-negative", - ) - } - }) - - t.Run("Inspect_MemoryBreakdown", func(t *testing.T) { - inspect := agent.Inspect() - require.NotNil(t, inspect, "inspect data is nil") - - // Verify memory breakdown for each balancer - for _, balancer := range inspect.Balancers { - // Packet handler memory - ph := balancer.PacketHandlerInspect - require.NotNil(t, ph, "packet handler inspect is nil") - assert.GreaterOrEqual( - t, - ph.TotalUsage, - uint64(0), - "packet handler total usage should be non-negative", - ) - - // State memory - state := balancer.StateInspect - require.NotNil(t, state, "state inspect is nil") - assert.GreaterOrEqual( - t, - state.TotalUsage, - uint64(0), - "state total usage should be non-negative", - ) - - // Total balancer memory should be sum of components - expectedTotal := ph.TotalUsage + state.TotalUsage + balancer.OtherUsage - assert.Equal( - t, - expectedTotal, - balancer.TotalUsage, - "balancer total usage should equal sum of components", - ) - } - }) -} diff --git a/modules/balancer/agent/go/manager.go b/modules/balancer/agent/go/manager.go deleted file mode 100644 index 2ed7eb9a8..000000000 --- a/modules/balancer/agent/go/manager.go +++ /dev/null @@ -1,905 +0,0 @@ -package balancer - -// BalancerManager implementation providing lifecycle management, configuration updates, -// real server management, and WLC (Weighted Least Connection) scheduling with automatic -// session table resizing and periodic refresh tasks. - -import ( - "context" - "fmt" - "maps" - "net/netip" - "strconv" - "sync" - "time" - - "github.com/yanet-platform/yanet2/common/commonpb" - "github.com/yanet-platform/yanet2/common/go/metrics" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "go.uber.org/zap" -) - -type BalancerManager struct { - handle *ffi.BalancerManager - - realUpdateBuffer []ffi.RealUpdate - - // Background task management - cancel context.CancelFunc - - mu sync.Mutex - - // Logger - log *zap.SugaredLogger - - handlerMetrics handlersMetrics -} - -func NewBalancerManager( - handle *ffi.BalancerManager, - log *zap.SugaredLogger, -) *BalancerManager { - name := handle.Name() - manager := &BalancerManager{ - handle: handle, - realUpdateBuffer: []ffi.RealUpdate{}, - log: log.With("balancer", name), - handlerMetrics: newHandlersMetrics(), - } - manager.startBackgroundTasks() - return manager -} - -func (b *BalancerManager) newHandlerTracker(handle string, extraLabels ...metrics.Labels) *handlerMetricTracker { - labels := metrics.Labels{ - "config": b.Name(), - } - for _, extra := range extraLabels { - maps.Copy(labels, extra) - } - return newHandlerMetricTracker(handle, &b.handlerMetrics, defaultLatencyBoundsMS, labels) -} - -func (b *BalancerManager) Name() string { - return b.handle.Name() -} - -func (b *BalancerManager) Update( - config *balancerpb.BalancerConfig, - now time.Time, -) (*ffi.UpdateInfo, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("update") - defer tracker.Fix() - - b.log.Debugw("updating balancer configuration") - - // Merge new config with current config for UPDATE mode - mergedConfig, err := mergeBalancerConfig(config, b.handle.Config()) - if err != nil { - b.log.Errorw("failed to merge config", "error", err) - return nil, fmt.Errorf("failed to merge config: %w", err) - } - - // Convert merged protobuf to FFI config - ffiConfig, err := ProtoToFFIConfig(mergedConfig) - if err != nil { - b.log.Errorw("failed to convert config", "error", err) - return nil, fmt.Errorf("failed to convert config: %w", err) - } - - // Create WLC configuration with validation - wlcConfig, err := createWlcConfig(mergedConfig) - if err != nil { - b.log.Errorw("failed to create WLC config", "error", err) - return nil, fmt.Errorf("failed to create WLC config: %w", err) - } - - // Create manager config - managerConfig := &ffi.BalancerManagerConfig{ - Balancer: ffiConfig, - RefreshPeriod: mergedConfig.State.RefreshPeriod.AsDuration(), - MaxLoadFactor: *mergedConfig.State.SessionTableMaxLoadFactor, - Wlc: wlcConfig, - } - - // Update via FFI - updateInfo, err := b.handle.Update(managerConfig, now) - if err != nil { - b.log.Errorw("failed to update manager", "error", err) - return nil, fmt.Errorf("failed to update manager: %w", err) - } - - // Log update information - b.log.Infow("balancer configuration updated successfully", - "vs_ipv4_matcher_reused", updateInfo.VsIpv4MatcherReused, - "vs_ipv6_matcher_reused", updateInfo.VsIpv6MatcherReused, - "acl_reused_vs_count", len(updateInfo.ACLReusedVs)) - - if len(updateInfo.ACLReusedVs) > 0 { - b.log.Debugw("ACL filters reused for virtual services", - "count", len(updateInfo.ACLReusedVs), - "vs_identifiers", updateInfo.ACLReusedVs) - } - - // restart background tasks - b.startBackgroundTasks() - - return updateInfo, nil -} - -func (b *BalancerManager) UpdateReals( - updates []*balancerpb.RealUpdate, - buffer bool, -) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("update_reals", metrics.Labels{"buffer": strconv.FormatBool(buffer)}) - defer tracker.Fix() - - b.log.Debugw("updating reals", "count", len(updates), "buffer", buffer) - - // Convert protobuf updates to FFI updates - ffiUpdates := make([]ffi.RealUpdate, 0, len(updates)) - for i, update := range updates { - ffiUpdate, err := NewRealUpdateFromProto(update) - if err != nil { - b.log.Errorw("failed to convert update", "index", i, "error", err) - return 0, fmt.Errorf( - "failed to convert update at index %d: %w", - i, - err, - ) - } - ffiUpdates = append(ffiUpdates, *ffiUpdate) - } - - if buffer { - // Buffer the updates - b.realUpdateBuffer = append(b.realUpdateBuffer, ffiUpdates...) - b.log.Debugw( - "real updates buffered", - "count", - len(updates), - "total_buffered", - len(b.realUpdateBuffer), - ) - return len(updates), nil - } - - // Apply immediately - if err := b.handle.UpdateReals(ffiUpdates); err != nil { - b.log.Errorw("failed to update reals", "error", err) - return 0, err - } - - b.log.Infow("real updates applied", "count", len(updates)) - return len(updates), nil -} - -func (b *BalancerManager) FlushRealUpdates() (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("flush_real_updates") - defer tracker.Fix() - - count := len(b.realUpdateBuffer) - if count == 0 { - b.log.Debugw("no buffered updates to flush") - return 0, nil - } - - b.log.Debugw("flushing buffered real updates", "count", count) - - // Apply buffered updates - if err := b.handle.UpdateReals(b.realUpdateBuffer); err != nil { - b.log.Errorw("failed to flush real updates", "error", err) - return 0, err - } - - // Clear buffer - b.realUpdateBuffer = b.realUpdateBuffer[:0] - - b.log.Infow("buffered real updates flushed", "count", count) - return count, nil -} - -func (b *BalancerManager) Config() *balancerpb.BalancerConfig { - b.mu.Lock() - defer b.mu.Unlock() - - return ConvertBalancerConfigToProto(b.handle.Config()) -} - -func (b *BalancerManager) BufferedUpdates() []*balancerpb.RealUpdate { - b.mu.Lock() - defer b.mu.Unlock() - - updates := make([]*balancerpb.RealUpdate, len(b.realUpdateBuffer)) - for i := range b.realUpdateBuffer { - updates[i] = ConvertFFIRealUpdateToProto(&b.realUpdateBuffer[i]) - } - return updates -} - -func (b *BalancerManager) Graph() *balancerpb.Graph { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("graph") - defer tracker.Fix() - - cfg := b.handle.Config() - graph := b.handle.Graph() - - return ConvertGraphToProtoWithConfig(graph, cfg) -} - -func (b *BalancerManager) Info( - now time.Time, -) (*balancerpb.BalancerInfo, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("info") - defer tracker.Fix() - - ffiInfo, err := b.handle.Info(now) - if err != nil { - return nil, err - } - - return ConvertBalancerInfoToProto(ffiInfo), nil -} - -func (b *BalancerManager) ActiveSessions() *balancerpb.BalancerInfo { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("active_sessions") - defer tracker.Fix() - - return ConvertBalancerInfoToProto(b.handle.ActiveSessions()) -} - -func (b *BalancerManager) Stats( - ref *balancerpb.PacketHandlerRef, -) (*balancerpb.BalancerStats, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("stats") - defer tracker.Fix() - - // Convert protobuf ref to FFI ref - ffiRef := &ffi.PacketHandlerRef{ - Device: ref.Device, - Pipeline: ref.Pipeline, - Function: ref.Function, - Chain: ref.Chain, - } - - ffiStats, err := b.handle.Stats(ffiRef) - if err != nil { - return nil, err - } - - return ConvertBalancerStatsToProto(ffiStats), nil -} - -//////////////////////////////////////////////////////////////////////////////// - -func (b *BalancerManager) Metrics( - now time.Time, - ref *balancerpb.PacketHandlerRef, -) ([]*commonpb.Metric, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("metrics") - defer tracker.Fix() - - // Convert protobuf ref to FFI ref - ffiRef := &ffi.PacketHandlerRef{ - Device: ref.Device, - Pipeline: ref.Pipeline, - Function: ref.Function, - Chain: ref.Chain, - } - - ffiStats, err := b.handle.Stats(ffiRef) - if err != nil { - return nil, fmt.Errorf("failed to get stats: %s", err) - } - - info, err := b.handle.Info(now) - if err != nil { - return nil, fmt.Errorf("failed to get info: %s", err) - } - - config := b.handle.Config() - - refLabels := []*commonpb.Label{ - {Name: "device", Value: *ref.Device}, - {Name: "pipeline", Value: *ref.Pipeline}, - {Name: "function", Value: *ref.Function}, - {Name: "chain", Value: *ref.Chain}, - {Name: "config", Value: b.Name()}, - } - - makeCounter := func(name string, value uint64, extraLabels ...*commonpb.Label) *commonpb.Metric { - metric := commonpb.Metric{ - Name: name, - Labels: append(refLabels, extraLabels...), - Value: &commonpb.Metric_Counter{Counter: value}, - } - return &metric - } - - makeGauge := func(name string, value float64, extraLabels ...*commonpb.Label) *commonpb.Metric { - metric := commonpb.Metric{ - Name: name, - Labels: append(refLabels, extraLabels...), - Value: &commonpb.Metric_Gauge{Gauge: value}, - } - return &metric - } - - commonMetricsCount := len( - commonCounters, - ) + 2 // +2 for active sessions and session table capacity (from info and config) - - perVsMetrics := len( - vsCounters, - ) + 1 // +1 for active sessions (from info) - perRealMetrics := len( - realCounters, - ) + 1 // +1 for active sessions (from info) - - metricsCount := commonMetricsCount + perVsMetrics*len(ffiStats.Vs) - - for vsIdx := range ffiStats.Vs { - vs := &ffiStats.Vs[vsIdx] - metricsCount += perRealMetrics * len(vs.Reals) - } - - metrics := make([]*commonpb.Metric, 0, metricsCount) - - // make common metrics - { - // active sessions and session table capacity - metrics = append( - metrics, - makeGauge("active_sessions", float64(info.ActiveSessions)), - makeGauge( - "session_table_capacity", - float64(config.Balancer.State.TableCapacity), - ), - ) - - // counters - for _, counter := range commonCounters { - metrics = append( - metrics, - makeCounter(counter.name, counter.getter(ffiStats)), - ) - } - } - - // make vs metrics - for vsIdx := range ffiStats.Vs { - vs := &ffiStats.Vs[vsIdx] - vsInfo := &info.Vs[vsIdx] - labelsVS := []*commonpb.Label{ - {Name: "vip", Value: vs.Identifier.Addr.String()}, - {Name: "port", Value: strconv.Itoa(int(vs.Identifier.Port))}, - {Name: "protocol", Value: vs.Identifier.TransportProto.String()}, - } - - // active sessions - metrics = append( - metrics, - makeGauge( - "vs_active_sessions", - float64(vsInfo.ActiveSessions), - labelsVS..., - ), - ) - - // counters - for _, counter := range vsCounters { - metrics = append( - metrics, - makeCounter( - counter.name, - counter.getter(&vs.Stats), - labelsVS...), - ) - } - - // make real metrics - for realIdx := range vs.Reals { - real := &vs.Reals[realIdx] - realInfo := &vsInfo.Reals[realIdx] - labelsReal := append(labelsVS, &commonpb.Label{Name: "real_ip", Value: real.Dst.String()}) - - // active sessions - metrics = append( - metrics, - makeGauge( - "real_active_sessions", - float64(realInfo.ActiveSessions), - labelsReal..., - ), - ) - - // counters - for _, counter := range realCounters { - metrics = append( - metrics, - makeCounter( - counter.name, - counter.getter(&real.Stats), - labelsReal..., - ), - ) - } - } - - // make acl metrics - for aclIdx := range vs.AllowedSources { - acl := &vs.AllowedSources[aclIdx] - labelsACL := append(labelsVS, &commonpb.Label{Name: "acl_tag", Value: acl.Tag}) - - metrics = append( - metrics, - makeCounter( - "vs_acl_hits", - acl.Passes, - labelsACL..., - ), - ) - } - } - - calls := b.handlerMetrics.collect() - - return append(metrics, calls...), nil -} - -func (b *BalancerManager) Sessions( - now time.Time, -) ([]*balancerpb.SessionInfo, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("sessions") - defer tracker.Fix() - - ffiSessions := b.handle.Sessions(now) - - sessions := make([]*balancerpb.SessionInfo, 0, len(ffiSessions.Sessions)) - for i := range ffiSessions.Sessions { - sessions = append(sessions, ConvertSessionInfoToProto( - &ffiSessions.Sessions[i].Identifier, - &ffiSessions.Sessions[i].Info, - )) - } - - return sessions, nil -} - -func (b *BalancerManager) startBackgroundTasks() { - b.stopBackgroundTasks() - - if b.handle.Config().RefreshPeriod == 0 { - return - } - - var ctx context.Context - ctx, b.cancel = context.WithCancel(context.Background()) - - // Start background refresh task - go b.backgroundRefreshTask(ctx) -} - -// backgroundRefreshTask runs periodically to: -// 1. Get balancer info -// 2. Resize session table if load factor exceeds threshold -// 3. Adjust WLC weights based on active connections -// 4. Apply real updates if needed -func (b *BalancerManager) backgroundRefreshTask(ctx context.Context) { - for { - // Get current config to check refresh period - b.mu.Lock() - config := b.handle.Config() - refreshPeriod := config.RefreshPeriod - b.mu.Unlock() - - // If refresh period is zero, stop the task - if refreshPeriod == 0 { - b.log.Debugw( - "background refresh task stopped (refresh_period is zero)", - ) - return - } - - // Wait for refresh period or context cancellation - select { - case <-ctx.Done(): - b.log.Debugw("background refresh task stopped (context cancelled)") - return - case <-time.After(refreshPeriod): - // Continue with refresh - } - - now := time.Now() - - // Perform refresh with error handling - if err := b.Refresh(now); err != nil { - b.log.Errorw("background refresh failed", "error", err) - } - } -} - -// Refresh executes the refresh logic -func (b *BalancerManager) Refresh(now time.Time) error { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("refresh") - defer tracker.Fix() - - b.log.Info("refreshing state") - - // Get current config - config := b.handle.Config() - - // Get balancer info - info, err := b.handle.Info(now) - if err != nil { - return fmt.Errorf("failed to get info: %w", err) - } - - // Check if session table needs resizing - capacity := config.Balancer.State.TableCapacity - activeSessions := info.ActiveSessions - maxLoadFactor := config.MaxLoadFactor - - currentLoadFactor := float32(activeSessions) / float32(capacity) - - b.log.Debugw("fetched balancer info", - "current_capacity", capacity, - "active_sessions", activeSessions, - "current_load_factor", currentLoadFactor, - "max_load_factor", maxLoadFactor) - - if currentLoadFactor > maxLoadFactor { - newCapacity := capacity * 2 - b.log.Infow("resizing session table", - "current_capacity", capacity, - "new_capacity", newCapacity, - "active_sessions", activeSessions, - "current_load_factor", currentLoadFactor, - "max_load_factor", maxLoadFactor) - - if err := b.handle.ResizeSessionTable(newCapacity, now); err != nil { - b.log.Errorw("failed to resize session table", "error", err) - } else { - b.log.Infow("session table resized successfully", "new_capacity", newCapacity) - } - } - - // WLC real updates - use UpdateRealsWlc to preserve config weights - updates := WlcUpdates(b.handle.Config(), b.handle.Graph(), info) - - b.log.Debugw("calculated WLC updates", "count", len(updates)) - - if len(updates) > 0 { - b.log.Infow("applying WLC updates", "count", len(updates)) - if err := b.handle.UpdateRealsWlc(updates); err != nil { - b.log.Errorw("failed to apply WLC updates", "error", err) - } else { - b.log.Infow("WLC updates applied successfully", "count", len(updates)) - } - } - - return nil -} - -func (b *BalancerManager) stopBackgroundTasks() { - if b.cancel != nil { - b.cancel() - b.cancel = nil - } -} - -func (b *BalancerManager) Free() { - b.mu.Lock() - defer b.mu.Unlock() - - b.stopBackgroundTasks() -} - -// UpdateVS updates specific virtual services in the balancer configuration. -// -// This method takes the current FFI config, updates/adds the specified virtual -// services (provided in protobuf format), and calls the FFI update. The ACL -// reuse list in the returned UpdateInfo only contains virtual services from -// the update request. -// -// Behavior: -// - Virtual services in the request that already exist are replaced -// - Virtual services in the request that don't exist are added -// - Virtual services not in the request remain unchanged -func (b *BalancerManager) UpdateVS( - vsList []*balancerpb.VirtualService, - now time.Time, -) (*ffi.UpdateInfo, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("update_vs") - defer tracker.Fix() - - b.log.Debugw("updating virtual services", "vs_count", len(vsList)) - - // Convert protobuf VS list to FFI format - ffiVsList := make([]ffi.VsConfig, 0, len(vsList)) - for i, protoVs := range vsList { - ffiVs, err := protoToVsConfig(protoVs) - if err != nil { - b.log.Errorw("failed to convert VS", "index", i, "error", err) - return nil, fmt.Errorf( - "failed to convert VS at index %d: %w", - i, - err, - ) - } - ffiVsList = append(ffiVsList, ffiVs) - } - - // Get current config - currentConfig := b.handle.Config() - - // Build a map of VS identifiers from the request for quick lookup - requestedVsIds := make(map[ffi.VsIdentifier]bool) - for _, vs := range ffiVsList { - requestedVsIds[vs.Identifier] = true - } - - // Create new VS list: keep existing VS that are not in the request, - // then add/replace with VS from the request - newVsList := make([]ffi.VsConfig, 0) - - // Keep existing VS that are not being updated - for _, existingVs := range currentConfig.Balancer.Handler.VirtualServices { - if !requestedVsIds[existingVs.Identifier] { - newVsList = append(newVsList, existingVs) - } - } - - // Add VS from the request (these are new or updated) - newVsList = append(newVsList, ffiVsList...) - - // Update WLC config - recalculate which VS indices have WLC enabled - newWlcVs := make([]uint32, 0) - wlcEnabledOld := make(map[ffi.VsIdentifier]bool) - for _, vsIdx := range currentConfig.Wlc.Vs { - if int(vsIdx) < len(currentConfig.Balancer.Handler.VirtualServices) { - wlcEnabledOld[currentConfig.Balancer.Handler.VirtualServices[vsIdx].Identifier] = true - } - } - // Check which VS in the request have WLC enabled - wlcEnabledNew := make(map[ffi.VsIdentifier]bool) - for _, protoVs := range vsList { - if protoVs.Flags != nil && protoVs.Flags.Wlc { - ffiVs, _ := protoToVsConfig(protoVs) - wlcEnabledNew[ffiVs.Identifier] = true - } - } - // For new VS list, find indices of VS that have WLC enabled - for i, vs := range newVsList { - // If VS is in the request, use the new WLC flag; otherwise use the old one - if requestedVsIds[vs.Identifier] { - if wlcEnabledNew[vs.Identifier] { - newWlcVs = append(newWlcVs, uint32(i)) - } - } else if wlcEnabledOld[vs.Identifier] { - newWlcVs = append(newWlcVs, uint32(i)) - } - } - - // Create updated manager config - updatedConfig := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - SessionsTimeouts: currentConfig.Balancer.Handler.SessionsTimeouts, - VirtualServices: newVsList, - SourceV4: currentConfig.Balancer.Handler.SourceV4, - SourceV6: currentConfig.Balancer.Handler.SourceV6, - DecapV4: currentConfig.Balancer.Handler.DecapV4, - DecapV6: currentConfig.Balancer.Handler.DecapV6, - }, - State: currentConfig.Balancer.State, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: currentConfig.Wlc.Power, - MaxRealWeight: currentConfig.Wlc.MaxRealWeight, - Vs: newWlcVs, - }, - RefreshPeriod: currentConfig.RefreshPeriod, - MaxLoadFactor: currentConfig.MaxLoadFactor, - } - - // Update via FFI - updateInfo, err := b.handle.Update(updatedConfig, now) - if err != nil { - b.log.Errorw("failed to update manager", "error", err) - return nil, fmt.Errorf("failed to update manager: %w", err) - } - - // Filter ACL reuse list to only include VS from the update request - filteredUpdateInfo := filterACLReusesForRequestedVs( - updateInfo, - requestedVsIds, - ) - - b.log.Infow("virtual services updated successfully", - "vs_count", len(vsList), - "vs_ipv4_matcher_reused", filteredUpdateInfo.VsIpv4MatcherReused, - "vs_ipv6_matcher_reused", filteredUpdateInfo.VsIpv6MatcherReused, - "acl_reused_vs_count", len(filteredUpdateInfo.ACLReusedVs)) - - return filteredUpdateInfo, nil -} - -// DeleteVS deletes specific virtual services from the balancer configuration. -// -// This method takes the current FFI config, removes the specified virtual -// services (identified by protobuf VS list), and calls the FFI update. The ACL -// reuse list in the returned UpdateInfo is always empty since deleted VSs don't -// have ACL filters to reuse. -// -// Behavior: -// - Virtual services matching the identifiers in the request are removed -// - Virtual services not in the request remain unchanged -// - Deleting a non-existent VS is not an error (idempotent) -func (b *BalancerManager) DeleteVS( - vsList []*balancerpb.VirtualService, - now time.Time, -) (*ffi.UpdateInfo, error) { - b.mu.Lock() - defer b.mu.Unlock() - - tracker := b.newHandlerTracker("delete_vs") - defer tracker.Fix() - - b.log.Debugw("deleting virtual services", "vs_count", len(vsList)) - - // Get current config - currentConfig := b.handle.Config() - - // Build a map of VS identifiers to delete for quick lookup - vsToDelete := make(map[ffi.VsIdentifier]bool) - for _, protoVs := range vsList { - if protoVs.Id != nil { - vsID := protoVsIdentifierToFFI(protoVs.Id) - vsToDelete[vsID] = true - } - } - - // Create new VS list without the deleted VS - newVsList := make([]ffi.VsConfig, 0) - - for _, existingVs := range currentConfig.Balancer.Handler.VirtualServices { - if !vsToDelete[existingVs.Identifier] { - newVsList = append(newVsList, existingVs) - } - } - - // Update WLC config - recalculate which VS indices have WLC enabled - newWlcVs := make([]uint32, 0) - wlcEnabledOld := make(map[ffi.VsIdentifier]bool) - for _, vsIdx := range currentConfig.Wlc.Vs { - if int(vsIdx) < len(currentConfig.Balancer.Handler.VirtualServices) { - wlcEnabledOld[currentConfig.Balancer.Handler.VirtualServices[vsIdx].Identifier] = true - } - } - // For new VS list, find indices of VS that had WLC enabled - for i, vs := range newVsList { - if wlcEnabledOld[vs.Identifier] { - newWlcVs = append(newWlcVs, uint32(i)) - } - } - - // Create updated manager config - updatedConfig := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - SessionsTimeouts: currentConfig.Balancer.Handler.SessionsTimeouts, - VirtualServices: newVsList, - SourceV4: currentConfig.Balancer.Handler.SourceV4, - SourceV6: currentConfig.Balancer.Handler.SourceV6, - DecapV4: currentConfig.Balancer.Handler.DecapV4, - DecapV6: currentConfig.Balancer.Handler.DecapV6, - }, - State: currentConfig.Balancer.State, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: currentConfig.Wlc.Power, - MaxRealWeight: currentConfig.Wlc.MaxRealWeight, - Vs: newWlcVs, - }, - RefreshPeriod: currentConfig.RefreshPeriod, - MaxLoadFactor: currentConfig.MaxLoadFactor, - } - - // Update via FFI - updateInfo, err := b.handle.Update(updatedConfig, now) - if err != nil { - b.log.Errorw("failed to update manager", "error", err) - return nil, fmt.Errorf("failed to update manager: %w", err) - } - - // For delete, ACL reuse list should be empty - updateInfo.ACLReusedVs = []ffi.VsIdentifier{} - - b.log.Infow("virtual services deleted successfully", - "vs_count", len(vsList), - "vs_ipv4_matcher_reused", updateInfo.VsIpv4MatcherReused, - "vs_ipv6_matcher_reused", updateInfo.VsIpv6MatcherReused) - - return updateInfo, nil -} - -// protoVsIdentifierToFFI converts a protobuf VS identifier to FFI format -func protoVsIdentifierToFFI(id *balancerpb.VsIdentifier) ffi.VsIdentifier { - var addr netip.Addr - if id.Addr != nil { - addr, _ = netip.AddrFromSlice(id.Addr.Bytes) - } - - proto := ffi.VsTransportProtoUDP - if id.Proto == balancerpb.TransportProto_TCP { - proto = ffi.VsTransportProtoTCP - } - - return ffi.VsIdentifier{ - Addr: addr, - Port: uint16(id.Port), - TransportProto: proto, - } -} - -// filterACLReusesForRequestedVs filters the ACL reuse list to only include -// VS identifiers that were in the original update request. -// This ensures the response only reports reuse status for VSs that were -// actually part of the update operation. -func filterACLReusesForRequestedVs( - info *ffi.UpdateInfo, - requestedVsIds map[ffi.VsIdentifier]bool, -) *ffi.UpdateInfo { - if info == nil { - return nil - } - - filtered := &ffi.UpdateInfo{ - VsIpv4MatcherReused: info.VsIpv4MatcherReused, - VsIpv6MatcherReused: info.VsIpv6MatcherReused, - ACLReusedVs: make([]ffi.VsIdentifier, 0), - } - - for _, vsID := range info.ACLReusedVs { - if requestedVsIds[vsID] { - filtered.ACLReusedVs = append(filtered.ACLReusedVs, vsID) - } - } - - return filtered -} diff --git a/modules/balancer/agent/go/manager_test.go b/modules/balancer/agent/go/manager_test.go deleted file mode 100644 index 402e6d900..000000000 --- a/modules/balancer/agent/go/manager_test.go +++ /dev/null @@ -1,1628 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - yanet2 "github.com/yanet-platform/yanet2/controlplane/ffi" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/durationpb" -) - -var ( - deviceName string = "eth0" - pipelineName string = "pipeline0" - functionName string = "function0" - chainName string = "chain0" - balancerName string = "balancer0" -) - -func TestManager(t *testing.T) { - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 28, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: deviceName, - }, - }, - }) - require.NoError(t, err, "failed to create mock") - require.NotNil(t, m, "mock is nil") - - // Create balancer agent - agent, err := ffi.NewBalancerAgent(m.SharedMemory(), 1<<25) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - // Create logger - logger, err := zap.NewDevelopment() - require.NoError(t, err, "failed to create logger") - sugaredLogger := logger.Sugar() - - // Create protobuf config with zero refresh_period to prevent background tasks - capacity := uint64(1000) - maxLoadFactor := float32(0.75) - power := uint64(10) - maxWeight := uint32(1024) - - protoConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: true}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.214"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.1.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 150, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.215"). - AsSlice(), - }, - Port: 8081, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.2.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 200, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.1.1.1"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.12.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.1.1.3").AsSlice()}, - {Bytes: netip.MustParseAddr("12.1.1.4").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::2").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::3").AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.40").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: false}, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.41"). - AsSlice(), - }, - Port: 8443, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.17.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.42"). - AsSlice(), - }, - Port: 8443, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.17.1.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.2.2.3").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::10").AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.70").AsSlice(), - }, - Port: 53, - Proto: balancerpb.TransportProto_UDP, - }, - Flags: &balancerpb.VsFlags{FixMss: false}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.71"). - AsSlice(), - }, - Port: 5353, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 50, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.72"). - AsSlice(), - }, - Port: 5353, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.1.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 75, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.73"). - AsSlice(), - }, - Port: 5353, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.2.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.74"). - AsSlice(), - }, - Port: 5353, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.3.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 125, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.75"). - AsSlice(), - }, - Port: 5354, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.4.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 150, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("12.3.3.3").AsSlice()}, - {Bytes: netip.MustParseAddr("12.3.3.4").AsSlice()}, - {Bytes: netip.MustParseAddr("12.3.3.5").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::20").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::21").AsSlice()}, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.13.11.215").AsSlice()}, - {Bytes: netip.MustParseAddr("10.14.11.214").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::3").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::2").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: &maxLoadFactor, - RefreshPeriod: durationpb.New( - 0, - ), // Zero to prevent background tasks - Wlc: &balancerpb.WlcConfig{ - Power: &power, - MaxWeight: &maxWeight, - }, - }, - } - - // Convert to FFI config - managerConfig, err := ProtoToManagerConfig(protoConfig) - require.NoError(t, err, "failed to convert config") - - // Create manager via FFI - managerHandle, err := agent.NewManager(balancerName, managerConfig) - require.NoError(t, err, "failed to create balancer manager") - require.NotNil(t, managerHandle, "balancer manager handle is nil") - - // Create Go wrapper - manager := NewBalancerManager(managerHandle, sugaredLogger) - require.NotNil(t, manager, "manager should not be nil") - - // Use mock's current time for all operations - now := m.CurrentTime() - - // Test 1: Verify manager was created successfully - t.Run("ManagerCreation", func(t *testing.T) { - require.NotNil(t, manager, "manager should not be nil") - - // Verify we can get the manager name - name := manager.Name() - require.Equal(t, balancerName, name, "manager name should match") - }) - - t.Run("SetupControlplane", func(t *testing.T) { - cpAgent, err := m.SharedMemory().AgentReattach("bootstrap", 0, 1<<20) - require.NoError(t, err, "failed to attach bootstrap agent") - { - functionConfig := yanet2.FunctionConfig{ - Name: functionName, - Chains: []yanet2.FunctionChainConfig{ - { - Weight: 1, - Chain: yanet2.ChainConfig{ - Name: chainName, - Modules: []yanet2.ChainModuleConfig{ - { - Type: "balancer", - Name: balancerName, - }, - }, - }, - }, - }, - } - - if err := cpAgent.UpdateFunction(functionConfig); err != nil { - t.Fatalf("failed to update functions: %v", err) - } - } - - // update pipelines - { - inputPipelineConfig := yanet2.PipelineConfig{ - Name: pipelineName, - Functions: []string{functionName}, - } - - dummyPipelineConfig := yanet2.PipelineConfig{ - Name: "dummy", - Functions: []string{}, - } - - if err := cpAgent.UpdatePipeline(inputPipelineConfig); err != nil { - t.Fatalf("failed to update pipeline: %v", err) - } - - if err := cpAgent.UpdatePipeline(dummyPipelineConfig); err != nil { - t.Fatalf("failed to update pipeline: %v", err) - } - } - - // update devices - { - deviceConfig := yanet2.DeviceConfig{ - Name: deviceName, - Input: []yanet2.DevicePipelineConfig{ - { - Name: pipelineName, - Weight: 1, - }, - }, - Output: []yanet2.DevicePipelineConfig{ - { - Name: "dummy", - Weight: 1, - }, - }, - } - - if err := cpAgent.UpdatePlainDevices([]yanet2.DeviceConfig{deviceConfig}); err != nil { - t.Fatalf("failed to update pipelines: %v", err) - } - } - }) - - // Test 2: Get initial configuration - t.Run("GetInitialConfig", func(t *testing.T) { - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.Equal( - t, - 3, - len(config.PacketHandler.Vs), - "should have 3 virtual services", - ) - require.Equal( - t, - uint64(1000), - *config.State.SessionTableCapacity, - "capacity should match", - ) - }) - - // Test 3: Get initial graph - t.Run("GetInitialGraph", func(t *testing.T) { - graph := manager.Graph() - require.NotNil(t, graph, "graph should not be nil") - require.Equal( - t, - 3, - len(graph.VirtualServices), - "graph should have 3 virtual services", - ) - - // Verify first VS has 3 reals - require.Equal( - t, - 3, - len(graph.VirtualServices[0].Reals), - "first VS should have 3 reals", - ) - // Verify second VS has 2 reals - require.Equal( - t, - 2, - len(graph.VirtualServices[1].Reals), - "second VS should have 2 reals", - ) - // Verify third VS has 5 reals - require.Equal( - t, - 5, - len(graph.VirtualServices[2].Reals), - "third VS should have 5 reals", - ) - - // Verify all reals are initially disabled - for vsIdx, vs := range graph.VirtualServices { - for realIdx, real := range vs.Reals { - require.False( - t, - real.Enabled, - "VS %d Real %d should be initially disabled", - vsIdx, - realIdx, - ) - } - } - }) - - // Test 4: Get initial info - t.Run("GetInitialInfo", func(t *testing.T) { - info, err := manager.Info(now) - require.NoError(t, err, "failed to get info") - require.NotNil(t, info, "info should not be nil") - - // Check info variables are zeroes initially - require.Equal( - t, - uint64(0), - info.ActiveSessions, - "active sessions should be zero initially", - ) - - // Check info topology matches config topology - require.Equal(t, 3, len(info.Vs), "info should have 3 virtual services") - }) - - ref := &balancerpb.PacketHandlerRef{ - Device: &deviceName, - Pipeline: &pipelineName, - Function: &functionName, - Chain: &chainName, - } - - // Test 5: Get initial stats - t.Run("GetInitialStats", func(t *testing.T) { - stats, err := manager.Stats(ref) - require.NoError(t, err, "failed to get stats") - require.NotNil(t, stats, "stats should not be nil") - require.Equal( - t, - 3, - len(stats.Vs), - "stats should have 3 virtual services", - ) - - // Check common stats are zeroes - require.Equal( - t, - uint64(0), - stats.Common.IncomingPackets, - "incoming packets should be zero", - ) - require.Equal( - t, - uint64(0), - stats.Common.OutgoingPackets, - "outgoing packets should be zero", - ) - }) - - // Test 6: Get initial sessions - t.Run("GetInitialSessions", func(t *testing.T) { - sessions, err := manager.Sessions(now) - require.NoError(t, err, "failed to get sessions") - require.NotNil(t, sessions, "sessions should not be nil") - // Initially should have no sessions - require.Equal(t, 0, len(sessions), "should have no sessions initially") - }) - - // Test 7: Update reals with buffering - t.Run("UpdateRealsWithBuffering", func(t *testing.T) { - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - }, - Weight: func() *uint32 { w := uint32(250); return &w }(), - Enable: func() *bool { e := true; return &e }(), - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.214"). - AsSlice(), - }, - Port: 8080, - }, - }, - Weight: func() *uint32 { w := uint32(300); return &w }(), - Enable: func() *bool { e := true; return &e }(), - }, - } - - // Buffer the updates - count, err := manager.UpdateReals(updates, true) - require.NoError(t, err, "failed to buffer real updates") - require.Equal(t, 2, count, "should buffer 2 updates") - - // Verify updates are buffered - buffered := manager.BufferedUpdates() - require.Equal(t, 2, len(buffered), "should have 2 buffered updates") - - // Verify graph hasn't changed yet - graph := manager.Graph() - require.False(t, graph.VirtualServices[0].Reals[0].Enabled, - "real should still be disabled before flush") - - // Flush the updates - flushedCount, err := manager.FlushRealUpdates() - require.NoError(t, err, "failed to flush real updates") - require.Equal(t, 2, flushedCount, "should flush 2 updates") - - // Verify buffer is empty - buffered = manager.BufferedUpdates() - require.Equal(t, 0, len(buffered), "buffer should be empty after flush") - - // Verify graph has changed - graph = manager.Graph() - // Find the updated reals by identifier - foundFirst := false - foundSecond := false - for _, vs := range graph.VirtualServices { - vsAddr, _ := netip.AddrFromSlice(vs.Identifier.Addr.Bytes) - if vsAddr.String() == "10.12.13.213" && vs.Identifier.Port == 80 { - for _, real := range vs.Reals { - realAddr, _ := netip.AddrFromSlice(real.Identifier.Ip.Bytes) - if realAddr.String() == "10.12.13.213" && - real.Identifier.Port == 8080 { - require.True( - t, - real.Enabled, - "first real should be enabled after flush", - ) - require.Equal( - t, - uint32(250), - real.Weight, - "first real config weight should be 250", - ) - require.Equal( - t, - uint32(250), - real.EffectiveWeight, - "first real effective weight should be 250", - ) - foundFirst = true - } - if realAddr.String() == "10.12.13.214" && - real.Identifier.Port == 8080 { - require.True( - t, - real.Enabled, - "second real should be enabled after flush", - ) - require.Equal( - t, - uint32(300), - real.Weight, - "second real config weight should be 300", - ) - require.Equal( - t, - uint32(300), - real.EffectiveWeight, - "second real effective weight should be 300", - ) - foundSecond = true - } - } - } - } - require.True(t, foundFirst, "should find first updated real") - require.True(t, foundSecond, "should find second updated real") - }) - - // Test 8: Update reals without buffering - t.Run("UpdateRealsWithoutBuffering", func(t *testing.T) { - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.40").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.41").AsSlice(), - }, - Port: 8443, - }, - }, - Weight: func() *uint32 { w := uint32(200); return &w }(), - Enable: func() *bool { e := true; return &e }(), - }, - } - - // Apply immediately without buffering - count, err := manager.UpdateReals(updates, false) - require.NoError(t, err, "failed to update reals immediately") - require.Equal(t, 1, count, "should apply 1 update") - - // Verify buffer is still empty - buffered := manager.BufferedUpdates() - require.Equal(t, 0, len(buffered), "buffer should be empty") - - // Verify graph has changed immediately - graph := manager.Graph() - found := false - for _, vs := range graph.VirtualServices { - vsAddr, _ := netip.AddrFromSlice(vs.Identifier.Addr.Bytes) - if vsAddr.String() == "10.20.30.40" && vs.Identifier.Port == 443 { - for _, real := range vs.Reals { - realAddr, _ := netip.AddrFromSlice(real.Identifier.Ip.Bytes) - if realAddr.String() == "10.20.30.41" && - real.Identifier.Port == 8443 { - require.True( - t, - real.Enabled, - "real should be enabled immediately", - ) - require.Equal( - t, - uint32(200), - real.Weight, - "real config weight should be 200", - ) - require.Equal( - t, - uint32(200), - real.EffectiveWeight, - "real effective weight should be 200", - ) - found = true - } - } - } - } - require.True(t, found, "should find updated real") - }) - - // Test 9: Update manager configuration - t.Run("UpdateManagerConfig", func(t *testing.T) { - // Create a new config with different values - newCapacity := uint64(2000) - newMaxLoadFactor := float32(0.8) - - newConfig := &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - SessionTableCapacity: &newCapacity, - SessionTableMaxLoadFactor: &newMaxLoadFactor, - RefreshPeriod: durationpb.New(0), // Keep zero - }, - } - - _, err := manager.Update(newConfig, now) - require.NoError(t, err, "failed to update manager config") - - // Verify the new config is applied - updatedConfig := manager.Config() - require.NotNil(t, updatedConfig, "updated config should not be nil") - // Note: The capacity might not change immediately due to how the update works - // Just verify the config is returned successfully - require.NotNil( - t, - updatedConfig.State.SessionTableCapacity, - "capacity should not be nil", - ) - }) - - // Test 10: BufferedUpdates when empty - t.Run("BufferedUpdatesEmpty", func(t *testing.T) { - // Ensure buffer is empty - _, err := manager.FlushRealUpdates() - require.NoError(t, err, "failed to flush") - - buffered := manager.BufferedUpdates() - require.Equal(t, 0, len(buffered), "buffer should be empty") - }) - - // Test 11: FlushRealUpdates when empty - t.Run("FlushRealUpdatesEmpty", func(t *testing.T) { - count, err := manager.FlushRealUpdates() - require.NoError(t, err, "flushing empty buffer should not error") - require.Equal(t, 0, count, "should flush 0 updates") - }) - - // Test 12: Verify Name method - t.Run("NameMethod", func(t *testing.T) { - name := manager.Name() - require.Equal(t, balancerName, name, "name should match balancer name") - }) -} - -// TestMergeBalancerConfigRecursive tests recursive merging of balancer configuration -func TestMergeBalancerConfigRecursive(t *testing.T) { - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 28, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: deviceName, - }, - }, - }) - require.NoError(t, err, "failed to create mock") - require.NotNil(t, m, "mock is nil") - - // Create balancer agent - agent, err := ffi.NewBalancerAgent(m.SharedMemory(), 1<<26) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - // Create logger - logger, err := zap.NewDevelopment() - require.NoError(t, err, "failed to create logger") - sugaredLogger := logger.Sugar() - - // Create initial full config - capacity := uint64(1000) - maxLoadFactor := float32(0.75) - power := uint64(10) - maxWeight := uint32(1024) - - initialConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: true}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.13.11.215").AsSlice()}, - {Bytes: netip.MustParseAddr("2001:db8::3").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: &maxLoadFactor, - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: &power, - MaxWeight: &maxWeight, - }, - }, - } - - // Convert to FFI config and create manager - managerConfig, err := ProtoToManagerConfig(initialConfig) - require.NoError(t, err, "failed to convert config") - - managerHandle, err := agent.NewManager("test_merge_balancer", managerConfig) - require.NoError(t, err, "failed to create balancer manager") - require.NotNil(t, managerHandle, "balancer manager handle is nil") - - manager := NewBalancerManager(managerHandle, sugaredLogger) - require.NotNil(t, manager, "manager should not be nil") - - now := m.CurrentTime() - - // Test 1: Partial PacketHandler update - only source_address_v4 - t.Run("PartialPacketHandlerUpdate_OnlySourceV4", func(t *testing.T) { - newSourceV4 := &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.1").AsSlice(), - } - - updateConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: newSourceV4, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify source_v4 changed - config := manager.Config() - require.NotNil(t, config.PacketHandler) - require.Equal( - t, - newSourceV4.Bytes, - config.PacketHandler.SourceAddressV4.Bytes, - "source_address_v4 should be updated", - ) - - // Verify other fields preserved - require.NotNil(t, config.PacketHandler.SourceAddressV6, - "source_address_v6 should be preserved") - require.Equal(t, netip.MustParseAddr("2001:db8::1").AsSlice(), - config.PacketHandler.SourceAddressV6.Bytes, - "source_address_v6 should match original") - - require.NotNil(t, config.PacketHandler.SessionsTimeouts, - "sessions_timeouts should be preserved") - require.Equal( - t, - uint32(10), - config.PacketHandler.SessionsTimeouts.TcpSynAck, - "tcp_syn_ack timeout should match original", - ) - - require.NotNil(t, config.PacketHandler.Vs, - "virtual services should be preserved") - require.Equal(t, 1, len(config.PacketHandler.Vs), - "should have 1 virtual service") - - require.NotNil(t, config.PacketHandler.DecapAddresses, - "decap_addresses should be preserved") - require.Equal(t, 2, len(config.PacketHandler.DecapAddresses), - "should have 2 decap addresses") - }) - - // Test 2: Partial PacketHandler update - only virtual services - t.Run("PartialPacketHandlerUpdate_OnlyVirtualServices", func(t *testing.T) { - newVs := []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.40").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: false}, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.30.41"). - AsSlice(), - }, - Port: 8443, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.17.0.0").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - }, - } - - updateConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: newVs, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify VS changed - config := manager.Config() - require.NotNil(t, config.PacketHandler) - require.NotNil(t, config.PacketHandler.Vs) - require.Equal(t, 1, len(config.PacketHandler.Vs), - "should have 1 virtual service") - require.Equal(t, uint32(443), config.PacketHandler.Vs[0].Id.Port, - "VS port should be updated") - - // Verify other fields preserved (should be from previous update) - require.NotNil(t, config.PacketHandler.SourceAddressV4) - require.Equal(t, netip.MustParseAddr("192.168.1.1").AsSlice(), - config.PacketHandler.SourceAddressV4.Bytes, - "source_address_v4 should be preserved from previous update") - - require.NotNil(t, config.PacketHandler.SessionsTimeouts) - require.Equal( - t, - uint32(10), - config.PacketHandler.SessionsTimeouts.TcpSynAck, - "tcp_syn_ack timeout should be preserved", - ) - }) - - // Test 3: Partial State update - only session_table_capacity - t.Run("PartialStateUpdate_OnlyCapacity", func(t *testing.T) { - newCapacity := uint64(2000) - - updateConfig := &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - SessionTableCapacity: &newCapacity, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify capacity changed - config := manager.Config() - require.NotNil(t, config.State) - require.NotNil(t, config.State.SessionTableCapacity) - require.LessOrEqual(t, newCapacity, *config.State.SessionTableCapacity, - "session_table_capacity should be updated") - - // Verify other fields preserved - require.NotNil(t, config.State.SessionTableMaxLoadFactor) - require.Equal(t, float32(0.75), *config.State.SessionTableMaxLoadFactor, - "max_load_factor should be preserved") - - require.NotNil(t, config.State.Wlc) - require.NotNil(t, config.State.Wlc.Power) - require.Equal(t, uint64(10), *config.State.Wlc.Power, - "wlc.power should be preserved") - }) - - // Test 4: Partial WLC update - only power - t.Run("PartialWlcUpdate_OnlyPower", func(t *testing.T) { - newPower := uint64(20) - - updateConfig := &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - Wlc: &balancerpb.WlcConfig{ - Power: &newPower, - }, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify power changed - config := manager.Config() - require.NotNil(t, config.State) - require.NotNil(t, config.State.Wlc) - require.NotNil(t, config.State.Wlc.Power) - require.Equal(t, newPower, *config.State.Wlc.Power, - "wlc.power should be updated") - - // Verify max_weight preserved - require.NotNil(t, config.State.Wlc.MaxWeight) - require.Equal(t, uint32(1024), *config.State.Wlc.MaxWeight, - "wlc.max_weight should be preserved") - - // Verify other state fields preserved - require.NotNil(t, config.State.SessionTableCapacity) - require.LessOrEqual(t, uint64(2000), *config.State.SessionTableCapacity, - "session_table_capacity should be preserved from previous update") - }) - - // Test 5: Nested partial update - State with partial Wlc - t.Run("NestedPartialUpdate_StateWithPartialWlc", func(t *testing.T) { - newMaxWeight := uint32(2048) - newLoadFactor := float32(0.85) - - updateConfig := &balancerpb.BalancerConfig{ - State: &balancerpb.StateConfig{ - SessionTableMaxLoadFactor: &newLoadFactor, - Wlc: &balancerpb.WlcConfig{ - MaxWeight: &newMaxWeight, - }, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify updated fields - config := manager.Config() - require.NotNil(t, config.State) - require.NotNil(t, config.State.SessionTableMaxLoadFactor) - require.Equal(t, newLoadFactor, *config.State.SessionTableMaxLoadFactor, - "max_load_factor should be updated") - - require.NotNil(t, config.State.Wlc) - require.NotNil(t, config.State.Wlc.MaxWeight) - require.Equal(t, newMaxWeight, *config.State.Wlc.MaxWeight, - "wlc.max_weight should be updated") - - // Verify wlc.power preserved (recursive fallback) - require.NotNil(t, config.State.Wlc.Power) - require.Equal(t, uint64(20), *config.State.Wlc.Power, - "wlc.power should be preserved from previous update") - - // Verify other state fields preserved - require.NotNil(t, config.State.SessionTableCapacity) - require.LessOrEqual(t, uint64(2000), *config.State.SessionTableCapacity, - "session_table_capacity should be preserved") - }) - - // Test 6: Update with empty PacketHandler (all fields nil) - t.Run("EmptyPacketHandlerUpdate", func(t *testing.T) { - updateConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - // All fields nil - should fallback to current - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify all fields preserved - config := manager.Config() - require.NotNil(t, config.PacketHandler) - - require.NotNil(t, config.PacketHandler.SourceAddressV4) - require.NotNil(t, config.PacketHandler.SourceAddressV6) - require.NotNil(t, config.PacketHandler.SessionsTimeouts) - require.NotNil(t, config.PacketHandler.Vs) - require.NotNil(t, config.PacketHandler.DecapAddresses) - - // Verify values match previous state - require.Equal(t, 1, len(config.PacketHandler.Vs), - "should preserve VS from previous update") - }) - - // Test 7: Full replacement - all fields provided - t.Run("FullReplacement", func(t *testing.T) { - newCapacity := uint64(3000) - newMaxLoadFactor := float32(0.9) - newPower := uint64(15) - newMaxWeight := uint32(512) - - fullConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 5, - TcpSyn: 10, - TcpFin: 8, - Tcp: 50, - Udp: 6, - Default: 10, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.70"). - AsSlice(), - }, - Port: 53, - Proto: balancerpb.TransportProto_UDP, - }, - Flags: &balancerpb.VsFlags{FixMss: false}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.50.60.71"). - AsSlice(), - }, - Port: 5353, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.18.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 50, - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.100.100.100").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::100").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.200.200.200").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &newCapacity, - SessionTableMaxLoadFactor: &newMaxLoadFactor, - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: &newPower, - MaxWeight: &newMaxWeight, - }, - }, - } - - _, err := manager.Update(fullConfig, now) - require.NoError(t, err, "failed to update config") - - // Verify all fields updated - config := manager.Config() - - // Check PacketHandler - require.NotNil(t, config.PacketHandler) - require.Equal( - t, - uint32(5), - config.PacketHandler.SessionsTimeouts.TcpSynAck, - ) - require.Equal(t, netip.MustParseAddr("10.100.100.100").AsSlice(), - config.PacketHandler.SourceAddressV4.Bytes) - require.Equal(t, netip.MustParseAddr("2001:db8::100").AsSlice(), - config.PacketHandler.SourceAddressV6.Bytes) - require.Equal(t, 1, len(config.PacketHandler.Vs)) - require.Equal(t, uint32(53), config.PacketHandler.Vs[0].Id.Port) - require.Equal(t, 1, len(config.PacketHandler.DecapAddresses)) - - // Check State - require.NotNil(t, config.State) - require.LessOrEqual(t, newCapacity, *config.State.SessionTableCapacity) - require.Equal( - t, - newMaxLoadFactor, - *config.State.SessionTableMaxLoadFactor, - ) - require.NotNil(t, config.State.Wlc) - require.Equal(t, newPower, *config.State.Wlc.Power) - require.Equal(t, newMaxWeight, *config.State.Wlc.MaxWeight) - }) -} - -// TestManagerTagFieldInConfig tests that tag field is properly shown in manager config -func TestManagerTagFieldInConfig(t *testing.T) { - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 1 << 28, - DpMemory: 1 << 24, - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - { - ID: 0, - Name: deviceName, - }, - }, - }) - require.NoError(t, err, "failed to create mock") - require.NotNil(t, m, "mock is nil") - - // Create balancer agent - agent, err := ffi.NewBalancerAgent(m.SharedMemory(), 1<<25) - require.NoError(t, err, "failed to create balancer agent") - require.NotNil(t, agent, "balancer agent is nil") - - // Create logger - logger, err := zap.NewDevelopment() - require.NoError(t, err, "failed to create logger") - sugaredLogger := logger.Sugar() - - // Create config with specific tag values - capacity := uint64(1000) - maxLoadFactor := float32(0.75) - power := uint64(10) - maxWeight := uint32(1024) - - protoConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: true}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.1.1.1"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "11111"; return &s }(), // First tag - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.12.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "22222"; return &s }(), // Second tag - }, - }, - }, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: netip.MustParseAddr("10.13.11.215").AsSlice()}, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: &maxLoadFactor, - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: &power, - MaxWeight: &maxWeight, - }, - }, - } - - // Convert to FFI config and create manager - managerConfig, err := ProtoToManagerConfig(protoConfig) - require.NoError(t, err, "failed to convert config") - - managerHandle, err := agent.NewManager("test_tag_manager", managerConfig) - require.NoError(t, err, "failed to create balancer manager") - require.NotNil(t, managerHandle, "balancer manager handle is nil") - - manager := NewBalancerManager(managerHandle, sugaredLogger) - require.NotNil(t, manager, "manager should not be nil") - - now := m.CurrentTime() - - // Test 1: Verify initial config has correct tags - t.Run("InitialConfigTags", func(t *testing.T) { - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.NotNil( - t, - config.PacketHandler, - "packet handler should not be nil", - ) - require.Len( - t, - config.PacketHandler.Vs, - 1, - "should have 1 virtual service", - ) - require.Len( - t, - config.PacketHandler.Vs[0].AllowedSrcs, - 2, - "should have 2 allowed sources", - ) - - // Verify tags - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first tag should not be nil", - ) - assert.Equal( - t, - "11111", - *config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first tag should be 11111", - ) - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "second tag should not be nil", - ) - assert.Equal( - t, - "22222", - *config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "second tag should be 22222", - ) - }) - - // Test 2: Update config with different tags - t.Run("UpdateConfigTags", func(t *testing.T) { - updateConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Flags: &balancerpb.VsFlags{FixMss: true}, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.12.13.213"). - AsSlice(), - }, - Port: 8080, - }, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - Weight: 100, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.1.1.1"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "33333"; return &s }(), // Changed tag - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.12.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "0"; return &s }(), // Changed to zero - }, - }, - }, - }, - }, - } - - _, err := manager.Update(updateConfig, now) - require.NoError(t, err, "failed to update manager config") - - // Verify updated tags - config := manager.Config() - require.NotNil(t, config, "config should not be nil") - require.NotNil( - t, - config.PacketHandler, - "packet handler should not be nil", - ) - require.Len( - t, - config.PacketHandler.Vs, - 1, - "should have 1 virtual service", - ) - require.Len( - t, - config.PacketHandler.Vs[0].AllowedSrcs, - 2, - "should have 2 allowed sources", - ) - - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first tag should not be nil", - ) - assert.Equal( - t, - "33333", - *config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first tag should be updated to 33333", - ) - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "second tag should not be nil", - ) - assert.Equal( - t, - "0", - *config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "second tag should be updated to 0", - ) - }) -} diff --git a/modules/balancer/agent/go/metrics.go b/modules/balancer/agent/go/metrics.go deleted file mode 100644 index a169baa0e..000000000 --- a/modules/balancer/agent/go/metrics.go +++ /dev/null @@ -1,334 +0,0 @@ -package balancer - -import ( - "time" - - "github.com/yanet-platform/yanet2/common/commonpb" - "github.com/yanet-platform/yanet2/common/go/metrics" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" -) - -var commonCounters = []struct { - name string - getter func(*ffi.BalancerStats) uint64 -}{ - { - name: "incoming_bits", - getter: func(s *ffi.BalancerStats) uint64 { - return s.Common.IncomingBytes * 8 - }, - }, - { - name: "incoming_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.Common.IncomingPackets - }, - }, - { - name: "outgoing_bits", - getter: func(s *ffi.BalancerStats) uint64 { - return s.Common.OutgoingBytes * 8 - }, - }, - { - name: "outgoing_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.Common.OutgoingPackets - }, - }, - { - name: "l4_incoming_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.L4.IncomingPackets - }, - }, - { - name: "l4_outgoing_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.L4.OutgoingPackets - }, - }, - { - name: "l4_select_vs_failed", - getter: func(s *ffi.BalancerStats) uint64 { - return s.L4.SelectVsFailed - }, - }, - { - name: "icmp_ipv4_incoming_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv4.IncomingPackets - }, - }, - { - name: "icmp_ipv6_incoming_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv6.IncomingPackets - }, - }, - { - name: "icmp_ipv4_forwarded_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv4.ForwardedPackets - }, - }, - { - name: "icmp_ipv4_packet_clones_sent", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv4.PacketClonesSent - }, - }, - { - name: "icmp_ipv4_packet_clones_received", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv4.PacketClonesReceived - }, - }, - { - name: "icmp_ipv4_packet_clone_failures", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv4.PacketCloneFailures - }, - }, - { - name: "icmp_ipv6_forwarded_packets", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv6.ForwardedPackets - }, - }, - { - name: "icmp_ipv6_packet_clones_sent", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv6.PacketClonesSent - }, - }, - { - name: "icmp_ipv6_packet_clones_received", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv6.PacketClonesReceived - }, - }, - { - name: "icmp_ipv6_packet_clone_failures", - getter: func(s *ffi.BalancerStats) uint64 { - return s.IcmpIpv6.PacketCloneFailures - }, - }, -} - -var vsCounters = []struct { - name string - getter func(*ffi.VsStats) uint64 -}{ - { - name: "vs_incoming_bits", - getter: func(s *ffi.VsStats) uint64 { - return s.IncomingBytes * 8 - }, - }, - { - name: "vs_incoming_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.IncomingPackets - }, - }, - { - name: "vs_outgoing_bits", - getter: func(s *ffi.VsStats) uint64 { - return s.OutgoingBytes * 8 - }, - }, - { - name: "vs_outgoing_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.OutgoingPackets - }, - }, - { - name: "vs_created_sessions", - getter: func(s *ffi.VsStats) uint64 { - return s.CreatedSessions - }, - }, - { - name: "vs_packet_src_not_allowed", - getter: func(s *ffi.VsStats) uint64 { - return s.PacketSrcNotAllowed - }, - }, - { - name: "vs_no_reals", - getter: func(s *ffi.VsStats) uint64 { - return s.NoReals - }, - }, - { - name: "vs_ops_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.OpsPackets - }, - }, - { - name: "vs_session_table_overflow", - getter: func(s *ffi.VsStats) uint64 { - return s.SessionTableOverflow - }, - }, - { - name: "vs_echo_icmp_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.EchoIcmpPackets - }, - }, - { - name: "vs_error_icmp_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.ErrorIcmpPackets - }, - }, - { - name: "vs_real_is_disabled", - getter: func(s *ffi.VsStats) uint64 { - return s.RealIsDisabled - }, - }, - { - name: "vs_real_is_removed", - getter: func(s *ffi.VsStats) uint64 { - return s.RealIsRemoved - }, - }, - { - name: "vs_not_rescheduled_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.NotRescheduledPackets - }, - }, - { - name: "vs_broadcasted_icmp_packets", - getter: func(s *ffi.VsStats) uint64 { - return s.BroadcastedIcmpPackets - }, - }, -} - -var realCounters = []struct { - name string - getter func(*ffi.RealStats) uint64 -}{ - { - name: "real_incoming_bits", - getter: func(s *ffi.RealStats) uint64 { - return s.Bytes * 8 - }, - }, - { - name: "real_incoming_packets", - getter: func(s *ffi.RealStats) uint64 { - return s.Packets - }, - }, - { - name: "real_created_sessions", - getter: func(s *ffi.RealStats) uint64 { - return s.CreatedSessions - }, - }, - { - name: "real_icmp_error_packets", - getter: func(s *ffi.RealStats) uint64 { - return s.ErrorIcmpPackets - }, - }, - { - name: "real_ops_packets", - getter: func(s *ffi.RealStats) uint64 { - return s.OpsPackets - }, - }, - { - name: "packets_real_disabled", - getter: func(s *ffi.RealStats) uint64 { - return s.PacketsRealDisabled - }, - }, -} - -//////////////////////////////////////////////////////////////////////////////// - -type handlersMetrics struct { - callLatencies *metrics.MetricMap[*metrics.Histogram] -} - -func newHandlersMetrics() handlersMetrics { - return handlersMetrics{ - callLatencies: metrics.NewMetricMap[*metrics.Histogram](), - } -} - -func (m *handlersMetrics) collect() []*commonpb.Metric { - return commonpb.MetricRefsToProto(m.callLatencies.Metrics()) -} - -var defaultLatencyBoundsMS = []float64{ - 1, - 2, - 5, - 10, - 25, - 50, - 75, - 100, - 150, - 200, - 300, - 400, - 500, - 600, - 700, - 800, - 900, - 1000, - 1500, - 2000, - 3000, - 4000, - 5000, -} - -type handlerMetricTracker struct { - metricID metrics.MetricID - startTime time.Time - metrics *handlersMetrics - latencies []float64 -} - -func newHandlerMetricTracker( - handlerName string, - handlerMetrics *handlersMetrics, - latencies []float64, - labels metrics.Labels, -) *handlerMetricTracker { - if handlerMetrics == nil || latencies == nil { - return nil - } - id := metrics.MetricID{ - Name: handlerName, - Labels: labels, - } - return &handlerMetricTracker{ - metricID: id, - startTime: time.Now(), - metrics: handlerMetrics, - latencies: latencies, - } -} - -func (m *handlerMetricTracker) Fix() { - duration := time.Since(m.startTime) - - // update latencies - m.metrics.callLatencies.GetOrCreate(m.metricID, func() *metrics.Histogram { - return metrics.NewHistogram(m.latencies) - }).Observe(float64(duration.Milliseconds())) -} diff --git a/modules/balancer/agent/go/mod.go b/modules/balancer/agent/go/mod.go deleted file mode 100644 index 303da492c..000000000 --- a/modules/balancer/agent/go/mod.go +++ /dev/null @@ -1,60 +0,0 @@ -package balancer - -// BalancerModule provides the gRPC module interface for the balancer service, -// integrating with YANET's module system for service registration and lifecycle management. - -import ( - "fmt" - - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -type BalancerModule struct { - cfg *Config - service *BalancerService -} - -func NewBalancerModule( - cfg *Config, - log *zap.SugaredLogger, -) (*BalancerModule, error) { - log = log.With(zap.String("module", "balancerpb.BalancerService")) - - shm, err := yanet.AttachSharedMemory(cfg.MemoryPath.Unwrap()) - if err != nil { - return nil, fmt.Errorf("failed to attach shared memory: %w", err) - } - - svc, err := NewBalancerService(shm, cfg.MemoryRequirements.Unwrap(), log) - if err != nil { - return nil, fmt.Errorf("failed to create balancer service: %w", err) - } - - return &BalancerModule{ - cfg: cfg, - service: svc, - }, nil -} - -func (m *BalancerModule) Name() string { - return "balancer" -} - -func (m *BalancerModule) Endpoint() string { - return m.cfg.Endpoint.Unwrap() -} - -func (m *BalancerModule) ServicesNames() []string { - return []string{"balancerpb.BalancerService"} -} - -func (m *BalancerModule) RegisterService(server *grpc.Server) { - balancerpb.RegisterBalancerServiceServer(server, m.service) -} - -func (m *BalancerModule) Close() error { - return nil -} diff --git a/modules/balancer/agent/go/reuse_test.go b/modules/balancer/agent/go/reuse_test.go deleted file mode 100644 index ca7d0fce7..000000000 --- a/modules/balancer/agent/go/reuse_test.go +++ /dev/null @@ -1,2224 +0,0 @@ -package balancer - -import ( - "math/rand/v2" - "net/netip" - "strconv" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/durationpb" -) - -// TestACLAndFilterReuse is a comprehensive test that verifies the balancer's ACL and filter -// reuse optimization during configuration updates. It analyzes UpdateInfo returned by Update() -// to verify that: -// -// 1. IPv4 VS matcher is reused when the set of IPv4 virtual services remains unchanged -// 2. IPv6 VS matcher is reused when the set of IPv6 virtual services remains unchanged -// 3. IPv4/IPv6 VS matcher comparison is order-independent (different VS order = same matcher) -// 4. ACL filters are reused when allowed_srcs configuration remains the same -// 5. ACL comparison is order-independent (different ACL rule order = same ACL) -// 6. ACL comparison handles duplicates correctly -// 7. Partial changes are detected correctly (some VS changed, some unchanged) -// -// This test does NOT send packets - it only analyzes the UpdateInfo structure. -func TestACLAndFilterReuse(t *testing.T) { - // Create mock Yanet instance - m, err := mock.NewYanetMock(&mock.YanetMockConfig{ - AgentsMemory: 512 << 20, // 512 MB - DpMemory: 64 << 20, // 64 MB - Workers: 1, - Devices: []mock.YanetMockDeviceConfig{ - {ID: 0, Name: "eth0"}, - }, - }) - require.NoError(t, err) - defer m.Free() - - // Create logger for tests - log := zap.NewNop().Sugar() - - // Create balancer agent - agent, err := NewBalancerAgent(m.SharedMemory(), 256*datasize.MB, log) - require.NoError(t, err) - - // Helper to create a simple ACL (allow all) - createSimpleACL := func(isIPv6 bool) []*balancerpb.AllowedSources { - var addr, mask netip.Addr - if isIPv6 { - addr = netip.AddrFrom16([16]byte{}) - mask = netip.AddrFrom16([16]byte{}) - } else { - addr = netip.AddrFrom4([4]byte{0, 0, 0, 0}) - mask = netip.AddrFrom4([4]byte{0, 0, 0, 0}) - } - return []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{Bytes: addr.AsSlice()}, - Mask: &balancerpb.Addr{Bytes: mask.AsSlice()}, - }, - }, - }, - } - } - - // Helper to create a complex ACL with multiple rules - createComplexACL := func(variant int, isIPv6 bool) []*balancerpb.AllowedSources { - var acl []*balancerpb.AllowedSources - if isIPv6 { - // Rule 1: 2001:db8:1::/48 - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{{From: 1024, To: 65535}}, - }) - // Rule 2: 2001:db8:2::/48 with specific ports - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, 0, 2, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - }, - }) - } else { - // Rule 1: 10.0.0.0/8 - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{10, 0, 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 0, 0, 0}).AsSlice()}, - }}, - Ports: []*balancerpb.PortsRange{{From: 1024, To: 65535}}, - }) - // Rule 2: 192.168.0.0/16 with specific ports - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{192, 168, 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice()}, - }}, - Ports: []*balancerpb.PortsRange{{From: 80, To: 80}, {From: 443, To: 443}}, - }) - } - - // Add variant-specific rule - if variant > 0 { - if isIPv6 { - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, 0, byte(variant), 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }).AsSlice(), - }, - }}, - }) - } else { - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{172, byte(variant), 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice()}, - }}, - }) - } - } - return acl - } - - // Helper to create a large complex ACL with 15-20 rules and random duplicates - // This tests that ACL comparison handles duplicates correctly and works with many rules - createLargeComplexACL := func(variant int, isIPv6 bool, rng *rand.Rand) []*balancerpb.AllowedSources { - var acl []*balancerpb.AllowedSources - numRules := 15 + rng.IntN(6) // 15-20 rules - - for i := 0; i < numRules; i++ { - var rule *balancerpb.AllowedSources - if isIPv6 { - // Generate IPv6 rule with varying prefixes - addr := [16]byte{ - 0x20, - 0x01, - 0x0d, - 0xb8, - byte(variant), - byte(i), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - } - mask := [16]byte{ - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - } - rule = &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16(addr).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16(mask).AsSlice(), - }, - }}, - } - } else { - // Generate IPv4 rule with varying prefixes - addr := [4]byte{byte(10 + variant%240), byte(i), 0, 0} - mask := [4]byte{255, 255, 0, 0} - rule = &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4(addr).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4(mask).AsSlice()}, - }}, - } - } - - // Add port ranges to some rules - if i%3 == 0 { - rule.Ports = []*balancerpb.PortsRange{ - {From: uint32(1024 + i*100), To: uint32(2024 + i*100)}, - } - } else if i%3 == 1 { - rule.Ports = []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - {From: uint32(8000 + i), To: uint32(8100 + i)}, - } - } - - acl = append(acl, rule) - - // Randomly add duplicates (about 30% of rules will be duplicated) - if rng.Float32() < 0.3 { - acl = append(acl, rule) - } - } - - // Shuffle the ACL to ensure order independence is tested - rng.Shuffle(len(acl), func(i, j int) { - acl[i], acl[j] = acl[j], acl[i] - }) - - return acl - } - - // Helper to create a VS - createVS := func(ip netip.Addr, port uint16, proto balancerpb.TransportProto, acl []*balancerpb.AllowedSources) *balancerpb.VirtualService { - var realIP netip.Addr - var srcAddr, srcMask netip.Addr - if ip.Is4() { - realIP = netip.AddrFrom4([4]byte{192, 168, 1, 1}) - srcAddr = netip.AddrFrom4([4]byte{172, 16, 0, 1}) - srcMask = netip.AddrFrom4([4]byte{255, 255, 255, 255}) - } else { - realIP = netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) - srcAddr = netip.AddrFrom16([16]byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) - srcMask = netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - } - - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: proto, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: acl, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: realIP.AsSlice()}, - Port: 0, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{Bytes: srcAddr.AsSlice()}, - SrcMask: &balancerpb.Addr{Bytes: srcMask.AsSlice()}, - }, - }, - Flags: &balancerpb.VsFlags{}, - Peers: []*balancerpb.Addr{}, - } - } - - // Helper to create config - createConfig := func(vsList []*balancerpb.VirtualService) *balancerpb.BalancerConfig { - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: vsList, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(10000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - } - - // Helper to generate many virtual services for large-scale tests - // ipBase parameter allows using different IP ranges to avoid overlap between tests - generateManyVS := func(numIPv4, numIPv6 int, ipBase byte, aclGenerator func(idx int, isIPv6 bool) []*balancerpb.AllowedSources) []*balancerpb.VirtualService { - vsList := make([]*balancerpb.VirtualService, 0, numIPv4+numIPv6) - - // Generate IPv4 VS - for i := 0; i < numIPv4; i++ { - ip := netip.AddrFrom4( - [4]byte{ipBase, byte(i / 256), byte(i % 256), 1}, - ) - proto := balancerpb.TransportProto_TCP - if i%3 == 0 { - proto = balancerpb.TransportProto_UDP - } - port := uint16(80 + (i % 10)) - acl := aclGenerator(i, false) - vsList = append(vsList, createVS(ip, port, proto, acl)) - } - - // Generate IPv6 VS - for i := 0; i < numIPv6; i++ { - ip := netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, - ipBase, byte(i), 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, - }) - proto := balancerpb.TransportProto_TCP - if i%3 == 0 { - proto = balancerpb.TransportProto_UDP - } - port := uint16(80 + (i % 10)) - acl := aclGenerator(i, true) - vsList = append(vsList, createVS(ip, port, proto, acl)) - } - - return vsList - } - - // Helper to verify UpdateInfo - verifyUpdateInfo := func(t *testing.T, info *ffi.UpdateInfo, expectIPv4Reused, expectIPv6Reused bool, expectACLReusedCount int) { - t.Helper() - assert.Equal( - t, - expectIPv4Reused, - info.VsIpv4MatcherReused, - "IPv4 matcher reuse mismatch", - ) - assert.Equal( - t, - expectIPv6Reused, - info.VsIpv6MatcherReused, - "IPv6 matcher reuse mismatch", - ) - assert.Equal( - t, - expectACLReusedCount, - len(info.ACLReusedVs), - "ACL reused count mismatch", - ) - } - - // Test 1: Initial configuration - t.Run("InitialConfiguration", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - err := agent.NewBalancerManager("test", config) - require.NoError(t, err) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - require.NotNil(t, manager) - }) - - // Test 2: Identical configuration - everything should be reused - t.Run("IdenticalConfiguration", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Everything should be reused - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 3: VS order independence - IPv4 VS in different order - t.Run("IPv4_VSOrderIndependence", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - // Swap IPv4 VS order - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), - // Keep IPv6 VS order the same - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers should be reused (order doesn't matter), all ACLs reused - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 4: VS order independence - IPv6 VS in different order - t.Run("IPv6_VSOrderIndependence", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - // Keep IPv4 VS order from previous test - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), - // Swap IPv6 VS order - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers should be reused (order doesn't matter), all ACLs reused - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 5: VS order independence - both IPv4 and IPv6 in different order - t.Run("BothIPv4AndIPv6_VSOrderIndependence", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - // Different IPv4 order - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), - // Different IPv6 order - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers should be reused (order doesn't matter), all ACLs reused - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 6: Same VS identifiers, different ACL for some VS - t.Run("SameVS_DifferentACLForSome", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(2, false), - ), // Changed ACL - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // Same ACL - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), // Same ACL - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(2, true), - ), // Changed ACL - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers reused (VS identifiers unchanged), but only 2 ACLs reused - verifyUpdateInfo(t, updateInfo, true, true, 2) - }) - - // Test 7: ACL order independence - shuffled ACL should be considered the same - t.Run("ACLOrderIndependence", func(t *testing.T) { - // Create ACL with rules in different order - acl1 := createComplexACL(2, false) - acl2 := make([]*balancerpb.AllowedSources, len(acl1)) - // Reverse order - for i := range acl1 { - acl2[len(acl1)-1-i] = acl1[i] - } - - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.1"), - 80, - balancerpb.TransportProto_TCP, - acl2, - ), // Reversed order, should match - createVS( - netip.MustParseAddr("10.0.0.2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // Same as before - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(2, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (order doesn't matter) - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 8: Different IPv4 VS set, same IPv6 VS set - t.Run("DifferentIPv4_SameIPv6", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.3"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), // New IPv4 VS - createVS( - netip.MustParseAddr("10.0.0.4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // New IPv4 VS - createVS( - netip.MustParseAddr("2001:db8::1"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), // Same IPv6 VS - createVS( - netip.MustParseAddr("2001:db8::2"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(2, true), - ), // Same IPv6 VS - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // IPv4 matcher NOT reused (different VS set), IPv6 matcher reused, 2 IPv6 ACLs reused - verifyUpdateInfo(t, updateInfo, false, true, 2) - }) - - // Test 9: Same IPv4 VS set, different IPv6 VS set - t.Run("SameIPv4_DifferentIPv6", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.3"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(false), - ), // Same IPv4 VS - createVS( - netip.MustParseAddr("10.0.0.4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // Same IPv4 VS - createVS( - netip.MustParseAddr("2001:db8::3"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), // New IPv6 VS - createVS( - netip.MustParseAddr("2001:db8::4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), // New IPv6 VS - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // IPv4 matcher reused, IPv6 matcher NOT reused (different VS set), 2 IPv4 ACLs reused - verifyUpdateInfo(t, updateInfo, true, false, 2) - }) - - // Test 10: Protocol change for some VS - t.Run("ProtocolChange", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.3"), - 80, - balancerpb.TransportProto_UDP, - createSimpleACL(false), - ), // Changed protocol - createVS( - netip.MustParseAddr("10.0.0.4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // Same - createVS( - netip.MustParseAddr("2001:db8::3"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // IPv4 matcher NOT reused (protocol changed = different VS identifier), IPv6 reused - verifyUpdateInfo( - t, - updateInfo, - false, - true, - 3, - ) // 1 IPv4 VS with same ACL + 2 IPv6 VS - }) - - // Test 11: Port change for some VS - t.Run("PortChange", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.0.3"), - 8080, - balancerpb.TransportProto_UDP, - createSimpleACL(false), - ), // Changed port - createVS( - netip.MustParseAddr("10.0.0.4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, false), - ), // Same - createVS( - netip.MustParseAddr("2001:db8::3"), - 80, - balancerpb.TransportProto_TCP, - createSimpleACL(true), - ), - createVS( - netip.MustParseAddr("2001:db8::4"), - 80, - balancerpb.TransportProto_TCP, - createComplexACL(1, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // IPv4 matcher NOT reused (port changed = different VS identifier), IPv6 reused - verifyUpdateInfo(t, updateInfo, false, true, 3) // 1 IPv4 VS + 2 IPv6 VS - }) - - // Test 12: Completely different configuration - t.Run("CompletelyDifferent", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.1.1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, false), - ), - createVS( - netip.MustParseAddr("10.0.1.2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, false), - ), - createVS( - netip.MustParseAddr("2001:db8:1::1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, true), - ), - createVS( - netip.MustParseAddr("2001:db8:1::2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused - verifyUpdateInfo(t, updateInfo, false, false, 0) - }) - - // Test 13: Back to previous configuration - everything should be reused - t.Run("BackToPrevious", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.1.1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, false), - ), - createVS( - netip.MustParseAddr("10.0.1.2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, false), - ), - createVS( - netip.MustParseAddr("2001:db8:1::1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, true), - ), - createVS( - netip.MustParseAddr("2001:db8:1::2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Everything reused (same as previous) - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 14: ACL with duplicates - should be considered the same - t.Run("ACLWithDuplicates", func(t *testing.T) { - acl := createComplexACL(3, false) - aclWithDuplicates := make([]*balancerpb.AllowedSources, 0, len(acl)*2) - for _, rule := range acl { - aclWithDuplicates = append(aclWithDuplicates, rule) - aclWithDuplicates = append( - aclWithDuplicates, - rule, - ) // Duplicate each rule - } - - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.1.1"), - 443, - balancerpb.TransportProto_TCP, - aclWithDuplicates, - ), // ACL with duplicates - createVS( - netip.MustParseAddr("10.0.1.2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, false), - ), - createVS( - netip.MustParseAddr("2001:db8:1::1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, true), - ), - createVS( - netip.MustParseAddr("2001:db8:1::2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (duplicates don't matter) - verifyUpdateInfo(t, updateInfo, true, true, 4) - }) - - // Test 15: Mixed IPv4/IPv6 VS order with ACL changes - t.Run("MixedOrderWithACLChanges", func(t *testing.T) { - vsList := []*balancerpb.VirtualService{ - // Different order and one ACL change - createVS( - netip.MustParseAddr("10.0.1.2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(5, false), - ), // Changed ACL - createVS( - netip.MustParseAddr("10.0.1.1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, false), - ), // Same ACL (with duplicates from prev test) - createVS( - netip.MustParseAddr("2001:db8:1::2"), - 443, - balancerpb.TransportProto_UDP, - createComplexACL(4, true), - ), - createVS( - netip.MustParseAddr("2001:db8:1::1"), - 443, - balancerpb.TransportProto_TCP, - createComplexACL(3, true), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers reused (same VS set, different order), 3 ACLs reused (1 IPv4 changed) - verifyUpdateInfo(t, updateInfo, true, true, 3) - }) - - // Test 16: Large complex ACL with 15-20 rules and random duplicates - t.Run("LargeComplexACLWithDuplicates", func(t *testing.T) { - rng := rand.New( - rand.NewPCG(42, 0), - ) // Deterministic seed for reproducibility - - // Create initial configuration with large complex ACLs - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.2.1"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(1, false, rng), - ), - createVS( - netip.MustParseAddr("10.0.2.2"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(2, false, rng), - ), - createVS( - netip.MustParseAddr("2001:db8:2::1"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(1, true, rng), - ), - createVS( - netip.MustParseAddr("2001:db8:2::2"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(2, true, rng), - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused (completely new VS set) - verifyUpdateInfo(t, updateInfo, false, false, 0) - - // Now update with the same ACLs but regenerated with same seed (should be identical) - rng2 := rand.New(rand.NewPCG(42, 0)) // Same seed - vsList2 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.2.1"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(1, false, rng2), - ), - createVS( - netip.MustParseAddr("10.0.2.2"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(2, false, rng2), - ), - createVS( - netip.MustParseAddr("2001:db8:2::1"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(1, true, rng2), - ), - createVS( - netip.MustParseAddr("2001:db8:2::2"), - 80, - balancerpb.TransportProto_TCP, - createLargeComplexACL(2, true, rng2), - ), - } - config2 := createConfig(vsList2) - - updateInfo2, err := manager.Update(config2, m.CurrentTime()) - require.NoError(t, err) - - // Everything should be reused (same ACLs despite duplicates and shuffling) - verifyUpdateInfo(t, updateInfo2, true, true, 4) - }) - - // Test 17: Many virtual services (15 IPv4 + 15 IPv6 = 30 total) - t.Run("ManyVirtualServices", func(t *testing.T) { - rng := rand.New(rand.NewPCG(100, 0)) - - // Generate 15 IPv4 and 15 IPv6 VS with large complex ACLs - // Use ipBase=10 for this test group - vsList := generateManyVS( - 15, - 15, - 10, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - return createLargeComplexACL(idx, isIPv6, rng) - }, - ) - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused (completely new VS set) - verifyUpdateInfo(t, updateInfo, false, false, 0) - - // Update with identical configuration - rng2 := rand.New(rand.NewPCG(100, 0)) // Same seed - vsList2 := generateManyVS( - 15, - 15, - 10, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - return createLargeComplexACL(idx, isIPv6, rng2) - }, - ) - config2 := createConfig(vsList2) - - updateInfo2, err := manager.Update(config2, m.CurrentTime()) - require.NoError(t, err) - - // Everything should be reused (30 VS total) - verifyUpdateInfo(t, updateInfo2, true, true, 30) - }) - - // Test 18: Many VS with shuffled order - should still reuse - t.Run("ManyVS_ShuffledOrder", func(t *testing.T) { - rng := rand.New(rand.NewPCG(100, 0)) - - // Generate same VS as previous test (ipBase=10) - vsList := generateManyVS( - 15, - 15, - 10, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - return createLargeComplexACL(idx, isIPv6, rng) - }, - ) - - // Shuffle the VS list - shuffleRng := rand.New(rand.NewPCG(999, 0)) - shuffleRng.Shuffle(len(vsList), func(i, j int) { - vsList[i], vsList[j] = vsList[j], vsList[i] - }) - - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers should be reused (order doesn't matter), all 30 ACLs reused - verifyUpdateInfo(t, updateInfo, true, true, 30) - }) - - // Test 19: Many VS with some ACL changes - t.Run("ManyVS_SomeACLChanges", func(t *testing.T) { - rng := rand.New(rand.NewPCG(100, 0)) - - // Generate VS with same base but change ACL for first 5 IPv4 and first 5 IPv6 (ipBase=10) - vsList := generateManyVS( - 15, - 15, - 10, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - // Change ACL for indices 0-4 by using different variant - if idx < 5 { - return createLargeComplexACL( - idx+100, - isIPv6, - rng, - ) // Different variant - } - return createLargeComplexACL(idx, isIPv6, rng) - }, - ) - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Both matchers reused (same VS identifiers), but only 20 ACLs reused (10 changed) - verifyUpdateInfo(t, updateInfo, true, true, 20) - }) - - // Test 20: Large scale with 25 IPv4 + 25 IPv6 = 50 VS total - // Use different IP range (ipBase=20) to avoid overlap with previous tests - t.Run("LargeScale_50VS", func(t *testing.T) { - rng := rand.New(rand.NewPCG(200, 0)) - - // Generate 25 IPv4 and 25 IPv6 VS with ipBase=20 (different from previous tests) - vsList := generateManyVS( - 25, - 25, - 20, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - return createLargeComplexACL(idx, isIPv6, rng) - }, - ) - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused (completely new VS set with different IP range) - verifyUpdateInfo(t, updateInfo, false, false, 0) - - // Update with identical configuration - rng2 := rand.New(rand.NewPCG(200, 0)) // Same seed - vsList2 := generateManyVS( - 25, - 25, - 20, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - return createLargeComplexACL(idx, isIPv6, rng2) - }, - ) - config2 := createConfig(vsList2) - - updateInfo2, err := manager.Update(config2, m.CurrentTime()) - require.NoError(t, err) - - // Everything should be reused (50 VS total) - verifyUpdateInfo(t, updateInfo2, true, true, 50) - }) - - // Test 21: Large scale with shuffled ACL rules (order independence with many rules) - t.Run("LargeScale_ShuffledACLRules", func(t *testing.T) { - rng := rand.New(rand.NewPCG(200, 0)) - - // Generate VS with same ACLs but shuffle the rules within each ACL (ipBase=20) - vsList := generateManyVS( - 25, - 25, - 20, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - acl := createLargeComplexACL(idx, isIPv6, rng) - // Additional shuffle of the ACL rules - shuffleRng := rand.New(rand.NewPCG(uint64(idx+1000), 0)) - shuffleRng.Shuffle(len(acl), func(i, j int) { - acl[i], acl[j] = acl[j], acl[i] - }) - return acl - }, - ) - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (order doesn't matter even with many rules) - verifyUpdateInfo(t, updateInfo, true, true, 50) - }) - - // Test 22: Large scale with additional duplicates in ACLs - t.Run("LargeScale_ExtraDuplicates", func(t *testing.T) { - rng := rand.New(rand.NewPCG(200, 0)) - - // Generate VS with same ACLs but add extra duplicates (ipBase=20) - vsList := generateManyVS( - 25, - 25, - 20, - func(idx int, isIPv6 bool) []*balancerpb.AllowedSources { - acl := createLargeComplexACL(idx, isIPv6, rng) - // Add extra duplicates (duplicate first 5 rules again) - extraDuplicates := make( - []*balancerpb.AllowedSources, - 0, - len(acl)+5, - ) - extraDuplicates = append(extraDuplicates, acl...) - for i := 0; i < 5 && i < len(acl); i++ { - extraDuplicates = append(extraDuplicates, acl[i]) - } - // Shuffle - shuffleRng := rand.New(rand.NewPCG(uint64(idx+2000), 0)) - shuffleRng.Shuffle(len(extraDuplicates), func(i, j int) { - extraDuplicates[i], extraDuplicates[j] = extraDuplicates[j], extraDuplicates[i] - }) - return extraDuplicates - }, - ) - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (duplicates don't matter) - verifyUpdateInfo(t, updateInfo, true, true, 50) - }) - - // Test 23: ACL with different tags - should be considered equal (tags don't affect ACL comparison) - t.Run("ACLWithDifferentTags_ShouldBeEqual", func(t *testing.T) { - // Create initial config with specific tags - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.3.1"), - 80, - balancerpb.TransportProto_TCP, - []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, - }, - Tag: func() *string { s := "100"; return &s }(), // Tag = 100 - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - }, - Tag: func() *string { s := "200"; return &s }(), // Tag = 200 - }, - }, - ), - createVS( - netip.MustParseAddr("10.0.3.2"), - 80, - balancerpb.TransportProto_TCP, - []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "300"; return &s }(), // Tag = 300 - }, - }, - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused (new VS set) - verifyUpdateInfo(t, updateInfo, false, false, 0) - - // Update with same ACL rules but different tags - should be considered equal - vsList2 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.3.1"), - 80, - balancerpb.TransportProto_TCP, - []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, - }, - Tag: func() *string { s := "999"; return &s }(), // Different tag (was 100) - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - }, - Tag: func() *string { s := "888"; return &s }(), // Different tag (was 200) - }, - }, - ), - createVS( - netip.MustParseAddr("10.0.3.2"), - 80, - balancerpb.TransportProto_TCP, - []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "0"; return &s }(), // Different tag (was 300) - }, - }, - ), - } - config2 := createConfig(vsList2) - - updateInfo2, err := manager.Update(config2, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (tags don't affect ACL comparison) - verifyUpdateInfo(t, updateInfo2, true, true, 2) - }) - - // Test 24: Verify tag values are preserved in config after update - t.Run("TagValuesPreservedInConfig", func(t *testing.T) { - // Get current config and verify tag values - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - config := manager.Config() - require.NotNil(t, config) - require.NotNil(t, config.PacketHandler) - require.Len(t, config.PacketHandler.Vs, 2) - - // Verify first VS tags - require.Len(t, config.PacketHandler.Vs[0].AllowedSrcs, 2) - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first VS first tag should not be nil", - ) - assert.Equal( - t, - "999", - *config.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first VS first tag should be 999", - ) - require.NotNil( - t, - config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "first VS second tag should not be nil", - ) - assert.Equal( - t, - "888", - *config.PacketHandler.Vs[0].AllowedSrcs[1].Tag, - "first VS second tag should be 888", - ) - - // Verify second VS tag - require.Len(t, config.PacketHandler.Vs[1].AllowedSrcs, 1) - require.NotNil( - t, - config.PacketHandler.Vs[1].AllowedSrcs[0].Tag, - "second VS tag should not be nil", - ) - assert.Equal( - t, - "0", - *config.PacketHandler.Vs[1].AllowedSrcs[0].Tag, - "second VS tag should be 0", - ) - }) - - // Test 25: ACL reuse with many nets and port ranges in different order and with different tags - t.Run("ACLReuseWithManyNetsAndPortRanges", func(t *testing.T) { - // Helper to create AllowedSources with many nets and port ranges - createManyNetsACL := func(variant int, isIPv6 bool, tag uint32, rng *rand.Rand) []*balancerpb.AllowedSources { - numNets := 10 + rng.IntN(6) // 10-15 nets - numPorts := 3 + rng.IntN(3) // 3-5 port ranges - - nets := make([]*balancerpb.Net, numNets) - for i := range numNets { - if isIPv6 { - // Generate IPv6 networks - addr := [16]byte{ - 0x20, 0x01, 0x0d, 0xb8, - byte(variant), byte(i), 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - } - mask := [16]byte{ - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - } - nets[i] = &balancerpb.Net{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16(addr).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16(mask).AsSlice(), - }, - } - } else { - // Generate IPv4 networks - addr := [4]byte{byte(10 + variant%240), byte(i), 0, 0} - mask := [4]byte{255, 255, 0, 0} - nets[i] = &balancerpb.Net{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4(addr).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4(mask).AsSlice()}, - } - } - } - - ports := make([]*balancerpb.PortsRange, numPorts) - for i := range numPorts { - switch i { - case 0: - ports[i] = &balancerpb.PortsRange{From: 80, To: 80} - case 1: - ports[i] = &balancerpb.PortsRange{From: 443, To: 443} - case 2: - ports[i] = &balancerpb.PortsRange{From: 1024, To: 2048} - case 3: - ports[i] = &balancerpb.PortsRange{From: 8000, To: 9000} - case 4: - ports[i] = &balancerpb.PortsRange{From: 3000, To: 3999} - } - } - - return []*balancerpb.AllowedSources{ - { - Nets: nets, - Ports: ports, - Tag: func() *string { s := strconv.FormatUint(uint64(tag), 10); return &s }(), - }, - } - } - - // Helper to shuffle nets in AllowedSources - shuffleNets := func(acl []*balancerpb.AllowedSources, rng *rand.Rand) []*balancerpb.AllowedSources { - result := make([]*balancerpb.AllowedSources, len(acl)) - for i, rule := range acl { - newNets := make([]*balancerpb.Net, len(rule.Nets)) - copy(newNets, rule.Nets) - rng.Shuffle(len(newNets), func(i, j int) { - newNets[i], newNets[j] = newNets[j], newNets[i] - }) - result[i] = &balancerpb.AllowedSources{ - Nets: newNets, - Ports: rule.Ports, - Tag: rule.Tag, - } - } - return result - } - - // Helper to shuffle port ranges in AllowedSources - shufflePorts := func(acl []*balancerpb.AllowedSources, rng *rand.Rand) []*balancerpb.AllowedSources { - result := make([]*balancerpb.AllowedSources, len(acl)) - for i, rule := range acl { - newPorts := make([]*balancerpb.PortsRange, len(rule.Ports)) - copy(newPorts, rule.Ports) - rng.Shuffle(len(newPorts), func(i, j int) { - newPorts[i], newPorts[j] = newPorts[j], newPorts[i] - }) - result[i] = &balancerpb.AllowedSources{ - Nets: rule.Nets, - Ports: newPorts, - Tag: rule.Tag, - } - } - return result - } - - // Helper to change tags in AllowedSources - changeTags := func(acl []*balancerpb.AllowedSources, newTag uint32) []*balancerpb.AllowedSources { - result := make([]*balancerpb.AllowedSources, len(acl)) - for i, rule := range acl { - result[i] = &balancerpb.AllowedSources{ - Nets: rule.Nets, - Ports: rule.Ports, - Tag: func() *string { s := strconv.FormatUint(uint64(newTag), 10); return &s }(), - } - } - return result - } - - rng := rand.New(rand.NewPCG(300, 0)) // Deterministic seed - - // Scenario 1: Initial configuration with many nets and port ranges - acl1IPv4 := createManyNetsACL(1, false, 100, rng) - acl2IPv4 := createManyNetsACL(2, false, 200, rng) - acl1IPv6 := createManyNetsACL(1, true, 100, rng) - acl2IPv6 := createManyNetsACL(2, true, 200, rng) - - vsList := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6, - ), - } - config := createConfig(vsList) - - manager, err := agent.BalancerManager("test") - require.NoError(t, err) - - updateInfo, err := manager.Update(config, m.CurrentTime()) - require.NoError(t, err) - - // Nothing reused (new VS set) - verifyUpdateInfo(t, updateInfo, false, false, 0) - - // Scenario 2: Same ACL with shuffled net order - shuffleRng := rand.New(rand.NewPCG(301, 0)) - acl1IPv4Shuffled := shuffleNets(acl1IPv4, shuffleRng) - acl2IPv4Shuffled := shuffleNets(acl2IPv4, shuffleRng) - acl1IPv6Shuffled := shuffleNets(acl1IPv6, shuffleRng) - acl2IPv6Shuffled := shuffleNets(acl2IPv6, shuffleRng) - - vsList2 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4Shuffled, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Shuffled, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Shuffled, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Shuffled, - ), - } - config2 := createConfig(vsList2) - - updateInfo2, err := manager.Update(config2, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (net order doesn't matter) - verifyUpdateInfo(t, updateInfo2, true, true, 4) - - // Scenario 3: Same ACL with shuffled port range order - shuffleRng2 := rand.New(rand.NewPCG(302, 0)) - acl1IPv4PortShuffled := shufflePorts(acl1IPv4, shuffleRng2) - acl2IPv4PortShuffled := shufflePorts(acl2IPv4, shuffleRng2) - acl1IPv6PortShuffled := shufflePorts(acl1IPv6, shuffleRng2) - acl2IPv6PortShuffled := shufflePorts(acl2IPv6, shuffleRng2) - - vsList3 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4PortShuffled, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4PortShuffled, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6PortShuffled, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6PortShuffled, - ), - } - config3 := createConfig(vsList3) - - updateInfo3, err := manager.Update(config3, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (port range order doesn't matter) - verifyUpdateInfo(t, updateInfo3, true, true, 4) - - // Scenario 4: Same ACL with different tags - acl1IPv4NewTag := changeTags(acl1IPv4, 999) - acl2IPv4NewTag := changeTags(acl2IPv4, 888) - acl1IPv6NewTag := changeTags(acl1IPv6, 777) - acl2IPv6NewTag := changeTags(acl2IPv6, 666) - - vsList4 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4NewTag, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4NewTag, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6NewTag, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6NewTag, - ), - } - config4 := createConfig(vsList4) - - updateInfo4, err := manager.Update(config4, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (tags don't affect ACL comparison) - verifyUpdateInfo(t, updateInfo4, true, true, 4) - - // Scenario 5: Combined - shuffled nets, shuffled ports, and different tags - shuffleRng3 := rand.New(rand.NewPCG(303, 0)) - shuffleRng4 := rand.New(rand.NewPCG(304, 0)) - - acl1IPv4Combined := shuffleNets(acl1IPv4, shuffleRng3) - acl1IPv4Combined = shufflePorts(acl1IPv4Combined, shuffleRng4) - acl1IPv4Combined = changeTags(acl1IPv4Combined, 111) - - acl2IPv4Combined := shuffleNets(acl2IPv4, shuffleRng3) - acl2IPv4Combined = shufflePorts(acl2IPv4Combined, shuffleRng4) - acl2IPv4Combined = changeTags(acl2IPv4Combined, 222) - - acl1IPv6Combined := shuffleNets(acl1IPv6, shuffleRng3) - acl1IPv6Combined = shufflePorts(acl1IPv6Combined, shuffleRng4) - acl1IPv6Combined = changeTags(acl1IPv6Combined, 333) - - acl2IPv6Combined := shuffleNets(acl2IPv6, shuffleRng3) - acl2IPv6Combined = shufflePorts(acl2IPv6Combined, shuffleRng4) - acl2IPv6Combined = changeTags(acl2IPv6Combined, 444) - - vsList5 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4Combined, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config5 := createConfig(vsList5) - - updateInfo5, err := manager.Update(config5, m.CurrentTime()) - require.NoError(t, err) - - // All ACLs should be reused (order and tags don't matter) - verifyUpdateInfo(t, updateInfo5, true, true, 4) - - // Verify that tags are preserved in the final config - finalConfig := manager.Config() - require.NotNil(t, finalConfig) - require.NotNil(t, finalConfig.PacketHandler) - require.Len(t, finalConfig.PacketHandler.Vs, 4) - - // Check that the new tags are stored correctly - require.NotNil( - t, - finalConfig.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first IPv4 VS tag should not be nil", - ) - assert.Equal( - t, - "111", - *finalConfig.PacketHandler.Vs[0].AllowedSrcs[0].Tag, - "first IPv4 VS tag should be 111", - ) - require.NotNil( - t, - finalConfig.PacketHandler.Vs[1].AllowedSrcs[0].Tag, - "second IPv4 VS tag should not be nil", - ) - assert.Equal( - t, - "222", - *finalConfig.PacketHandler.Vs[1].AllowedSrcs[0].Tag, - "second IPv4 VS tag should be 222", - ) - require.NotNil( - t, - finalConfig.PacketHandler.Vs[2].AllowedSrcs[0].Tag, - "first IPv6 VS tag should not be nil", - ) - assert.Equal( - t, - "333", - *finalConfig.PacketHandler.Vs[2].AllowedSrcs[0].Tag, - "first IPv6 VS tag should be 333", - ) - require.NotNil( - t, - finalConfig.PacketHandler.Vs[3].AllowedSrcs[0].Tag, - "second IPv6 VS tag should not be nil", - ) - assert.Equal( - t, - "444", - *finalConfig.PacketHandler.Vs[3].AllowedSrcs[0].Tag, - "second IPv6 VS tag should be 444", - ) - - // Scenario 6: ALMOST matching - one net is different (should NOT reuse) - rng6 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4AlmostMatch := createManyNetsACL(1, false, 111, rng6) - // Modify one net in the middle - acl1IPv4AlmostMatch[0].Nets[5] = &balancerpb.Net{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{99, 99, 0, 0}).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice(), - }, - } - - vsList6 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4AlmostMatch, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config6 := createConfig(vsList6) - - updateInfo6, err := manager.Update(config6, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has different net) - verifyUpdateInfo(t, updateInfo6, true, true, 3) - - // Scenario 7: ALMOST matching - one port range is different (should NOT reuse) - rng7 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4AlmostMatchPort := createManyNetsACL(1, false, 111, rng7) - // Modify one port range - acl1IPv4AlmostMatchPort[0].Ports[1] = &balancerpb.PortsRange{ - From: 8443, - To: 8443, - } // Changed from 443 - - vsList7 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4AlmostMatchPort, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config7 := createConfig(vsList7) - - updateInfo7, err := manager.Update(config7, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has different port range) - verifyUpdateInfo(t, updateInfo7, true, true, 3) - - // Scenario 8: ALMOST matching - one net is missing (should NOT reuse) - rng8 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4MissingNet := createManyNetsACL(1, false, 111, rng8) - // Remove one net from the middle - acl1IPv4MissingNet[0].Nets = append( - acl1IPv4MissingNet[0].Nets[:3], - acl1IPv4MissingNet[0].Nets[4:]...) - - vsList8 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4MissingNet, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config8 := createConfig(vsList8) - - updateInfo8, err := manager.Update(config8, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has missing net) - verifyUpdateInfo(t, updateInfo8, true, true, 3) - - // Scenario 9: ALMOST matching - one port range is missing (should NOT reuse) - rng9 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4MissingPort := createManyNetsACL(1, false, 111, rng9) - // Remove one port range - acl1IPv4MissingPort[0].Ports = acl1IPv4MissingPort[0].Ports[:len(acl1IPv4MissingPort[0].Ports)-1] - - vsList9 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4MissingPort, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config9 := createConfig(vsList9) - - updateInfo9, err := manager.Update(config9, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has missing port range) - verifyUpdateInfo(t, updateInfo9, true, true, 3) - - // Scenario 10: ALMOST matching - one extra net added (should NOT reuse) - rng10 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4ExtraNet := createManyNetsACL(1, false, 111, rng10) - // Add one extra net - acl1IPv4ExtraNet[0].Nets = append( - acl1IPv4ExtraNet[0].Nets, - &balancerpb.Net{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{88, 88, 0, 0}).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice(), - }, - }, - ) - - vsList10 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4ExtraNet, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config10 := createConfig(vsList10) - - updateInfo10, err := manager.Update(config10, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has extra net) - verifyUpdateInfo(t, updateInfo10, true, true, 3) - - // Scenario 11: ALMOST matching - one extra port range added (should NOT reuse) - rng11 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4ExtraPort := createManyNetsACL(1, false, 111, rng11) - // Add one extra port range - acl1IPv4ExtraPort[0].Ports = append( - acl1IPv4ExtraPort[0].Ports, - &balancerpb.PortsRange{From: 9999, To: 9999}, - ) - - vsList11 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4ExtraPort, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config11 := createConfig(vsList11) - - updateInfo11, err := manager.Update(config11, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has extra port range) - verifyUpdateInfo(t, updateInfo11, true, true, 3) - - // Scenario 12: ALMOST matching - net mask is different (should NOT reuse) - rng12 := rand.New(rand.NewPCG(300, 0)) - acl1IPv4DifferentMask := createManyNetsACL(1, false, 111, rng12) - // Change mask of one net - acl1IPv4DifferentMask[0].Nets[2].Mask = &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 255, 0}).AsSlice(), - } // Changed from /16 to /24 - - vsList12 := []*balancerpb.VirtualService{ - createVS( - netip.MustParseAddr("10.0.4.1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv4DifferentMask, - ), - createVS( - netip.MustParseAddr("10.0.4.2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv4Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::1"), - 80, - balancerpb.TransportProto_TCP, - acl1IPv6Combined, - ), - createVS( - netip.MustParseAddr("2001:db8:4::2"), - 80, - balancerpb.TransportProto_TCP, - acl2IPv6Combined, - ), - } - config12 := createConfig(vsList12) - - updateInfo12, err := manager.Update(config12, m.CurrentTime()) - require.NoError(t, err) - - // Only 3 ACLs should be reused (first IPv4 VS has different mask) - verifyUpdateInfo(t, updateInfo12, true, true, 3) - }) -} diff --git a/modules/balancer/agent/go/service.go b/modules/balancer/agent/go/service.go deleted file mode 100644 index de70cba5a..000000000 --- a/modules/balancer/agent/go/service.go +++ /dev/null @@ -1,465 +0,0 @@ -package balancer - -// BalancerService implements the gRPC service interface for balancer management, -// providing RPC methods for configuration updates, real server management, statistics -// retrieval, and session inspection with automatic manager selection support. - -import ( - "context" - "fmt" - "time" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/c2h5oh/datasize" - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "go.uber.org/zap" -) - -//////////////////////////////////////////////////////////////////////////////// - -// BalancerService is agRPC service for controlling balancer -type BalancerService struct { - balancerpb.UnimplementedBalancerServiceServer - - agent *BalancerAgent - - log *zap.SugaredLogger -} - -//////////////////////////////////////////////////////////////////////////////// - -func NewBalancerService( - shm *yanet.SharedMemory, - memory datasize.ByteSize, - log *zap.SugaredLogger, -) (*BalancerService, error) { - log.Info("initializing balancer service") - - agent, err := NewBalancerAgent(shm, memory, log) - if err != nil { - log.Errorw("failed to create balancer agent", "error", err) - return nil, err - } - - service := &BalancerService{ - agent: agent, - log: log, - } - - return service, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// getManagerWithAutoSelection retrieves a balancer manager by name. -// If name is nil or empty, attempts to auto-select when exactly one manager exists. -// Returns the manager, the actual name used, and any error. -func (m *BalancerService) getManagerWithAutoSelection( - name *string, -) (*BalancerManager, string, error) { - // If name is provided and not empty, use it directly - if name != nil && *name != "" { - manager, err := m.agent.BalancerManager(*name) - if err != nil { - return nil, "", err - } - return manager, *name, nil - } - - // Name not provided - attempt auto-selection - managers := m.agent.Managers() - - if len(managers) == 0 { - return nil, "", status.Error( - codes.NotFound, - "no balancer managers found", - ) - } - - if len(managers) > 1 { - return nil, "", status.Error( - codes.InvalidArgument, - fmt.Sprintf( - "multiple balancer managers found (%d), please specify name explicitly", - len(managers), - ), - ) - } - - // Exactly one manager - auto-select it - selectedName := managers[0] - m.log.Infow("auto-selected balancer manager", "name", selectedName) - - manager, err := m.agent.BalancerManager(selectedName) - if err != nil { - return nil, "", err - } - - return manager, selectedName, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// UpdateConfig updates or enables balancer config -func (m *BalancerService) UpdateConfig( - ctx context.Context, - req *balancerpb.UpdateConfigRequest, -) (*balancerpb.UpdateConfigResponse, error) { - name := req.GetName() - if name == "" { - return nil, status.Error( - codes.InvalidArgument, - "module config name is required", - ) - } - - manager, _ := m.agent.BalancerManager(name) - if manager != nil { - m.log.Infow("updating balancer config", "name", name) - updateInfo, err := manager.Update(req.Config, time.Now()) - if err != nil { - m.log.Errorw( - "failed to update balancer", - "name", - name, - "error", - err, - ) - return nil, fmt.Errorf("failed to update balancer: %v", err) - } - m.log.Infow("balancer config updated", "name", name) - return &balancerpb.UpdateConfigResponse{ - Name: req.Name, - UpdateInfo: ConvertUpdateInfoToProto( - updateInfo, - false, - ), // created=false for updates - }, nil - } else { - m.log.Infow("creating new balancer", "name", name) - if err := m.agent.NewBalancerManager(name, req.Config); err != nil { - m.log.Errorw("failed to create balancer", "name", name, "error", err) - return nil, fmt.Errorf("failed to create balancer: %v", err) - } - m.log.Infow("balancer created", "name", name) - return &balancerpb.UpdateConfigResponse{ - Name: req.Name, - // Return update info with created=true for new balancer - UpdateInfo: ConvertUpdateInfoToProto(&ffi.UpdateInfo{ - VsIpv4MatcherReused: false, - VsIpv6MatcherReused: false, - ACLReusedVs: []ffi.VsIdentifier{}, - }, true), // created=true for new balancer - }, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// UpdateReals updates reals with optional buffering -func (m *BalancerService) UpdateReals( - ctx context.Context, - req *balancerpb.UpdateRealsRequest, -) (*balancerpb.UpdateRealsResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - count, err := manager.UpdateReals(req.Updates, req.Buffer) - if err != nil { - m.log.Errorw("failed to update reals", "name", name, "error", err) - msg := fmt.Sprintf("failed to make reals update: %v", err) - return nil, status.Error(codes.Internal, msg) - } - - if req.Buffer { - m.log.Debugw("real updates buffered", "name", name, "count", count) - } else { - m.log.Infow("real updates applied", "name", name, "count", count) - } - - return &balancerpb.UpdateRealsResponse{ - Name: name, - UpdatesApplied: uint32(count), - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// FlushRealUpdates flushes buffered reals updates -func (m *BalancerService) FlushRealUpdates( - ctx context.Context, - req *balancerpb.FlushRealUpdatesRequest, -) (*balancerpb.FlushRealUpdatesResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - count, err := manager.FlushRealUpdates() - if err != nil { - m.log.Errorw("failed to flush updates", "name", name, "error", err) - msg := fmt.Sprintf("failed to flush updates: %v", err) - return nil, status.Error(codes.Internal, msg) - } - - m.log.Infow("real updates flushed", "name", name, "count", count) - - return &balancerpb.FlushRealUpdatesResponse{ - UpdatesFlushed: uint32(count), - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ShowConfig shows balancer config -func (m *BalancerService) ShowConfig( - ctx context.Context, - req *balancerpb.ShowConfigRequest, -) (*balancerpb.ShowConfigResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - config := manager.Config() - bufferedUpdates := manager.BufferedUpdates() - - return &balancerpb.ShowConfigResponse{ - Name: name, - Config: config, - BufferedRealUpdates: bufferedUpdates, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ListConfigs lists balancer configs -func (m *BalancerService) ListConfigs( - ctx context.Context, - req *balancerpb.ListConfigsRequest, -) (*balancerpb.ListConfigsResponse, error) { - managers := m.agent.Managers() - m.log.Debugw("listing managers", "count", len(managers)) - return &balancerpb.ListConfigsResponse{ - Configs: managers, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ShowInfo returns info of the balancer state -func (m *BalancerService) ShowInfo( - ctx context.Context, - req *balancerpb.ShowInfoRequest, -) (*balancerpb.ShowInfoResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - info, err := manager.Info(time.Now()) - if err != nil { - msg := fmt.Sprintf("failed to get info: %v", err) - return nil, status.Error(codes.Internal, msg) - } - - return &balancerpb.ShowInfoResponse{ - Name: name, - Info: info, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -/* -ShowStats returns stats for balancer dataplane positions. - -Behavior: -- If req.name is specified: only positions belonging to that balancer instance are considered. -- If req.name is not specified: positions for all balancer instances are considered. -- PacketHandlerRef fields are filters (strict equality on specified fields). - -The enumeration pattern matches [`BalancerAgent.Metrics()`](modules/balancer/agent/go/agent.go:146): -we enumerate all balancer positions from DPConfig and then pick the manager per position. -*/ -func (m *BalancerService) ShowStats( - ctx context.Context, - req *balancerpb.ShowStatsRequest, -) (*balancerpb.ShowStatsResponse, error) { - entries, err := m.agent.StatsEntries(req.Name, req.Ref) - if err != nil { - msg := fmt.Sprintf("failed to get stats: %v", err) - return nil, status.Error(codes.Internal, msg) - } - - return &balancerpb.ShowStatsResponse{ - Entries: entries, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ShowSessions returns info about active balancer sessions -func (m *BalancerService) ShowSessions( - ctx context.Context, - req *balancerpb.ShowSessionsRequest, -) (*balancerpb.ShowSessionsResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - sessions, err := manager.Sessions(time.Now()) - if err != nil { - msg := fmt.Sprintf("failed to get sessions: %v", err) - return nil, status.Error(codes.Internal, msg) - } - - return &balancerpb.ShowSessionsResponse{ - Name: name, - Sessions: sessions, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ShowGraph returns the balancer topology graph -func (m *BalancerService) ShowGraph( - ctx context.Context, - req *balancerpb.ShowGraphRequest, -) (*balancerpb.ShowGraphResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - graph := manager.Graph() - - return &balancerpb.ShowGraphResponse{ - Name: name, - Graph: graph, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// ShowInspect returns memory inspection information for the balancer agent -func (m *BalancerService) ShowInspect( - ctx context.Context, - req *balancerpb.ShowInspectRequest, -) (*balancerpb.ShowInspectResponse, error) { - // Get inspect data from agent - inspect := m.agent.Inspect() - - return &balancerpb.ShowInspectResponse{ - Inspect: inspect, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -func (m *BalancerService) GetMetrics( - ctx context.Context, - req *balancerpb.GetMetricsRequest, -) (*balancerpb.GetMetricsResponse, error) { - metrics, err := m.agent.Metrics() - if err != nil { - return nil, err - } else { - return &balancerpb.GetMetricsResponse{ - Metrics: metrics, - }, nil - } -} - -// UpdateVS updates specific virtual services in the balancer configuration. -func (m *BalancerService) UpdateVS( - ctx context.Context, - req *balancerpb.UpdateVSRequest, -) (*balancerpb.UpdateVSResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - m.log.Infow( - "updating virtual services", - "name", - name, - "vs_count", - len(req.Vs), - ) - - updateInfo, err := manager.UpdateVS(req.Vs, time.Now()) - if err != nil { - m.log.Errorw( - "failed to update virtual services", - "name", name, - "error", err, - ) - return nil, fmt.Errorf("failed to update virtual services: %v", err) - } - - m.log.Infow( - "virtual services updated", - "name", - name, - "vs_count", - len(req.Vs), - ) - - return &balancerpb.UpdateVSResponse{ - Name: name, - Info: ConvertUpdateInfoToProto(updateInfo, false), - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// DeleteVS deletes specific virtual services from the balancer configuration. -func (m *BalancerService) DeleteVS( - ctx context.Context, - req *balancerpb.DeleteVSRequest, -) (*balancerpb.DeleteVSResponse, error) { - manager, name, err := m.getManagerWithAutoSelection(req.Name) - if err != nil { - return nil, err - } - - m.log.Infow( - "deleting virtual services", - "name", - name, - "vs_count", - len(req.Vs), - ) - - updateInfo, err := manager.DeleteVS(req.Vs, time.Now()) - if err != nil { - m.log.Errorw( - "failed to delete virtual services", - "name", name, - "error", err, - ) - return nil, fmt.Errorf("failed to delete virtual services: %v", err) - } - - m.log.Infow( - "virtual services deleted", - "name", - name, - "vs_count", - len(req.Vs), - ) - - return &balancerpb.DeleteVSResponse{ - Name: name, - Info: ConvertUpdateInfoToProto(updateInfo, false), - }, nil -} diff --git a/modules/balancer/agent/go/validate_test.go b/modules/balancer/agent/go/validate_test.go deleted file mode 100644 index d89088b84..000000000 --- a/modules/balancer/agent/go/validate_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" -) - -// TestValidation_InvalidPortRange tests that port ranges with from > to are rejected -func TestValidation_InvalidPortRange(t *testing.T) { - config := &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.100").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.1"). - AsSlice(), - }, - Port: 8080, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - }, - Ports: []*balancerpb.PortsRange{ - { - From: 8080, // Invalid: from > to - To: 80, - }, - }, - }, - }, - }, - }, - } - - _, err := ProtoToHandlerConfig(config) - require.Error(t, err, "Expected error for invalid port range") - assert.Contains( - t, - err.Error(), - "invalid range", - "Error should mention invalid range", - ) - assert.Contains( - t, - err.Error(), - "from=8080", - "Error should mention from value", - ) - assert.Contains(t, err.Error(), "to=80", "Error should mention to value") -} - -// TestValidation_TransportProtoTcpAndUdp tests that only TCP and UDP protocols are handled -func TestValidation_TransportProtoTcpAndUdp(t *testing.T) { - testCases := []struct { - name string - proto balancerpb.TransportProto - }{ - { - name: "TCP protocol", - proto: balancerpb.TransportProto_TCP, - }, - { - name: "UDP protocol", - proto: balancerpb.TransportProto_UDP, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - config := &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.100"). - AsSlice(), - }, - Port: 80, - Proto: tc.proto, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.1"). - AsSlice(), - }, - Port: 8080, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - }, - }, - }, - }, - }, - } - - result, err := ProtoToHandlerConfig(config) - require.NoError( - t, - err, - "Valid TCP/UDP protocol should not produce error", - ) - require.NotNil(t, result, "Result should not be nil") - require.Len( - t, - result.VirtualServices, - 1, - "Should have one virtual service", - ) - }) - } -} - -// TestValidation_AllowedSrcIPVersionMismatch tests that allowed_src networks must match VS IP version -func TestValidation_AllowedSrcIPVersionMismatch(t *testing.T) { - testCases := []struct { - name string - vsAddr string - allowedSrcAddr string - allowedSrcMask string - expectError bool - errorContains string - }{ - { - name: "IPv4 VS with IPv6 allowed_src", - vsAddr: "10.0.0.100", - allowedSrcAddr: "2001:db8::", - allowedSrcMask: "ffff:ffff::", - expectError: true, - errorContains: "IP version", - }, - { - name: "IPv6 VS with IPv4 allowed_src", - vsAddr: "2001:db8::100", - allowedSrcAddr: "192.168.0.0", - allowedSrcMask: "255.255.0.0", - expectError: true, - errorContains: "IP version", - }, - { - name: "IPv4 VS with IPv4 allowed_src (valid)", - vsAddr: "10.0.0.100", - allowedSrcAddr: "192.168.0.0", - allowedSrcMask: "255.255.0.0", - expectError: false, - }, - { - name: "IPv6 VS with IPv6 allowed_src (valid)", - vsAddr: "2001:db8::100", - allowedSrcAddr: "2001:db8:1::", - allowedSrcMask: "ffff:ffff:ffff::", - expectError: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - config := &balancerpb.PacketHandlerConfig{ - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 20, - TcpFin: 15, - Tcp: 100, - Udp: 11, - Default: 19, - }, - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.1").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::1").AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{}, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr(tc.vsAddr).AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_SOURCE_HASH, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.1"). - AsSlice(), - }, - Port: 8080, - }, - Weight: 100, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr(tc.allowedSrcAddr). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr(tc.allowedSrcMask). - AsSlice(), - }, - }, - }, - }, - }, - }, - }, - } - - _, err := ProtoToHandlerConfig(config) - - if tc.expectError { - require.Error(t, err, "Expected error for IP version mismatch") - assert.Contains( - t, - err.Error(), - tc.errorContains, - "Error should mention IP version", - ) - } else { - require.NoError(t, err, "Valid IP version match should not produce error") - } - }) - } -} diff --git a/modules/balancer/agent/go/wlc.go b/modules/balancer/agent/go/wlc.go deleted file mode 100644 index 473c41773..000000000 --- a/modules/balancer/agent/go/wlc.go +++ /dev/null @@ -1,118 +0,0 @@ -package balancer - -// WLC (Weighted Least Connection) algorithm implementation for dynamic weight adjustment -// based on active connection counts, calculating effective weights to balance load across -// real servers according to their capacity and current utilization. - -import ( - "math" - - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" -) - -func WlcUpdates( - config *ffi.BalancerManagerConfig, - graph *ffi.BalancerGraph, - info *ffi.BalancerInfo, -) []ffi.RealUpdate { - wlcVs := map[int]bool{} - for _, vs := range config.Wlc.Vs { - wlcVs[int(vs)] = true - } - - updates := []ffi.RealUpdate{} - - for vsIdx := range config.Balancer.Handler.VirtualServices { - if !wlcVs[vsIdx] { - continue - } - - vsConfig := &config.Balancer.Handler.VirtualServices[vsIdx] - vsGraph := &graph.VirtualServices[vsIdx] - vsInfo := &info.Vs[vsIdx] - - vsUpdates := vsWlcUpdates(&config.Wlc, vsConfig, vsGraph, vsInfo) - updates = append(updates, vsUpdates...) - } - - return updates -} - -func vsWlcUpdates( - wlc *ffi.BalancerManagerWlcConfig, - vsConfig *ffi.VsConfig, - vsGraph *ffi.GraphVs, - vsInfo *ffi.VsInfo, -) []ffi.RealUpdate { - realsCnt := len(vsConfig.Reals) - if realsCnt != len(vsGraph.Reals) || realsCnt != len(vsInfo.Reals) { - panic("invalid reals number") - } - - connectionsSum := uint64(0) - weightsSum := uint64(0) - for idx := range realsCnt { - if vsGraph.Reals[idx].Enabled { - connectionsSum += vsInfo.Reals[idx].ActiveSessions - weightsSum += uint64(vsConfig.Reals[idx].Weight) - } - } - - updates := []ffi.RealUpdate{} - - for idx := range realsCnt { - realConfig := &vsConfig.Reals[idx] - realGraph := &vsGraph.Reals[idx] - realInfo := &vsInfo.Reals[idx] - - // Only generate weight updates for enabled reals - if !realGraph.Enabled { - continue - } - - newWeight := calcWlcWeight( - wlc, - realConfig.Weight, - realInfo.ActiveSessions, - weightsSum, - connectionsSum, - ) - - if newWeight != realGraph.Weight { - updates = append(updates, ffi.RealUpdate{ - Identifier: ffi.RealIdentifier{ - Relative: realConfig.Identifier, - VsIdentifier: vsConfig.Identifier, - }, - Weight: newWeight, - Enabled: ffi.DontUpdateRealEnabled, - }) - } - } - - return updates -} - -func calcWlcWeight( - wlc *ffi.BalancerManagerWlcConfig, - weight uint16, - connections uint64, - weightSum uint64, - connectionsSum uint64, -) uint16 { - if weight == 0 || weightSum == 0 || connectionsSum < weightSum { - return weight - } - - scaledConnections := float64(connections) * float64(weightSum) - scaledWeight := float64(connectionsSum) * float64(weight) - connectionsRatio := scaledConnections / scaledWeight - - const minRatio = 1.0 - wlcRatio := math.Max(minRatio, float64(wlc.Power)*(1.0-connectionsRatio)) - - newWeight := uint64(math.Round(float64(weight) * wlcRatio)) - newWeight = min(newWeight, uint64(wlc.MaxRealWeight)) - - return uint16(newWeight) -} diff --git a/modules/balancer/agent/go/wlc_test.go b/modules/balancer/agent/go/wlc_test.go deleted file mode 100644 index fddb38d5d..000000000 --- a/modules/balancer/agent/go/wlc_test.go +++ /dev/null @@ -1,1302 +0,0 @@ -package balancer - -import ( - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" -) - -// TestCalcWlcWeight tests the core WLC weight calculation algorithm -func TestCalcWlcWeight(t *testing.T) { - tests := []struct { - name string - wlc *ffi.BalancerManagerWlcConfig - weight uint16 - connections uint64 - weightSum uint64 - connectionsSum uint64 - expected uint16 - }{ - { - name: "Zero weight returns zero", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 0, - connections: 100, - weightSum: 200, - connectionsSum: 500, - expected: 0, - }, - { - name: "Zero weightSum returns original weight", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 50, - weightSum: 0, - connectionsSum: 500, - expected: 100, - }, - { - name: "connectionsSum less than weightSum returns original weight", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 50, - weightSum: 200, - connectionsSum: 100, - expected: 100, - }, - { - name: "connectionsSum equals weightSum - algorithm proceeds", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 50, - weightSum: 200, - connectionsSum: 200, - // scaledConnections = 50 * 200 = 10000 - // scaledWeight = 200 * 100 = 20000 - // connectionsRatio = 10000 / 20000 = 0.5 - // wlcRatio = max(1.0, 10 * (1.0 - 0.5)) = max(1.0, 5.0) = 5.0 - // newWeight = round(100 * 5.0) = 500 - expected: 500, - }, - { - name: "Equal distribution - connections proportional to weight", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 200, // connections/connectionsSum = 200/400 = 0.5 = weight/weightSum = 100/200 - weightSum: 200, - connectionsSum: 400, - expected: 100, // ratio = 1.0, wlcRatio = max(1.0, 10*(1-1)) = 1.0 - }, - { - name: "Underloaded server gets higher weight", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 100, // connections/connectionsSum = 100/400 = 0.25, but weight/weightSum = 100/200 = 0.5 - weightSum: 200, - connectionsSum: 400, - // scaledConnections = 100 * 200 = 20000 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 20000 / 40000 = 0.5 - // wlcRatio = max(1.0, 10 * (1.0 - 0.5)) = max(1.0, 5.0) = 5.0 - // newWeight = round(100 * 5.0) = 500 - expected: 500, - }, - { - name: "Overloaded server keeps minimum weight (ratio >= 1)", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 300, // connections/connectionsSum = 300/400 = 0.75, but weight/weightSum = 100/200 = 0.5 - weightSum: 200, - connectionsSum: 400, - // scaledConnections = 300 * 200 = 60000 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 60000 / 40000 = 1.5 - // wlcRatio = max(1.0, 10 * (1.0 - 1.5)) = max(1.0, -5.0) = 1.0 - // newWeight = round(100 * 1.0) = 100 - expected: 100, - }, - { - name: "Weight capped at MaxRealWeight", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 20, - MaxRealWeight: 200, - }, - weight: 100, - connections: 50, // Very underloaded - weightSum: 200, - connectionsSum: 400, - // scaledConnections = 50 * 200 = 10000 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 10000 / 40000 = 0.25 - // wlcRatio = max(1.0, 20 * (1.0 - 0.25)) = max(1.0, 15.0) = 15.0 - // newWeight = round(100 * 15.0) = 1500, but capped at 200 - expected: 200, - }, - { - name: "Power factor of 1", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 1, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 100, - weightSum: 200, - connectionsSum: 400, - // connectionsRatio = 0.5 - // wlcRatio = max(1.0, 1 * (1.0 - 0.5)) = max(1.0, 0.5) = 1.0 - expected: 100, - }, - { - name: "Power factor of 2 with underloaded server", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 2, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 100, - weightSum: 200, - connectionsSum: 400, - // connectionsRatio = 0.5 - // wlcRatio = max(1.0, 2 * (1.0 - 0.5)) = max(1.0, 1.0) = 1.0 - expected: 100, - }, - { - name: "Power factor of 4 with underloaded server", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 4, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 100, - weightSum: 200, - connectionsSum: 400, - // connectionsRatio = 0.5 - // wlcRatio = max(1.0, 4 * (1.0 - 0.5)) = max(1.0, 2.0) = 2.0 - // newWeight = round(100 * 2.0) = 200 - expected: 200, - }, - { - name: "Zero connections (very underloaded)", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - }, - weight: 100, - connections: 0, - weightSum: 200, - connectionsSum: 400, - // scaledConnections = 0 * 200 = 0 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 0 / 40000 = 0 - // wlcRatio = max(1.0, 10 * (1.0 - 0)) = max(1.0, 10.0) = 10.0 - // newWeight = round(100 * 10.0) = 1000 - expected: 1000, - }, - { - name: "Large values", - wlc: &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 65535, - }, - weight: 1000, - connections: 1000000, - weightSum: 10000, - connectionsSum: 5000000, - // scaledConnections = 1000000 * 10000 = 10000000000 - // scaledWeight = 5000000 * 1000 = 5000000000 - // connectionsRatio = 10000000000 / 5000000000 = 2.0 - // wlcRatio = max(1.0, 10 * (1.0 - 2.0)) = max(1.0, -10.0) = 1.0 - // newWeight = round(1000 * 1.0) = 1000 - expected: 1000, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := calcWlcWeight( - tt.wlc, - tt.weight, - tt.connections, - tt.weightSum, - tt.connectionsSum, - ) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestVsWlcUpdates tests the virtual service level WLC update calculation -func TestVsWlcUpdates(t *testing.T) { - wlc := &ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - } - - vsIdentifier := ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - } - - t.Run("Empty reals returns no updates", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{}, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{}, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{}, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - assert.Empty(t, updates) - }) - - t.Run("All disabled reals returns no updates", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - Enabled: false, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - Enabled: false, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - {Dst: netip.MustParseAddr("192.168.1.1"), ActiveSessions: 50}, - {Dst: netip.MustParseAddr("192.168.1.2"), ActiveSessions: 50}, - }, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - assert.Len(t, updates, 0) - }) - - t.Run("Single enabled real with no weight change", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - {Dst: netip.MustParseAddr("192.168.1.1"), ActiveSessions: 50}, - }, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - // connectionsSum=50, weightSum=100, connectionsSum < weightSum, so weight unchanged - assert.Empty(t, updates) - }) - - t.Run( - "Multiple reals with unequal load generates updates", - func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 100, - }, // Underloaded - { - Dst: netip.MustParseAddr("192.168.1.2"), - ActiveSessions: 300, - }, // Overloaded - }, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - // connectionsSum=400, weightSum=200 - // Real 1: connections=100, weight=100 - // scaledConnections = 100 * 200 = 20000 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 0.5 - // wlcRatio = max(1.0, 10 * 0.5) = 5.0 - // newWeight = 500 - // Real 2: connections=300, weight=100 - // scaledConnections = 300 * 200 = 60000 - // scaledWeight = 400 * 100 = 40000 - // connectionsRatio = 1.5 - // wlcRatio = max(1.0, 10 * -0.5) = 1.0 - // newWeight = 100 (no change) - - require.Len(t, updates, 1) // Only real 1 has weight change - assert.Equal(t, uint16(500), updates[0].Weight) - assert.Equal( - t, - netip.MustParseAddr("192.168.1.1"), - updates[0].Identifier.Relative.Addr, - ) - }, - ) - - t.Run("Update includes correct identifiers", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 50, // Different from calculated weight - Enabled: true, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - {Dst: netip.MustParseAddr("192.168.1.1"), ActiveSessions: 50}, - }, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - // connectionsSum=50, weightSum=100, connectionsSum < weightSum - // newWeight = 100 (original), but graph.Weight=50, so update generated - - require.Len(t, updates, 1) - assert.Equal(t, vsIdentifier, updates[0].Identifier.VsIdentifier) - assert.Equal( - t, - netip.MustParseAddr("192.168.1.1"), - updates[0].Identifier.Relative.Addr, - ) - assert.Equal(t, uint16(8080), updates[0].Identifier.Relative.Port) - assert.Equal(t, ffi.DontUpdateRealEnabled, updates[0].Enabled) - }) - - t.Run( - "Mismatched reals count panics - config vs graph", - func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{}, // Mismatched count - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 50, - }, - }, - } - - assert.Panics(t, func() { - vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - }) - }, - ) - - t.Run("Mismatched reals count panics - config vs info", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{}, // Mismatched count - } - - assert.Panics(t, func() { - vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - }) - }) - - t.Run("Mixed enabled and disabled reals", func(t *testing.T) { - vsConfig := &ffi.VsConfig{ - Identifier: vsIdentifier, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.3"), - Port: 8080, - }, - Weight: 100, - }, - }, - } - vsGraph := &ffi.GraphVs{ - Identifier: vsIdentifier, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.2"), - Port: 8080, - }, - Weight: 100, - Enabled: false, // Disabled - }, - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.3"), - Port: 8080, - }, - Weight: 100, - Enabled: true, - }, - }, - } - vsInfo := &ffi.VsInfo{ - Identifier: vsIdentifier, - Reals: []ffi.RealInfo{ - {Dst: netip.MustParseAddr("192.168.1.1"), ActiveSessions: 100}, - { - Dst: netip.MustParseAddr("192.168.1.2"), - ActiveSessions: 50, - }, // Disabled, but has sessions - {Dst: netip.MustParseAddr("192.168.1.3"), ActiveSessions: 300}, - }, - } - - updates := vsWlcUpdates(wlc, vsConfig, vsGraph, vsInfo) - - // Expect updates for real 1 (weight change) - require.Len(t, updates, 1) - - // Find updates by address - var real1Update *ffi.RealUpdate - for i := range updates { - if updates[i].Identifier.Relative.Addr == netip.MustParseAddr( - "192.168.1.1", - ) { - real1Update = &updates[i] - } - } - - require.NotNil(t, real1Update) - assert.Equal(t, uint16(500), real1Update.Weight) - }) -} - -// TestWlcUpdates tests the main WLC updates function -func TestWlcUpdates(t *testing.T) { - t.Run("No WLC virtual services returns empty updates", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{}, // No WLC VS - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 50, - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 100, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - assert.Empty(t, updates) - }) - - t.Run("Single WLC virtual service", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, // VS index 0 has WLC - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 50, // Different from expected - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 50, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - // connectionsSum=50, weightSum=100, connectionsSum < weightSum - // newWeight = 100, but graph.Weight=50, so update generated - require.Len(t, updates, 1) - assert.Equal(t, uint16(100), updates[0].Weight) - }) - - t.Run("Multiple WLC virtual services", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.2.1", - ), - Port: 8443, - }, - Weight: 200, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0, 1}, // Both VS have WLC - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 50, - Enabled: true, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.2.1"), - Port: 8443, - }, - Weight: 100, - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 50, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.2.1"), - ActiveSessions: 100, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - // Both VS should generate updates since graph weights differ from calculated - require.Len(t, updates, 2) - }) - - t.Run("Mixed WLC and non-WLC virtual services", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.2.1", - ), - Port: 8443, - }, - Weight: 200, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 8080, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.3.1", - ), - Port: 9090, - }, - Weight: 150, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{ - 0, - 2, - }, // Only VS 0 and 2 have WLC (not VS 1) - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 8080, - }, - Weight: 50, - Enabled: true, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.2.1"), - Port: 8443, - }, - Weight: 50, // Different, but VS 1 is not WLC - Enabled: true, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 8080, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.3.1"), - Port: 9090, - }, - Weight: 75, - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 50, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.2"), - Port: 443, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.2.1"), - ActiveSessions: 100, - }, - }, - }, - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.3"), - Port: 8080, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.3.1"), - ActiveSessions: 75, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - // Only VS 0 and VS 2 should be processed (WLC enabled) - // VS 1 should be skipped even though its weight differs - require.Len(t, updates, 2) - - // Verify updates are for VS 0 and VS 2 only - for _, update := range updates { - vsAddr := update.Identifier.VsIdentifier.Addr - assert.True( - t, - vsAddr == netip.MustParseAddr("10.0.0.1") || - vsAddr == netip.MustParseAddr("10.0.0.3"), - "Update should be for VS 0 or VS 2, got VS with addr %s", - vsAddr, - ) - } - }) - - t.Run("Empty virtual services", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{}, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, // WLC enabled for non-existent VS - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{}, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{}, - } - - updates := WlcUpdates(config, graph, info) - assert.Empty(t, updates) - }) - - t.Run("IPv6 virtual service", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::1", - ), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "2001:db8::100", - ), - Port: 8080, - }, - Weight: 100, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("2001:db8::100"), - Port: 8080, - }, - Weight: 50, - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("2001:db8::1"), - Port: 80, - TransportProto: ffi.VsTransportProtoTCP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr( - "2001:db8::100", - ), - ActiveSessions: 50, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - require.Len(t, updates, 1) - assert.Equal( - t, - netip.MustParseAddr("2001:db8::1"), - updates[0].Identifier.VsIdentifier.Addr, - ) - assert.Equal( - t, - netip.MustParseAddr("2001:db8::100"), - updates[0].Identifier.Relative.Addr, - ) - }) - - t.Run("UDP transport protocol", func(t *testing.T) { - config := &ffi.BalancerManagerConfig{ - Balancer: ffi.BalancerConfig{ - Handler: ffi.PacketHandlerConfig{ - VirtualServices: []ffi.VsConfig{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.RealConfig{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr( - "192.168.1.1", - ), - Port: 5353, - }, - Weight: 100, - }, - }, - }, - }, - }, - }, - Wlc: ffi.BalancerManagerWlcConfig{ - Power: 10, - MaxRealWeight: 1000, - Vs: []uint32{0}, - }, - } - graph := &ffi.BalancerGraph{ - VirtualServices: []ffi.GraphVs{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.GraphReal{ - { - Identifier: ffi.RelativeRealIdentifier{ - Addr: netip.MustParseAddr("192.168.1.1"), - Port: 5353, - }, - Weight: 50, - Enabled: true, - }, - }, - }, - }, - } - info := &ffi.BalancerInfo{ - Vs: []ffi.VsInfo{ - { - Identifier: ffi.VsIdentifier{ - Addr: netip.MustParseAddr("10.0.0.1"), - Port: 53, - TransportProto: ffi.VsTransportProtoUDP, - }, - Reals: []ffi.RealInfo{ - { - Dst: netip.MustParseAddr("192.168.1.1"), - ActiveSessions: 50, - }, - }, - }, - }, - } - - updates := WlcUpdates(config, graph, info) - require.Len(t, updates, 1) - assert.Equal( - t, - ffi.VsTransportProtoUDP, - updates[0].Identifier.VsIdentifier.TransportProto, - ) - }) -} diff --git a/modules/balancer/agent/manager.c b/modules/balancer/agent/manager.c deleted file mode 100644 index d03bb760f..000000000 --- a/modules/balancer/agent/manager.c +++ /dev/null @@ -1,553 +0,0 @@ -#include "manager.h" -#include "api/agent.h" -#include "common/memory.h" -#include "common/memory_address.h" -#include "lib/controlplane/agent/agent.h" -#include "lib/controlplane/diag/diag.h" -#include "modules/balancer/controlplane/api/balancer.h" -#include "modules/balancer/controlplane/api/handler.h" -#include "modules/balancer/controlplane/api/real.h" -#include -#include -#include - -struct balancer_agent; - -struct balancer_manager { - struct balancer_handle *balancer; - struct balancer_manager_config config; - struct balancer_agent *agent; - struct diag diag; -}; - -//////////////////////////////////////////////////////////////////////////////// - -static struct memory_context * -balancer_manager_memory_context(struct balancer_manager *manager) { - struct balancer_agent *balancer_agent = ADDR_OF(&manager->agent); - struct agent *agent = (struct agent *)balancer_agent; - return &agent->memory_context; -} - -static void -setup_session_table_capacity(struct balancer_manager *manager) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - manager->config.balancer.state.table_capacity = - balancer_session_table_capacity(balancer); -} - -//////////////////////////////////////////////////////////////////////////////// - -const char * -balancer_manager_take_error(struct balancer_manager *manager) { - return diag_take_msg(&manager->diag); -} - -extern int -clone_balancer_config_to_relative( - struct balancer_config *dst, - struct balancer_config *src, - struct memory_context *ctx -); - -int -clone_manager_config_to_relative( - struct balancer_manager_config *dst, - struct balancer_manager_config *src, - struct memory_context *mctx -) { - // Clone balancer config - if (clone_balancer_config_to_relative( - &dst->balancer, &src->balancer, mctx - ) != 0) { - PUSH_ERROR("failed to clone balancer config"); - return -1; - } - - // Copy WLC scalar fields - dst->wlc.power = src->wlc.power; - dst->wlc.max_real_weight = src->wlc.max_real_weight; - dst->wlc.vs_count = src->wlc.vs_count; - - // Clone WLC vs array to relative pointers - if (src->wlc.vs_count > 0) { - uint32_t *vs_array = memory_balloc( - mctx, sizeof(uint32_t) * src->wlc.vs_count - ); - if (vs_array == NULL) { - PUSH_ERROR("failed to allocate wlc vs array"); - return -1; - } - memcpy(vs_array, - src->wlc.vs, - sizeof(uint32_t) * src->wlc.vs_count); - SET_OFFSET_OF(&dst->wlc.vs, vs_array); - } else { - SET_OFFSET_OF(&dst->wlc.vs, NULL); - } - - // Copy remaining scalar fields - dst->refresh_period = src->refresh_period; - dst->max_load_factor = src->max_load_factor; - - return 0; -} - -extern int -clone_balancer_config_from_relative( - struct balancer_config *dst, struct balancer_config *src -); - -static void -clone_manager_config_from_relative( - struct balancer_manager_config *dst, struct balancer_manager_config *src -) { - // Clone balancer config - clone_balancer_config_from_relative(&dst->balancer, &src->balancer); - - // Copy WLC scalar fields - dst->wlc.power = src->wlc.power; - dst->wlc.max_real_weight = src->wlc.max_real_weight; - dst->wlc.vs_count = src->wlc.vs_count; - - // Clone WLC vs array from relative pointers to normal pointers - if (src->wlc.vs_count > 0) { - uint32_t *src_vs = ADDR_OF(&src->wlc.vs); - dst->wlc.vs = calloc(src->wlc.vs_count, sizeof(uint32_t)); - memcpy(dst->wlc.vs, src_vs, sizeof(uint32_t) * src->wlc.vs_count - ); - } else { - dst->wlc.vs = NULL; - } - - // Copy remaining scalar fields - dst->refresh_period = src->refresh_period; - dst->max_load_factor = src->max_load_factor; -} - -//////////////////////////////////////////////////////////////////////////////// - -const char * -balancer_manager_name(struct balancer_manager *manager) { - return balancer_name(ADDR_OF(&manager->balancer)); -} - -void -balancer_manager_config( - struct balancer_manager *manager, struct balancer_manager_config *config -) { - clone_manager_config_from_relative(config, &manager->config); -} - -//////////////////////////////////////////////////////////////////////////////// - -static void -take_balancer_error(struct balancer_handle *balancer, struct diag *diag) { - const char *msg = balancer_take_error_msg(balancer); - if (msg == NULL) { - diag_reset(diag); - } else { - NEW_ERROR("%s", msg); - diag_fill(diag); - } -} - -int -balancer_manager_update_reals( - struct balancer_manager *manager, - size_t count, - struct real_update *updates -) { - diag_reset(&manager->diag); - - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - int res = balancer_update_reals(balancer, count, updates); - if (res != 0) { - take_balancer_error(balancer, &manager->diag); - return -1; - } - - struct balancer_config *config = &manager->config.balancer; - struct packet_handler_config *handler_config = &config->handler; - - for (size_t i = 0; i < count; i++) { - struct real_update *update = &updates[i]; - if (update->weight != DONT_UPDATE_REAL_WEIGHT) { - struct real_ph_index index; - int ec = balancer_real_ph_idx( - balancer, &update->identifier, &index - ); - assert(ec == 0); - - struct named_vs_config *vs_config = - ADDR_OF(&handler_config->vs) + index.vs_idx; - struct named_real_config *real_config = - ADDR_OF(&vs_config->config.reals) + - index.real_idx; - - real_config->config.weight = update->weight; - } - } - - return 0; -} - -int -balancer_manager_update_reals_wlc( - struct balancer_manager *manager, - size_t count, - struct real_update *updates -) { - diag_reset(&manager->diag); - - // Validate that WLC updates only change weights, not enable state - for (size_t i = 0; i < count; i++) { - struct real_update *update = &updates[i]; - if (update->enabled != DONT_UPDATE_REAL_ENABLED) { - NEW_ERROR( - "WLC update at index %lu attempts to change " - "enable state (not allowed)", - i - ); - diag_fill(&manager->diag); - return -1; - } - } - - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - int res = balancer_update_reals(balancer, count, updates); - if (res != 0) { - take_balancer_error(balancer, &manager->diag); - return -1; - } - - // Note: Unlike balancer_manager_update_reals(), this function does NOT - // update the config weights. The config weight should remain the - // original static weight. WLC calculations use the config weight as the - // baseline and adjust the state weight dynamically based on load. - - return 0; -} - -int -balancer_manager_update( - struct balancer_manager *manager, - struct balancer_manager_config *config, - struct balancer_update_info *update_info, - uint32_t now -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - - diag_reset(&manager->diag); - - struct balancer_manager_config old_config; - memcpy(&old_config, - &manager->config, - sizeof(struct balancer_manager_config)); - - // first, try to resize session table - size_t requested_session_table_capacity = - config->balancer.state.table_capacity; - if (requested_session_table_capacity != - manager->config.balancer.state.table_capacity) { - if (balancer_resize_session_table( - balancer, requested_session_table_capacity, now - ) != 0) { - NEW_ERROR("%s", balancer_take_error_msg(balancer)); - PUSH_ERROR("failed to resize session table"); - goto restore_config_on_error; - } - - size_t new_session_table_capacity = - balancer_session_table_capacity(balancer); - config->balancer.state.table_capacity = - new_session_table_capacity; - old_config.balancer.state.table_capacity = - new_session_table_capacity; - } - - // clone config - if (clone_manager_config_to_relative( - &manager->config, - config, - balancer_manager_memory_context(manager) - ) != 0) { - NEW_ERROR("failed to clone config"); - goto restore_config_on_error; - } - - // update state (resize session table) - - // update packet handler - if (balancer_update_packet_handler( - balancer, &config->balancer.handler, update_info - ) != 0) { - NEW_ERROR("%s", balancer_take_error_msg(balancer)); - PUSH_ERROR("failed to update packet handler"); - goto restore_config_on_error; - } - - return 0; - -restore_config_on_error: - memcpy(&manager->config, - &old_config, - sizeof(struct balancer_manager_config)); - - diag_fill(&manager->diag); - - return -1; -} - -int -balancer_manager_resize_session_table( - struct balancer_manager *manager, size_t new_size, uint32_t now -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - if (balancer_resize_session_table(balancer, new_size, now) != 0) { - NEW_ERROR("%s", balancer_take_error_msg(balancer)); - return -1; - } - setup_session_table_capacity(manager); - return 0; -} - -int -balancer_manager_info( - struct balancer_manager *manager, - struct balancer_info *info, - uint32_t now -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - if (balancer_info(balancer, info, now) != 0) { - NEW_ERROR("%s", balancer_take_error_msg(balancer)); - return -1; - } - return 0; -} - -void -balancer_manager_sessions( - struct balancer_manager *manager, - struct sessions *sessions, - uint32_t now -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - balancer_sessions(balancer, sessions, now); -} - -int -balancer_manager_stats( - struct balancer_manager *manager, - struct balancer_stats *stats, - struct packet_handler_ref *ref -) { - diag_reset(&manager->diag); - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - if (balancer_stats(balancer, stats, ref) != 0) { - take_balancer_error(balancer, &manager->diag); - return -1; - } - return 0; -} - -void -balancer_manager_graph( - struct balancer_manager *manager, struct balancer_graph *graph -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - balancer_graph(balancer, graph); -} - -//////////////////////////////////////////////////////////////////////////////// - -extern const char *agent_name; -extern const char *storage_name; - -void -balancer_agent_managers( - struct balancer_agent *agent, struct balancer_managers *managers -) { - struct balancer_managers *stored_managers = - agent_storage_read((struct agent *)agent, storage_name); - assert(stored_managers != NULL); - managers->count = stored_managers->count; - managers->managers = - calloc(managers->count, sizeof(struct balancer_manager *)); - - struct balancer_manager **stored_managers_array = - ADDR_OF(&stored_managers->managers); - - for (size_t i = 0; i < managers->count; ++i) { - managers->managers[i] = ADDR_OF(stored_managers_array + i); - } -} - -static int -find_manager(struct balancer_agent *balancer_agent, const char *name) { - struct balancer_managers *stored_managers = agent_storage_read( - (struct agent *)balancer_agent, storage_name - ); - assert(stored_managers != NULL); - struct balancer_manager **managers = - ADDR_OF(&stored_managers->managers); - for (size_t i = 0; i < stored_managers->count; ++i) { - struct balancer_manager *manager = ADDR_OF(managers + i); - if (strcmp(name, balancer_manager_name(manager)) == 0) { - return 1; - } - } - return 0; -} - -struct balancer_manager * -balancer_agent_new_manager( - struct balancer_agent *balancer_agent, - const char *name, - struct balancer_manager_config *config -) { - struct agent *agent = (struct agent *)balancer_agent; - diag_reset(&agent->diag); - - if (find_manager(balancer_agent, name) != 0) { - NEW_ERROR("manager with name '%s' already exists", name); - diag_fill(&agent->diag); - return NULL; - } - - struct memory_context *mctx = &agent->memory_context; - struct balancer_manager *new_manager = - memory_balloc(mctx, sizeof(struct balancer_manager)); - if (new_manager == NULL) { - NEW_ERROR("failed to allocate manager"); - diag_fill(&agent->diag); - return NULL; - } - - memset(new_manager, 0, sizeof(struct balancer_manager)); - SET_OFFSET_OF(&new_manager->agent, balancer_agent); - - if (clone_manager_config_to_relative( - &new_manager->config, config, mctx - ) != 0) { - NEW_ERROR("failed to allocate manager config"); - diag_fill(&agent->diag); - memory_bfree( - mctx, new_manager, sizeof(struct balancer_manager) - ); - return NULL; - } - - struct balancer_managers *stored_managers = agent_storage_read( - (struct agent *)balancer_agent, storage_name - ); - assert(stored_managers != NULL); - - struct balancer_manager **new_managers = memory_balloc( - mctx, - sizeof(struct balancer_manager *) * (stored_managers->count + 1) - ); - if (new_managers == NULL) { - NEW_ERROR("failed to allocate managers storage"); - diag_fill(&agent->diag); - memory_bfree( - mctx, new_manager, sizeof(struct balancer_manager) - ); - return NULL; - } - for (size_t i = 0; i < stored_managers->count; ++i) { - EQUATE_OFFSET( - new_managers + i, - ADDR_OF(&stored_managers->managers) + i - ); - } - - struct balancer_handle *handle = - balancer_create(agent, name, &config->balancer); - if (handle == NULL) { - PUSH_ERROR("failed to create balancer"); - return NULL; - } - - SET_OFFSET_OF(&new_manager->balancer, handle); - - SET_OFFSET_OF(new_managers + stored_managers->count, new_manager); - - memory_bfree( - mctx, - ADDR_OF(&stored_managers->managers), - sizeof(struct balancer_manager *) * stored_managers->count - ); - - SET_OFFSET_OF(&stored_managers->managers, new_managers); - - ++stored_managers->count; - - return new_manager; -} - -//////////////////////////////////////////////////////////////////////////////// -// Memory Management -//////////////////////////////////////////////////////////////////////////////// - -void -balancer_manager_info_free(struct balancer_info *info) { - balancer_info_free(info); -} - -void -balancer_manager_sessions_free(struct sessions *sessions) { - balancer_sessions_free(sessions); -} - -void -balancer_manager_stats_free(struct balancer_stats *stats) { - balancer_stats_free(stats); -} - -void -balancer_manager_graph_free(struct balancer_graph *graph) { - balancer_graph_free(graph); -} - -//////////////////////////////////////////////////////////////////////////////// - -void -balancer_manager_inspect( - struct balancer_manager *manager, struct balancer_inspect *inspect -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - balancer_inspect(balancer, inspect); -} - -void -balancer_manager_inspect_free(struct balancer_inspect *inspect) { - if (inspect == NULL) { - return; - } - - // Free packet handler inspect nested structures - if (inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects != - NULL) { - free(inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects - ); - inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects = - NULL; - } - - if (inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects != - NULL) { - free(inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects - ); - inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects = - NULL; - } -} - -void -balancer_manager_active_sessions( - struct balancer_manager *manager, struct balancer_info *info -) { - struct balancer_handle *balancer = ADDR_OF(&manager->balancer); - balancer_active_sessions(balancer, info); -} \ No newline at end of file diff --git a/modules/balancer/agent/manager.h b/modules/balancer/agent/manager.h deleted file mode 100644 index 960b8df23..000000000 --- a/modules/balancer/agent/manager.h +++ /dev/null @@ -1,624 +0,0 @@ -#include -#include - -#include "agent.h" -#include "modules/balancer/controlplane/api/balancer.h" - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Weighted Least Connection (WLC) algorithm configuration. - * - * Configures the WLC scheduling algorithm that dynamically adjusts real - * server weights based on active session counts to achieve better load - * distribution. The WLC algorithm is particularly useful when session - * durations vary significantly, preventing overloading of individual reals. - * - * ALGORITHM OVERVIEW: - * The WLC algorithm calculates effective weights using this formula: - * - * effective_weight = min( - * config_weight * max(1.0, power * (1.0 - connections_ratio)), - * max_real_weight - * ) - * - * where: - * connections_ratio = (real_sessions * total_weight) / - * (total_sessions * real_weight) - * - * BEHAVIOR: - * - If a real has fewer sessions than expected (ratio < 1.0): - * Weight increases to attract more traffic - * - If a real has more sessions than expected (ratio > 1.0): - * Weight stays at baseline (no decrease below config_weight) - * - The 'power' parameter controls adjustment aggressiveness - * - The 'max_real_weight' parameter caps maximum weight - * - * EXECUTION: - * - Runs every refresh_period (configured in BalancerManagerConfig) - * - Only affects virtual services listed in the 'vs' array - * - Uses update_reals_wlc() to preserve config weights - * - * CONFIGURATION REQUIREMENTS: - * - Must be fully specified if any VS has WLC flag enabled - * - Can be empty (power=0, max_real_weight=0, vs=[]) if no VS uses WLC - */ -struct balancer_manager_wlc_config { - /** - * Power factor for weight adjustment aggressiveness. - * - * Controls how aggressively weights are adjusted based on session - * distribution imbalance. Higher values cause more dramatic weight - * changes in response to load imbalance. - * - * RECOMMENDED VALUES: - * - Conservative (stable): 1-2 - * - Moderate (balanced): 2-4 - * - Aggressive (responsive): 4-8 - * - Very aggressive: 8-16 - * - * EXAMPLE IMPACT: - * If a real has 50% fewer sessions than expected (ratio=0.5): - * - power=2: weight increases by 1.0x (doubles) - * - power=4: weight increases by 2.0x (triples) - * - power=1: weight increases by 0.5x (1.5x original) - */ - size_t power; - - /** - * Maximum effective weight limit. - * - * Caps the maximum weight a real server can have after WLC - * adjustment. Prevents any single real from dominating traffic - * distribution even when severely underloaded. - * - * RECOMMENDED VALUES: - * - Conservative: 2-3x maximum configured weight - * - Moderate: 5-10x maximum configured weight - * - Aggressive: 10-20x maximum configured weight - * - * EXAMPLE: - * If configured weights range from 1-100 and max_real_weight=500: - * - A real with weight=50 can reach effective_weight=500 (10x) - * - A real with weight=100 can reach effective_weight=500 (5x) - */ - size_t max_real_weight; - - /** - * Number of virtual service indices in the 'vs' array. - * - * Specifies how many virtual services have WLC enabled. - * If 0, no virtual services use WLC (algorithm is disabled). - */ - size_t vs_count; - - /** - * Array of virtual service indices with WLC enabled. - * - * Contains indices into the PacketHandlerConfig.vs array, - * identifying which virtual services should have WLC applied. - * Only these VSs will have their weights dynamically adjusted. - * - * OWNERSHIP: - * - Caller allocates and manages this array - * - Must remain valid for the lifetime of the config - * - Array length must match vs_count - * - * EXAMPLE: - * If vs = [0, 2, 5], then virtual services at indices 0, 2, and 5 - * in the configuration will have WLC enabled, while others won't. - */ - uint32_t *vs; -}; - -/** - * Complete configuration for a balancer manager. - * - * Combines balancer instance configuration with WLC algorithm parameters - * and operational settings for periodic refresh operations. The manager - * coordinates the balancer instance and applies scheduling algorithms - * like WLC. - * - * CONFIGURATION DEPENDENCIES: - * The fields have interdependencies that must be satisfied: - * - * 1. If refresh_period > 0, then: - * - max_load_factor must be set (typically 0.7-0.9) - * - wlc must be configured (even if no VS uses WLC) - * - * 2. If any VS has wlc flag enabled, then: - * - refresh_period must be > 0 - * - max_load_factor must be set - * - wlc.power and wlc.max_real_weight must be set - * - wlc.vs must include the VS index - * - * 3. If refresh_period == 0: - * - No periodic operations (no auto-resize, no WLC, no stats updates) - * - WLC flag cannot be enabled on any VS - * - max_load_factor is ignored - * - * REFRESH CYCLE OPERATIONS: - * When refresh_period > 0, the manager performs these operations every - * refresh_period milliseconds: - * - * 1. Session Statistics Collection: - * - Scans session table to count active sessions - * - Updates per-VS and per-real session counts - * - Updates last_packet_timestamp fields - * - Makes data available via balancer_manager_info() - * - * 2. Automatic Session Table Resizing: - * - Calculates: load_factor = active_sessions / table_capacity - * - If load_factor > max_load_factor: - * * Doubles table capacity - * * Migrates existing sessions - * * Prevents session table overflow - * - * 3. WLC Weight Adjustment: - * - For each VS in wlc.vs array: - * * Calculates new effective weights based on session distribution - * * Calls balancer_manager_update_reals_wlc() to update weights - * * Preserves original config weights for future calculations - * - * TYPICAL CONFIGURATIONS: - * - * Static configuration (no WLC, no auto-resize): - * ```c - * config.refresh_period = 0; - * config.max_load_factor = 0.0; // ignored - * config.wlc = {0, 0, 0, NULL}; // ignored - * ``` - * - * Auto-resize only (no WLC): - * ```c - * config.refresh_period = 30000; // 30 seconds - * config.max_load_factor = 0.8; - * config.wlc = {0, 0, 0, NULL}; // no VSs use WLC - * ``` - * - * Full dynamic configuration (auto-resize + WLC): - * ```c - * config.refresh_period = 10000; // 10 seconds - * config.max_load_factor = 0.75; - * config.wlc = { - * .power = 4, - * .max_real_weight = 1000, - * .vs_count = 2, - * .vs = (uint32_t[]){0, 1} // VSs 0 and 1 use WLC - * }; - * ``` - */ -struct balancer_manager_config { - /** - * Core balancer configuration. - * - * Contains packet handler config (virtual services, reals, timeouts) - * and state config (session table capacity). - */ - struct balancer_config balancer; - - /** - * WLC algorithm configuration. - * - * Specifies WLC parameters (power, max_weight) and which virtual - * services have WLC enabled (vs array). - * - * REQUIREMENTS: - * - Must be fully configured if any VS has wlc flag enabled - * - Can be empty (all zeros) if no VS uses WLC - * - Requires refresh_period > 0 to function - */ - struct balancer_manager_wlc_config wlc; - - /** - * Periodic refresh interval in milliseconds. - * - * Controls how often the manager performs background operations: - * - Session statistics collection - * - Automatic session table resizing - * - WLC weight adjustment - * - * SPECIAL VALUES: - * - 0: Disables all periodic operations - * - > 0: Enables periodic operations at specified interval - * - * RECOMMENDED VALUES: - * - High-traffic dynamic: 5,000-10,000 ms (5-10 seconds) - * - Moderate traffic: 15,000-30,000 ms (15-30 seconds) - * - Stable traffic: 30,000-60,000 ms (30-60 seconds) - * - Static config: 0 (disabled) - * - * PERFORMANCE IMPACT: - * - Shorter periods: More responsive, higher CPU overhead - * - Longer periods: Less overhead, slower response - * - Cost scales with active_sessions and vs_count - */ - uint32_t refresh_period; - - /** - * Maximum session table load factor (0.0 to 1.0). - * - * Threshold for automatic session table resizing. When the load - * factor (active_sessions / table_capacity) exceeds this value - * during a refresh cycle, the table capacity is doubled. - * - * RECOMMENDED VALUES: - * - Conservative (frequent resize): 0.6-0.7 - * - Balanced: 0.7-0.8 - * - Aggressive (rare resize): 0.8-0.9 - * - Very aggressive: 0.9-0.95 - * - * TRADE-OFFS: - * - Lower values: More frequent resizing, lower collision rate - * - Higher values: Less frequent resizing, higher collision rate - * - Typical hash table performance degrades above 0.9 - * - * REQUIREMENTS: - * - Must be set if refresh_period > 0 - * - Ignored if refresh_period == 0 - * - Valid range: 0.0 < max_load_factor < 1.0 - */ - float max_load_factor; -}; - -struct balancer_handle; - -/** - * Opaque handle to a balancer manager instance. - * - * A manager coordinates one balancer instance, applying scheduling - * algorithms (like WLC) and managing configuration updates. It provides - * a higher-level interface for balancer lifecycle management. - * - * Thread-Safety: Not thread-safe. External synchronization required for - * concurrent access. - */ -struct balancer_manager; - -//////////////////////////////////////////////////////////////////////////////// -// Query Operations -//////////////////////////////////////////////////////////////////////////////// - -/** - * Get the name of the balancer manager. - * - * @param manager Manager handle. - * @return Pointer to the manager name string (owned by the manager, do not - * free). - */ -const char * -balancer_manager_name(struct balancer_manager *manager); - -/** - * Retrieve the current configuration of the manager. - * - * Fills the provided config structure with the manager's current settings. - * The config structure should be allocated by the caller. - * - * @param manager Manager handle. - * @param config Output structure to be filled with current configuration. - */ -void -balancer_manager_config( - struct balancer_manager *manager, struct balancer_manager_config *config -); - -//////////////////////////////////////////////////////////////////////////////// -// Update Operations -//////////////////////////////////////////////////////////////////////////////// - -/** - * Update the manager's configuration. - * - * Applies a new configuration to the manager, updating balancer settings, - * WLC parameters, refresh period, and load factor. This may trigger - * reconfiguration of underlying balancer instances. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param config New configuration to apply. - * @param update_info Optional output structure to receive update metadata. - * Pass NULL if metadata is not needed. - * @param now Current monotonic timestamp for bookkeeping. - * @return 0 on success, -1 on error. - */ -int -balancer_manager_update( - struct balancer_manager *manager, - struct balancer_manager_config *config, - struct balancer_update_info *update_info, - uint32_t now -); - -/** - * Apply a batch of real server updates. - * - * Updates the state (weight, enabled status) of one or more real servers - * managed by this manager. Each update may change weight and/or enabled state. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param count Number of updates in the array. - * @param updates Array of real server updates to apply. - * @return 0 on success, -1 on error. - */ -int -balancer_manager_update_reals( - struct balancer_manager *manager, - size_t count, - struct real_update *updates -); - -/** - * Apply a batch of real server weight updates for WLC algorithm. - * - * This function is specifically designed for WLC (Weighted Least Connection) - * algorithm to update effective weights without modifying the baseline - * configuration weights. It provides a way to dynamically adjust weights - * based on load while preserving the original configured values. - * - * KEY DIFFERENCES FROM balancer_manager_update_reals(): - * - * 1. WEIGHT UPDATE SCOPE: - * - update_reals(): Updates BOTH config weight AND state/graph weight - * - update_reals_wlc(): Updates ONLY state/graph weight, preserves config - * - * 2. CONFIG PRESERVATION: - * - update_reals(): New weight becomes the baseline for future WLC - * calculations - * - update_reals_wlc(): Original config weight remains the WLC baseline - * - * 3. USE CASES: - * - update_reals(): Manual weight changes by administrator - * - update_reals_wlc(): Automatic weight adjustments by WLC algorithm - * - * 4. ENABLE STATE: - * - update_reals(): Can change both weight and enabled state - * - update_reals_wlc(): MUST NOT change enabled state (validation error) - * - * EXAMPLE SCENARIO: - * ``` - * Initial config: real1.weight = 100 - * - * // WLC adjusts weight based on load - * update_reals_wlc(real1, weight=150) - * Result: config.weight=100, state.weight=150 - * Next WLC cycle uses config.weight=100 as baseline - * - * // Admin changes weight - * update_reals(real1, weight=200) - * Result: config.weight=200, state.weight=200 - * Next WLC cycle uses config.weight=200 as baseline - * ``` - * - * VALIDATION: - * - All updates must have weight != DONT_UPDATE_REAL_WEIGHT - * - All updates must have enabled == DONT_UPDATE_REAL_ENABLED - * - Returns error if any update attempts to change enabled state - * - * TYPICAL USAGE: - * This function is called automatically by the background refresh task - * when WLC is enabled. Manual calls are rare and should be done with - * caution to avoid interfering with the WLC algorithm. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param count Number of updates in the array. - * @param updates Array of real server weight updates. Each update MUST: - * - Have weight != DONT_UPDATE_REAL_WEIGHT - * - Have enabled == DONT_UPDATE_REAL_ENABLED - * @return 0 on success, -1 on error (including validation failures). - */ -int -balancer_manager_update_reals_wlc( - struct balancer_manager *manager, - size_t count, - struct real_update *updates -); - -/** - * Resize the session table used by the manager's balancer. - * - * Changes the capacity of the session table to accommodate more or fewer - * concurrent sessions. This operation may involve memory reallocation and - * session migration. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param new_size New number of session table entries to allocate. - * @param now Current monotonic timestamp for migration bookkeeping. - * @return 0 on success, -1 on error. - */ -int -balancer_manager_resize_session_table( - struct balancer_manager *manager, size_t new_size, uint32_t now -); - -//////////////////////////////////////////////////////////////////////////////// -// Statistics and Information Retrieval -//////////////////////////////////////////////////////////////////////////////// - -/** - * Query aggregated balancer information from the manager. - * - * Retrieves comprehensive information including active sessions, virtual - * services, and real server states. On success, allocates arrays inside - * the info structure that must be freed with balancer_manager_info_free(). - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param info Output structure to be filled with balancer information. - * @param now Current monotonic timestamp for time-based calculations. - * @return 0 on success, -1 on error. - */ -int -balancer_manager_info( - struct balancer_manager *manager, - struct balancer_info *info, - uint32_t now -); - -/** - * Enumerate active sessions tracked by the manager's balancer. - * - * Returns a snapshot of all active sessions. The sessions structure will - * contain heap-allocated data that must be freed with - * balancer_manager_sessions_free(). - * - * @param manager Manager handle. - * @param sessions Output structure to be filled with session information. - * @param now Current monotonic timestamp for session state evaluation. - */ -void -balancer_manager_sessions( - struct balancer_manager *manager, - struct sessions *sessions, - uint32_t now -); - -/** - * Read balancer statistics from the manager. - * - * Retrieves statistics for the manager's balancers, optionally filtered - * by packet handler reference. On success, allocates data inside the stats - * structure that must be freed with balancer_manager_stats_free(). - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_manager_take_error(). - * - * @param manager Manager handle. - * @param stats Output structure to be filled with statistics. - * @param ref Optional filter for specific packet handler; pass NULL for - * aggregate. - * @return 0 on success, -1 on error. - */ -int -balancer_manager_stats( - struct balancer_manager *manager, - struct balancer_stats *stats, - struct packet_handler_ref *ref -); - -/** - * Retrieve graph representation of the manager's balancer topology. - * - * Generates a graph structure representing the relationships between - * virtual services, real servers, and their connections. The graph structure - * must be freed with balancer_manager_graph_free(). - * - * @param manager Manager handle. - * @param graph Output structure to be filled with graph data. - */ -void -balancer_manager_graph( - struct balancer_manager *manager, struct balancer_graph *graph -); - -//////////////////////////////////////////////////////////////////////////////// -// Memory Management -//////////////////////////////////////////////////////////////////////////////// - -/** - * Free all allocations inside a balancer_info structure. - * - * Releases memory allocated by balancer_manager_info(). Safe to call with - * partially-initialized structures; ignores NULL pointers. - * - * @param info Structure to release. The struct itself is not freed. - */ -void -balancer_manager_info_free(struct balancer_info *info); - -/** - * Free all allocations inside a sessions structure. - * - * Releases memory allocated by balancer_manager_sessions(). Safe to call - * with partially-initialized structures; ignores NULL pointers. - * - * @param sessions Structure to release. The struct itself is not freed. - */ -void -balancer_manager_sessions_free(struct sessions *sessions); - -/** - * Free all allocations inside a balancer_stats structure. - * - * Releases memory allocated by balancer_manager_stats(). Safe to call with - * partially-initialized structures; ignores NULL pointers. - * - * @param stats Structure to release. The struct itself is not freed. - */ -void -balancer_manager_stats_free(struct balancer_stats *stats); - -/** - * Free all allocations inside a balancer_graph structure. - * - * Releases memory allocated by balancer_manager_graph(). - * - * @param graph Structure to release. The struct itself is not freed. - */ -void -balancer_manager_graph_free(struct balancer_graph *graph); - -//////////////////////////////////////////////////////////////////////////////// -// Error Handling -//////////////////////////////////////////////////////////////////////////////// - -/** - * Retrieve memory inspection for this manager's balancer. - * - * Fills the provided balancer_inspect structure with detailed memory - * usage information for the balancer instance managed by this manager. - * The structure contains nested allocations that must be freed with - * balancer_manager_inspect_free(). - * - * @param manager Manager handle. - * @param inspect Output structure to be filled with inspection data. - */ -void -balancer_manager_inspect( - struct balancer_manager *manager, struct balancer_inspect *inspect -); - -/** - * Free all allocations inside a balancer_inspect structure. - * - * Releases memory allocated by balancer_manager_inspect() for nested - * arrays and structures. Safe to call with partially-initialized - * structures; ignores NULL pointers. - * - * @param inspect Structure to release. The struct itself is not freed. - */ -void -balancer_manager_inspect_free(struct balancer_inspect *inspect); - -/** - * Retrieve the last diagnostic error message for this manager. - * - * Returns the most recent error message recorded by manager operations. - * After calling this function, the error state is cleared. - * - * Ownership: The returned string is heap-allocated for the caller; you must - * free() it when no longer needed. Returns NULL if no error is available. - * - * @param manager Manager handle. - * @return Null-terminated error message string to be freed by caller, or NULL. - */ -const char * -balancer_manager_take_error(struct balancer_manager *manager); - -void -balancer_manager_active_sessions( - struct balancer_manager *manager, struct balancer_info *info -); \ No newline at end of file diff --git a/modules/balancer/bench/alloc.c b/modules/balancer/bench/alloc.c deleted file mode 100644 index 30be04a63..000000000 --- a/modules/balancer/bench/alloc.c +++ /dev/null @@ -1,24 +0,0 @@ -#include "alloc.h" - -void -allocator_init(struct allocator *alloc, void *arena, size_t size) { - alloc->arena = arena; - alloc->size = size; - alloc->allocated = 0; -} - -uint8_t * -allocator_alloc(struct allocator *alloc, size_t align, size_t size) { - size_t shift = 0; - uintptr_t start = (uintptr_t)alloc->arena + alloc->allocated; - if (start % align != 0) { - shift = align - start % align; - } - size += shift; - if (alloc->allocated + size > alloc->size) { - return NULL; - } - uint8_t *ptr = (uint8_t *)alloc->arena + alloc->allocated; - alloc->allocated += size; - return ptr + shift; -} \ No newline at end of file diff --git a/modules/balancer/bench/alloc.h b/modules/balancer/bench/alloc.h deleted file mode 100644 index 886aace42..000000000 --- a/modules/balancer/bench/alloc.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -struct allocator { - size_t allocated; - size_t size; - void *arena; -}; - -void -allocator_init(struct allocator *alloc, void *arena, size_t size); - -uint8_t * -allocator_alloc(struct allocator *alloc, size_t align, size_t size); diff --git a/modules/balancer/bench/bench.c b/modules/balancer/bench/bench.c deleted file mode 100644 index c41426039..000000000 --- a/modules/balancer/bench/bench.c +++ /dev/null @@ -1,121 +0,0 @@ -#include "bench.h" -#include "controlplane/diag/diag.h" -#include "mock/config.h" -#include "mock/mock.h" -#include "mock/packet.h" -#include -#include -#include - -#define DP_MEMORY (1 << 20) - -int -bench_init(struct bench *bench, struct bench_config *config) { - // Initialize fields to safe defaults before any operation that might - // fail - memset(&bench->yanet, 0, sizeof(bench->yanet)); - bench->shared_memory = NULL; - bench->total_memory = 0; - - diag_reset(&bench->diag); - - if (config->total_memory < DP_MEMORY + config->cp_memory) { - NEW_ERROR( - "memory is to small (required at least %lu)", - DP_MEMORY + config->cp_memory - ); - goto error; - } - - errno = 0; - void *shared_memory = - mmap(NULL, - config->total_memory, - PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, - -1, - 0); - if (shared_memory == MAP_FAILED) { - NEW_ERROR("mmap failed: %s", strerror(errno)); - goto error; - } - - memset(shared_memory, 0, config->total_memory); - - bench->shared_memory = shared_memory; - bench->total_memory = config->total_memory; - - struct yanet_mock_config yanet_config = { - .worker_count = config->workers, - .device_count = 1, - .dp_memory = DP_MEMORY, - .cp_memory = config->cp_memory, - .devices = {(struct yanet_mock_device_config){ - .id = 0, - .name = "01:00.0", - }} - }; - - if (yanet_mock_init(&bench->yanet, &yanet_config, shared_memory) != 0) { - NEW_ERROR("failed to init mock"); - goto error_unmap; - } - - allocator_init( - &bench->alloc, - shared_memory + DP_MEMORY + config->cp_memory, - config->total_memory - config->cp_memory - DP_MEMORY - ); - - return 0; - -error_unmap: - munmap(shared_memory, config->total_memory); - -error: - diag_fill(&bench->diag); - return -1; -} - -#undef DP_MEMORY - -const char * -bench_take_error(struct bench *bench) { - return diag_take_msg(&bench->diag); -} - -void -bench_free(struct bench *bench) { - yanet_mock_free(&bench->yanet); - munmap(bench->shared_memory, bench->total_memory); -} - -int -bench_handle_packets( - struct bench *bench, - size_t worker, - struct packet_list *packets_batch, - size_t batches_count -) { - struct packet_handle_result result; - size_t dropped_count = 0; - for (size_t i = 0; i < batches_count; i++) { - memset(&result, 0, sizeof(result)); - yanet_mock_handle_packets( - &bench->yanet, packets_batch + i, worker, &result - ); - dropped_count += result.drop_packets.count; - } - return dropped_count > 0; -} - -uint8_t * -bench_alloc(void *bench, size_t align, size_t size) { - struct bench *b = (struct bench *)bench; - return allocator_alloc(&b->alloc, align, size); -} - -void * -bench_shared_memory(struct bench *bench) { - return bench->shared_memory; -} \ No newline at end of file diff --git a/modules/balancer/bench/bench.h b/modules/balancer/bench/bench.h deleted file mode 100644 index 78c6122f4..000000000 --- a/modules/balancer/bench/bench.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "dataplane/packet/packet.h" -#include "lib/controlplane/diag/diag.h" -#include "mock/mock.h" - -#include "config.h" -#include "modules/balancer/bench/alloc.h" - -struct bench { - struct yanet_mock yanet; - struct diag diag; - void *shared_memory; - size_t total_memory; - struct allocator alloc; -}; - -int -bench_init(struct bench *bench, struct bench_config *config); - -void * -bench_shared_memory(struct bench *bench); - -const char * -bench_take_error(struct bench *bench); - -void -bench_free(struct bench *bench); - -int -bench_handle_packets( - struct bench *bench, - size_t worker, - struct packet_list *packets_batch, - size_t batches_count -); - -uint8_t * -bench_alloc(void *bench, size_t align, size_t size); diff --git a/modules/balancer/bench/config.h b/modules/balancer/bench/config.h deleted file mode 100644 index 39505aa99..000000000 --- a/modules/balancer/bench/config.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -struct balancer_config; - -struct bench_config { - size_t workers; - size_t cp_memory; - size_t total_memory; -}; \ No newline at end of file diff --git a/modules/balancer/bench/configs/some.yaml b/modules/balancer/bench/configs/some.yaml deleted file mode 100644 index fbb6d18f4..000000000 --- a/modules/balancer/bench/configs/some.yaml +++ /dev/null @@ -1,27 +0,0 @@ -gre_prob: 0.10 -fix_mss_prob: 0.05 -pure_l3_prob: 0.20 -ops_prob: 0.15 - -tcp_ipv4_vs: 100 -tcp_ipv6_vs: 50 -udp_ipv4_vs: 80 -udp_ipv6_vs: 40 - -reals_per_vs: 16 - -new_session_prob: 0.30 - -icmp_prob: 0.02 -icmp_redirect_prob: 0.01 - -batches_per_worker: 128 -packets_per_batch: 64 - -workers: 4 - -allowed_src_per_vs: 10 - -round_robin_prob: 0.5 - -session_table_capacity: 131072 \ No newline at end of file diff --git a/modules/balancer/bench/configs/udp_basic.yaml b/modules/balancer/bench/configs/udp_basic.yaml deleted file mode 100644 index fadfaef54..000000000 --- a/modules/balancer/bench/configs/udp_basic.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# UDP-only benchmark configuration -# This configuration generates only UDP virtual services and UDP packets - -# Virtual service flags probabilities (all set to 0 for basic UDP testing) -gre_prob: 0.0 -fix_mss_prob: 0.0 -pure_l3_prob: 0.0 -ops_prob: 0.0 - -# Scheduling algorithm probability -round_robin_prob: 0.5 - -# Virtual services configuration (only UDP, no TCP) -tcp_ipv4_vs: 0 -tcp_ipv6_vs: 0 -udp_ipv4_vs: 2 -udp_ipv6_vs: 2 - -# Real servers per virtual service -ipv4_reals: 10 -ipv6_reals: 10 - -# Number of allowed source networks per virtual service -allowed_src_per_vs: 5 - -# Session management -# Probability of creating a new session (vs reusing existing) -# 0.3 means 30% new sessions, 70% reuse existing sessions -new_session_prob: 0.01 - -# ICMP packet probabilities (set to 0 for UDP-only testing) -icmp_prob: 0.0 -icmp_redirect_prob: 0.0 - -# Packet generation configuration -batches_per_worker: 131072 -packets_per_batch: 32 - -# MSS option for TCP packets (not applicable for UDP, but must be specified) -mss: 0 - -# Number of worker threads -workers: 1 - -session_table_capacity: 33554432 \ No newline at end of file diff --git a/modules/balancer/bench/configs/udp_ipv6.yaml b/modules/balancer/bench/configs/udp_ipv6.yaml deleted file mode 100644 index 07b3c7735..000000000 --- a/modules/balancer/bench/configs/udp_ipv6.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# UDP-only benchmark configuration -# This configuration generates only UDP virtual services and UDP packets - -# Virtual service flags probabilities (all set to 0 for basic UDP testing) -gre_prob: 0.0 -fix_mss_prob: 0.0 -pure_l3_prob: 0.0 -ops_prob: 0.0 - -# Scheduling algorithm probability -round_robin_prob: 1 - -# Virtual services configuration (only UDP, no TCP) -tcp_ipv4_vs: 0 -tcp_ipv6_vs: 0 -udp_ipv4_vs: 0 -udp_ipv6_vs: 2 - -# Real servers per virtual service -ipv4_reals: 10 -ipv6_reals: 10 - -# Number of allowed source networks per virtual service -allowed_src_per_vs: 5 - -# Session management -# Probability of creating a new session (vs reusing existing) -# 0.3 means 30% new sessions, 70% reuse existing sessions -new_session_prob: 0.01 - -# ICMP packet probabilities (set to 0 for UDP-only testing) -icmp_prob: 0.0 -icmp_redirect_prob: 0.0 - -# Packet generation configuration -batches_per_worker: 262144 -packets_per_batch: 32 - -# MSS option for TCP packets (not applicable for UDP, but must be specified) -mss: 0 - -# Number of worker threads -workers: 1 - -session_table_capacity: 33554432 \ No newline at end of file diff --git a/modules/balancer/bench/go/bench.go b/modules/balancer/bench/go/bench.go deleted file mode 100644 index 092efa61a..000000000 --- a/modules/balancer/bench/go/bench.go +++ /dev/null @@ -1,677 +0,0 @@ -// Package main implements a high-performance benchmark tool for the YANET balancer module. -// It generates synthetic traffic, measures packet processing throughput (MPpS), and provides -// detailed statistics on balancer performance including per-worker metrics and load distribution. -package main - -import ( - "bufio" - "fmt" - "math/rand/v2" - "net/netip" - "os" - "runtime" - "sync" - "time" - - "github.com/c2h5oh/datasize" - "github.com/yanet-platform/yanet2/common/go/logging" - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" - dataplane "github.com/yanet-platform/yanet2/lib/utils/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "go.uber.org/zap/zapcore" - "golang.org/x/sys/unix" -) - -var ( - PacketsMemory int = (1 << 32) + (1 << 30) - TotalMemory int = CpMemory + PacketsMemory - CpMemory int = (1 << 33) - AgentMemory int = CpMemory - (1 << 27) -) - -var BalancerName string = "balancer0" - -// generate packets and run handlers -type workerInfo struct { - idx int - tid int - info string - isErr bool -} - -// worker performance metrics -type workerPerf struct { - idx int - packets int - duration time.Duration - mpps float64 -} - -func workerRoutine( - bench *Bench, - wg *sync.WaitGroup, - readyWg *sync.WaitGroup, - info chan workerInfo, - perf chan workerPerf, - start chan struct{}, - idx int, - packetList []dataplane.PacketList, - totalPackets int, -) { - defer wg.Done() - - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - tid := unix.Gettid() - - sendMsg := func(msg string) { - info <- workerInfo{idx: idx, tid: tid, info: msg, isErr: false} - } - - sendError := func(msg string) { - info <- workerInfo{idx: idx, tid: tid, info: msg, isErr: true} - } - - // pin - var set unix.CPUSet - set.Zero() - set.Set(idx) - if err := unix.SchedSetaffinity(0, &set); err != nil { - sendError(fmt.Sprintf("failed to set affinity: %s", err)) - readyWg.Done() - return - } - - // set priority - if err := unix.Setpriority(unix.PRIO_PROCESS, tid, -20); err != nil { - sendError(fmt.Sprintf("failed to set priority: %s", err)) - readyWg.Done() - return - } - - sendMsg(fmt.Sprintf("pinned to CPU %d with priority %d", idx, -20)) - readyWg.Done() - - <-start - - startTime := time.Now() - - if err := bench.HandlePackets(idx, packetList); err != nil { - msg := fmt.Sprintf("failed to handle packets: %s", err) - sendError(msg) - } else { - elapsed := time.Since(startTime) - mpps := float64(totalPackets) / elapsed.Seconds() / 1e6 - sendMsg(fmt.Sprintf("successfully handled %d packets in %s (%.2f MPpS)", totalPackets, elapsed, mpps)) - // Send performance metrics - perf <- workerPerf{ - idx: idx, - packets: totalPackets, - duration: elapsed, - mpps: mpps, - } - } -} - -func enableAllReals(bal *balancer.BalancerManager) error { - var updates []*balancerpb.RealUpdate - enableTrue := true - balancerConfig := bal.Config() - - for _, vs := range balancerConfig.PacketHandler.Vs { - for _, real := range vs.Reals { - updates = append(updates, &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: vs.Id, - Real: real.Id, - }, - Enable: &enableTrue, - }) - } - } - - // update reals - if _, err := bal.UpdateReals(updates, false); err != nil { - return fmt.Errorf("failed to enable reals: %s", err) - } - - return nil -} - -func balancerConfig(config *BenchConfig) *balancerpb.BalancerConfig { - // Create virtual services based on config - var virtualServices []*balancerpb.VirtualService - - rng := rand.New(rand.NewPCG(1, 2)) - - // Helper function to create a VS with reals - createVS := func(addr netip.Addr, port uint32, proto balancerpb.TransportProto) *balancerpb.VirtualService { - // Determine flags based on probabilities - flags := &balancerpb.VsFlags{ - Gre: rng.Float32() < config.GreProb, - FixMss: rng.Float32() < config.FixMSSProb, - PureL3: rng.Float32() < config.PureL3Prob, - Ops: rng.Float32() < config.OpsProb, - Wlc: false, - } - - // If PureL3 is enabled, port must be 0 - if flags.PureL3 { - port = 0 - } - - // Create reals for this VS - reals := make([]*balancerpb.Real, 0, config.Ipv4Reals+config.Ipv6Reals) - for i := 0; i < config.Ipv4Reals+config.Ipv6Reals; i++ { - var realAddr netip.Addr - if i < config.Ipv4Reals { - // Generate IPv4 real address (10.0.0.0/8 range) - realAddr = netip.AddrFrom4( - [4]byte{10, 0, byte(i / 256), byte(i % 256)}, - ) - } else { - // Generate IPv6 real address (fd00::/8 range) - realAddr = netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(i / 256), byte(i % 256)}) - } - - // Create source address and mask (preserve original source) - var srcAddr, srcMask []byte - if addr.Is4() { - srcAddr = []byte{0, 0, 0, 0} - srcMask = []byte{0, 0, 0, 0} - } else { - srcAddr = make([]byte, 16) - srcMask = make([]byte, 16) - } - - reals = append(reals, &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: realAddr.AsSlice()}, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{Bytes: srcAddr}, - SrcMask: &balancerpb.Addr{Bytes: srcMask}, - }) - } - - scheduler := balancerpb.VsScheduler_SOURCE_HASH - if rng.Float32() < config.RoundRobinProb { - scheduler = balancerpb.VsScheduler_ROUND_ROBIN - } - - allowedSrc := make([]*balancerpb.AllowedSources, 0, config.AllowedSrcPerVs) - for i := 0; i < config.AllowedSrcPerVs; i++ { - if addr.Is4() { - allowedSrc = append(allowedSrc, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: []byte{byte(i / 256), byte(i % 256), 5, 5}, - }, - Mask: &balancerpb.Addr{ - Bytes: []byte{255, 255, 255, 255}, - }, - }}, - }) - } else { - allowedSrc = append(allowedSrc, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: []byte{byte(i / 256), byte(i % 256), 0, 2, 3, 0, 0, 29, 0, 43, 0, 16, 0, 0, 0, 0}, - }, - Mask: &balancerpb.Addr{ - Bytes: []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, - }, - }}, - }) - } - } - - peers := make([]*balancerpb.Addr, 0, 2) - for i := range 2 { - peers = append( - peers, - &balancerpb.Addr{ - Bytes: []byte{byte(i / 256), byte(i % 256), 10, 11}, - }, - ) - } - - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: addr.AsSlice()}, - Port: port, - Proto: proto, - }, - Scheduler: scheduler, - AllowedSrcs: allowedSrc, - Reals: reals, - Flags: flags, - Peers: peers, - } - } - - // Generate TCP IPv4 virtual services - for i := 0; i < config.TCPIPv4VS; i++ { - addr := netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)}) - virtualServices = append( - virtualServices, - createVS(addr, 80, balancerpb.TransportProto_TCP), - ) - } - - // Generate TCP IPv6 virtual services - for i := 0; i < config.TCPIPv6VS; i++ { - addr := netip.AddrFrom16( - [16]byte{ - 0x20, - 0x01, - 0x0d, - 0xb8, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - byte(i / 256), - byte(i % 256), - }, - ) - virtualServices = append( - virtualServices, - createVS(addr, 80, balancerpb.TransportProto_TCP), - ) - } - - // Generate UDP IPv4 virtual services - for i := 0; i < config.UDPIPv4VS; i++ { - addr := netip.AddrFrom4([4]byte{172, 16, byte(i / 256), byte(i % 256)}) - virtualServices = append( - virtualServices, - createVS(addr, 53, balancerpb.TransportProto_UDP), - ) - } - - // Generate UDP IPv6 virtual services - for i := 0; i < config.UDPIPv6Vs; i++ { - addr := netip.AddrFrom16( - [16]byte{ - 0xfc, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - byte(i / 256), - byte(i % 256), - }, - ) - virtualServices = append( - virtualServices, - createVS(addr, 53, balancerpb.TransportProto_UDP), - ) - } - - // Session timeouts (in seconds) - sessionTimeouts := &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 120, - TcpFin: 120, - Tcp: 3600, - Udp: 300, - Default: 300, - } - - // Source addresses for encapsulation - sourceV4 := netip.AddrFrom4([4]byte{10, 255, 255, 254}) - sourceV6 := netip.AddrFrom16( - [16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, - ) - - // Packet handler configuration - packetHandler := &balancerpb.PacketHandlerConfig{ - Vs: virtualServices, - SourceAddressV4: &balancerpb.Addr{Bytes: sourceV4.AsSlice()}, - SourceAddressV6: &balancerpb.Addr{Bytes: sourceV6.AsSlice()}, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: sessionTimeouts, - } - - // State configuration - capacity := uint64( - config.SessionTableCapacity, - ) - - stateConfig := &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: nil, - RefreshPeriod: nil, - Wlc: nil, - } - - return &balancerpb.BalancerConfig{ - PacketHandler: packetHandler, - State: stateConfig, - } -} - -const ( - DeviceName = "01:00.0" - PipelineName = "pipeline0" - FunctionName = "function0" - ChainName = "chain0" -) - -func setupYanet(shm *yanet.SharedMemory) error { - // Attach bootstrap agent to configure the controlplane - bootstrap, err := shm.AgentReattach("bootstrap", 0, 1<<20) - if err != nil { - return fmt.Errorf("failed to attach to bootstrap agent: %w", err) - } - - // Update function configuration - { - functionConfig := yanet.FunctionConfig{ - Name: FunctionName, - Chains: []yanet.FunctionChainConfig{ - { - Weight: 1, - Chain: yanet.ChainConfig{ - Name: ChainName, - Modules: []yanet.ChainModuleConfig{ - { - Type: "balancer", - Name: BalancerName, - }, - }, - }, - }, - }, - } - - if err := bootstrap.UpdateFunction(functionConfig); err != nil { - return fmt.Errorf("failed to update function: %w", err) - } - } - - // Update pipelines - { - inputPipelineConfig := yanet.PipelineConfig{ - Name: PipelineName, - Functions: []string{FunctionName}, - } - - dummyPipelineConfig := yanet.PipelineConfig{ - Name: "dummy", - Functions: []string{}, - } - - if err := bootstrap.UpdatePipeline(inputPipelineConfig); err != nil { - return fmt.Errorf("failed to update pipeline: %w", err) - } - - if err := bootstrap.UpdatePipeline(dummyPipelineConfig); err != nil { - return fmt.Errorf("failed to update pipeline: %w", err) - } - } - - // Update devices - { - deviceConfig := yanet.DeviceConfig{ - Name: DeviceName, - Input: []yanet.DevicePipelineConfig{ - { - Name: PipelineName, - Weight: 1, - }, - }, - Output: []yanet.DevicePipelineConfig{ - { - Name: "dummy", - Weight: 1, - }, - }, - } - - if err := bootstrap.UpdatePlainDevices([]yanet.DeviceConfig{deviceConfig}); err != nil { - return fmt.Errorf("failed to update devices: %w", err) - } - } - - return nil -} - -func Run(config *BenchConfig) error { - bench, err := NewBench(config.Workers, TotalMemory, CpMemory) - if err != nil { - return fmt.Errorf("failed to create new bench: %s", err) - } - defer bench.Free() - - logLevel := zapcore.InfoLevel - logger, _, _ := logging.Init(&logging.Config{ - Level: logLevel, - }) - agent, err := balancer.NewBalancerAgent( - bench.SharedMemory(), - datasize.ByteSize(AgentMemory), - logger, - ) - if err != nil { - return fmt.Errorf("failed to create new balancer agent: %s", err) - } - - balancerConfig := balancerConfig(config) - if err := agent.NewBalancerManager(BalancerName, balancerConfig); err != nil { - return fmt.Errorf("failed to create new balancer manager: %s", err) - } - - if err := setupYanet(bench.SharedMemory()); err != nil { - return fmt.Errorf("failed to setup yanet: %s", err) - } - - // enable all reals - bal, err := agent.BalancerManager(BalancerName) - if err != nil { - panic("balancer manager is incorrect") - } - - if err := enableAllReals(bal); err != nil { - return fmt.Errorf("failed to enable reals: %s", err) - } - - start := make(chan struct{}) - info := make(chan workerInfo) - perf := make(chan workerPerf, config.Workers) - var readyWg sync.WaitGroup - var wg sync.WaitGroup - wg.Add(config.Workers) - readyWg.Add(config.Workers) - - generator := NewGenerator(config, balancerConfig) - - for worker := range config.Workers { - packetLists, err := bench.MakePacketLists(config.BatchesPerWorker) - if err != nil { - return fmt.Errorf("failed to create packet lists: %s", err) - } - for idx := range packetLists { - if idx%100 == 0 { - logger.Infow( - "generating packets", - "worker", - worker, - "progress", - fmt.Sprintf( - "%.2f%%", - 100.0*float32(idx)/float32(len(packetLists)), - ), - ) - } - packets := generator.generateWorkerPackets( - worker, - config.PacketsPerBatch, - ) - if err := bench.InitPacketList(&packetLists[idx], packets...); err != nil { - return fmt.Errorf( - "failed to init packet list at index %d: %s", - idx, - err, - ) - } - } - logger.Infow("generated all packets", "worker", worker) - - go workerRoutine( - bench, - &wg, - &readyWg, - info, - perf, - start, - worker, - packetLists, - config.PacketsPerBatch*config.BatchesPerWorker, - ) - } - - // Variables to track total benchmark duration - var benchStart time.Time - var benchDuration time.Duration - - go func() { - readyWg.Wait() - fmt.Printf("All workers are ready\nPress any key to start...\n") - _, _ = bufio.NewReader(os.Stdin).ReadBytes('\n') - fmt.Println("Benchmark started") - benchStart = time.Now() - close(start) - wg.Wait() - benchDuration = time.Since(benchStart) - - fmt.Printf("All workers are finished\n") - close(info) - }() - - isErr := false - workerPerfs := make([]workerPerf, 0, config.Workers) - - for info := range info { - if info.isErr { - logger.Errorw(info.info, "worker", info.idx, "tid", info.tid) - isErr = true - } else { - logger.Infow(info.info, "worker", info.idx, "tid", info.tid) - } - } - - // Collect performance metrics - close(perf) - for p := range perf { - workerPerfs = append(workerPerfs, p) - } - - logger.Infow("done") - - // Print comprehensive balancer stats - printSeparator() - fmt.Printf("\n") - fmt.Printf(" BALANCER BENCHMARK RESULTS\n") - printSeparator() - - // Print worker performance summary - printWorkerPerformance(workerPerfs, benchDuration, config.Workers) - - fmt.Println() - - if isErr { - return fmt.Errorf("some workers failed") - } else { - return nil - } -} - -// formatNumber adds comma separators to large numbers -func formatNumber(n uint64) string { - if n < 1000 { - return fmt.Sprintf("%d", n) - } - str := fmt.Sprintf("%d", n) - var result []byte - for i, c := range str { - if i > 0 && (len(str)-i)%3 == 0 { - result = append(result, ',') - } - result = append(result, byte(c)) - } - return string(result) -} - -// printSeparator prints a separator line -func printSeparator() { - fmt.Println("================================================================================") -} - -// printWorkerPerformance prints worker performance summary -func printWorkerPerformance(workerPerfs []workerPerf, benchDuration time.Duration, numWorkers int) { - fmt.Println("\nWORKER PERFORMANCE") - fmt.Println("------------------") - - // Calculate total packets - var totalPackets int - - if len(workerPerfs) > 0 { - // Sort by worker index for consistent output - sortedPerfs := make([]workerPerf, len(workerPerfs)) - copy(sortedPerfs, workerPerfs) - // Simple bubble sort by idx - for i := range sortedPerfs { - for j := i + 1; j < len(sortedPerfs); j++ { - if sortedPerfs[i].idx > sortedPerfs[j].idx { - sortedPerfs[i], sortedPerfs[j] = sortedPerfs[j], sortedPerfs[i] - } - } - } - - // Print per-worker stats - for _, p := range sortedPerfs { - fmt.Printf("Worker %d: %s packets in %s (%.2f MPpS)\n", - p.idx, - formatNumber(uint64(p.packets)), - p.duration, - p.mpps) - totalPackets += p.packets - } - fmt.Println() - } - - // Print aggregate stats based on total benchmark duration - if benchDuration > 0 { - aggregateMpps := float64(totalPackets) / benchDuration.Seconds() / 1e6 - fmt.Printf("Total Duration: %s\n", benchDuration) - fmt.Printf("Total Packets: %s\n", formatNumber(uint64(totalPackets))) - fmt.Printf("Aggregate Throughput: %.2f MPpS\n", aggregateMpps) - if numWorkers > 0 { - avgMpps := aggregateMpps / float64(numWorkers) - fmt.Printf("Average per Worker: %.2f MPpS\n", avgMpps) - } - } else { - fmt.Printf("Total Packets: %s\n", formatNumber(uint64(totalPackets))) - fmt.Println("(Benchmark duration not available)") - } -} diff --git a/modules/balancer/bench/go/cmd.go b/modules/balancer/bench/go/cmd.go deleted file mode 100644 index fc3ee6830..000000000 --- a/modules/balancer/bench/go/cmd.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -// Command-line entry point for the balancer benchmark tool. -// Parses YAML configuration and executes the benchmark with specified parameters. - -import ( - "fmt" - "os" - "runtime" - - "github.com/stretchr/testify/assert/yaml" -) - -func main() { - runtime.GOMAXPROCS(10) - - var cfgPath string - if len(os.Args) == 2 { - cfgPath = os.Args[1] - } else { - fmt.Fprintf(os.Stderr, "usage: %s config.yaml\n", os.Args[0]) - os.Exit(2) - } - - data, err := os.ReadFile(cfgPath) - if err != nil { - fmt.Fprintf(os.Stderr, "read config %q: %v\n", cfgPath, err) - os.Exit(1) - } - - var cfg BenchConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - fmt.Fprintf(os.Stderr, "parse yaml: %v\n", err) - os.Exit(1) - } - - if err := Run(&cfg); err != nil { - fmt.Fprintf(os.Stderr, "FAILED: %v\n", err) - os.Exit(1) - } else { - fmt.Println("OK!") - } -} diff --git a/modules/balancer/bench/go/config.go b/modules/balancer/bench/go/config.go deleted file mode 100644 index c8c25715f..000000000 --- a/modules/balancer/bench/go/config.go +++ /dev/null @@ -1,37 +0,0 @@ -package main - -// Benchmark configuration structure defining test parameters including virtual service counts, -// real server distribution, protocol probabilities, packet generation settings, and worker configuration. - -type BenchConfig struct { - GreProb float32 `yaml:"gre_prob"` - FixMSSProb float32 `yaml:"fix_mss_prob"` - PureL3Prob float32 `yaml:"pure_l3_prob"` - OpsProb float32 `yaml:"ops_prob"` - - RoundRobinProb float32 `yaml:"round_robin_prob"` - - TCPIPv4VS int `yaml:"tcp_ipv4_vs"` - TCPIPv6VS int `yaml:"tcp_ipv6_vs"` - UDPIPv4VS int `yaml:"udp_ipv4_vs"` - UDPIPv6Vs int `yaml:"udp_ipv6_vs"` - - Ipv4Reals int `yaml:"ipv4_reals"` - Ipv6Reals int `yaml:"ipv6_reals"` - - AllowedSrcPerVs int `yaml:"allowed_src_per_vs"` - - NewSessionProb float32 `yaml:"new_session_prob"` - - IcmpProb float32 `yaml:"icmp_prob"` - IcmpRedirectProb float32 `yaml:"icmp_redirect_prob"` - - BatchesPerWorker int `yaml:"batches_per_worker"` - PacketsPerBatch int `yaml:"packets_per_batch"` - - mss int `yaml:"mss"` - - Workers int `yaml:"workers"` - - SessionTableCapacity int `yaml:"session_table_capacity"` -} diff --git a/modules/balancer/bench/go/ffi.go b/modules/balancer/bench/go/ffi.go deleted file mode 100644 index 8277f268e..000000000 --- a/modules/balancer/bench/go/ffi.go +++ /dev/null @@ -1,114 +0,0 @@ -package main - -// FFI bindings to C benchmark infrastructure providing memory allocation, -// packet list management, and dataplane packet handling for performance testing. - -/* -#cgo CFLAGS: -I../ -I../../../../ -I../../../../lib -#cgo LDFLAGS: -L../../../../build/modules/balancer/bench -lbalancer_bench -L../../../../build/lib/utils -llib_utils -L../../../../build/mock -lyanet_mock -L../../../../build/lib/dataplane/pipeline -lpipeline -L../../../../build/lib/dataplane/worker -lworker_dp -lnuma -#cgo LDFLAGS: -L../../../../build/modules/balancer/dataplane -lbalancer_dp -#cgo LDFLAGS: -L../../../../build/modules/decap/dataplane -ldecap_dp -#cgo LDFLAGS: -L../../../../build/modules/dscp/dataplane -ldscp_dp -#cgo LDFLAGS: -L../../../../build/modules/acl/dataplane -lacl_dp -#cgo LDFLAGS: -L../../../../build/modules/fwstate/dataplane -lfwstate_dp -#cgo LDFLAGS: -L../../../../build/modules/forward/dataplane -lforward_dp -#cgo LDFLAGS: -L../../../../build/modules/route/dataplane -lroute_dp -#cgo LDFLAGS: -L../../../../build/modules/nat64/dataplane -lnat64_dp -#cgo LDFLAGS: -L../../../../build/modules/pdump/dataplane -lpdump_dp -#cgo LDFLAGS: -L../../../../build/devices/plain/dataplane -lplain_dp -#cgo LDFLAGS: -L../../../../build/devices/vlan/dataplane -lvlan_dp -#include -#include "bench.h" -#include -enum { packet_list_align = _Alignof(struct packet_list) }; -void *bench_alloc_func = bench_alloc; -*/ -import "C" - -import ( - "fmt" - "unsafe" - - yanet "github.com/yanet-platform/yanet2/controlplane/ffi" - dataplane "github.com/yanet-platform/yanet2/lib/utils/go" - - // Import mock to link with modules - _ "github.com/yanet-platform/yanet2/mock/go" -) - -type Bench struct { - bench C.struct_bench -} - -func NewBench(workers, totalMemory, cpMemory int) (*Bench, error) { - b := &Bench{} - config := C.struct_bench_config{} - config.workers = C.size_t(workers) - config.total_memory = C.size_t(totalMemory) - config.cp_memory = C.size_t(cpMemory) - ec := C.bench_init(&b.bench, &config) - if ec != 0 { - str := C.bench_take_error(&b.bench) - return nil, fmt.Errorf( - "failed to initialize bench: %s", - C.GoString(str), - ) - } - return b, nil -} - -func (b *Bench) Free() { - C.bench_free(&b.bench) -} - -func (b *Bench) MakePacketLists(count int) ([]dataplane.PacketList, error) { - if count == 0 { - return nil, nil - } - mem := (C.bench_alloc( - unsafe.Pointer(&b.bench), - C.size_t(C.packet_list_align), - C.size_t(C.sizeof_struct_packet_list)*C.size_t(count), - )) - p := (*dataplane.PacketList)(unsafe.Pointer(mem)) - if p == nil { - return nil, fmt.Errorf("failed to allocate memory") - } - return unsafe.Slice(p, count), nil -} - -func (b *Bench) InitPacketList( - packetList *dataplane.PacketList, - packets ...dataplane.PacketData, -) error { - return dataplane.FillPacketListFromDataWithCustomAlloc( - packetList, - dataplane.NewAlloc( - unsafe.Pointer(&b.bench), - unsafe.Pointer(C.bench_alloc_func), - ), - packets..., - ) -} - -func (b *Bench) HandlePackets( - worker int, - packets []dataplane.PacketList, -) error { - ec := C.bench_handle_packets( - &b.bench, - C.size_t(worker), - (*C.struct_packet_list)(unsafe.Pointer(&packets[0])), - C.size_t(len(packets)), - ) - if ec != 0 { - return fmt.Errorf("failed to run bench: %d", ec) - } - return nil -} - -func (b *Bench) SharedMemory() *yanet.SharedMemory { - return yanet.NewSharedMemoryFromRaw( - unsafe.Pointer(C.bench_shared_memory(&b.bench)), - ) -} diff --git a/modules/balancer/bench/go/gen.go b/modules/balancer/bench/go/gen.go deleted file mode 100644 index 766e46e01..000000000 --- a/modules/balancer/bench/go/gen.go +++ /dev/null @@ -1,301 +0,0 @@ -package main - -// Packet generator for benchmark traffic creation, producing realistic TCP/UDP packets -// with configurable session reuse, client IP distribution, and MSS options for testing -// various load balancing scenarios and performance characteristics. - -import ( - "fmt" - "math/rand/v2" - "net/netip" - - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - dataplane "github.com/yanet-platform/yanet2/lib/utils/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" -) - -type session struct { - clientIP netip.Addr - clientPort uint16 - vsIP netip.Addr - vsPort uint16 - proto balancerpb.TransportProto -} - -type Generator struct { - bench *BenchConfig - generated int - rand *rand.Rand - balancer *balancerpb.BalancerConfig - sessions []session // Per-worker session storage - worker int -} - -func NewGenerator( - bench *BenchConfig, - balancer *balancerpb.BalancerConfig, -) *Generator { - return &Generator{ - bench: bench, - balancer: balancer, - generated: 0, - sessions: []session{}, - rand: rand.New(rand.NewPCG(3, 5)), - worker: -1, - } -} - -// getAllVirtualServices returns all virtual services from the balancer config -func (ctx *Generator) getAllVirtualServices() []*balancerpb.VirtualService { - return ctx.balancer.PacketHandler.Vs -} - -// selectRandomVS selects a random virtual service -func (ctx *Generator) selectRandomVS() *balancerpb.VirtualService { - vsList := ctx.getAllVirtualServices() - if len(vsList) == 0 { - return nil - } - idx := ctx.rand.IntN(len(vsList)) - return vsList[idx] -} - -// generateRandomIPInNetwork generates a random IP address within the given network -func (ctx *Generator) generateRandomIPInNetwork( - netAddr netip.Addr, - prefixLen uint32, -) netip.Addr { - if netAddr.Is4() { - // IPv4 - addrBytes := netAddr.As4() - // Generate random bits for the host part - hostBits := 32 - prefixLen - if hostBits > 0 { - // Generate random host part - for i := range hostBits { - byteIdx := 3 - (i / 8) - bitIdx := i % 8 - if ctx.rand.IntN(2) == 1 { - addrBytes[byteIdx] |= (1 << bitIdx) - } - } - } - return netip.AddrFrom4(addrBytes) - } else { - // IPv6 - addrBytes := netAddr.As16() - // Generate random bits for the host part - hostBits := 128 - prefixLen - if hostBits > 0 { - // Generate random host part - for i := range hostBits { - byteIdx := 15 - (i / 8) - bitIdx := i % 8 - if ctx.rand.IntN(2) == 1 { - addrBytes[byteIdx] |= (1 << bitIdx) - } - } - } - return netip.AddrFrom16(addrBytes) - } -} - -// generateClientIP generates a client IP address for the given virtual service -func (ctx *Generator) generateClientIP( - vs *balancerpb.VirtualService, -) netip.Addr { - vsAddr, ok := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if !ok { - panic("invalid VS address") - } - - // If VS has allowed sources, pick a random one and generate IP within that network - if len(vs.AllowedSrcs) > 0 { - idx := ctx.rand.IntN(len(vs.AllowedSrcs)) - allowedSrc := vs.AllowedSrcs[idx] - netAddr, ok := netip.AddrFromSlice(allowedSrc.Nets[0].Addr.Bytes) - if !ok { - panic("invalid allowed source address") - } - maskAddr, ok := netip.AddrFromSlice(allowedSrc.Nets[0].Mask.Bytes) - if !ok { - panic("invalid allowed source mask") - } - // Calculate prefix length from mask - prefixLen := uint32(0) - if netAddr.Is4() { - maskBytes := maskAddr.As4() - for _, b := range maskBytes { - for i := 7; i >= 0; i-- { - if (b & (1 << i)) != 0 { - prefixLen++ - } else { - break - } - } - } - } else { - maskBytes := maskAddr.As16() - for _, b := range maskBytes { - for i := 7; i >= 0; i-- { - if (b & (1 << i)) != 0 { - prefixLen++ - } else { - break - } - } - } - } - return ctx.generateRandomIPInNetwork(netAddr, prefixLen) - } - - // Otherwise generate random IP in appropriate range - if vsAddr.Is4() { - // Generate IPv4: 10.x.x.x - return netip.AddrFrom4([4]byte{ - 10, - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - }) - } else { - // Generate IPv6: fd00::x:x:x:x - return netip.AddrFrom16([16]byte{ - 0xfd, 0x00, 0, 0, 0, 0, 0, 0, - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - byte(ctx.rand.IntN(256)), - }) - } -} - -// generateClientPort generates a random ephemeral port -func (ctx *Generator) generateClientPort() uint16 { - // Ephemeral port range: 32768-65535 - return uint16(ctx.rand.IntN(65535-32768+1) + 32768) -} - -// createNewSession creates a new session for a random virtual service -func (ctx *Generator) createNewSession() session { - vs := ctx.selectRandomVS() - if vs == nil { - panic("no virtual services configured") - } - - vsIP, ok := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if !ok { - panic("invalid VS address in session creation") - } - - return session{ - clientIP: ctx.generateClientIP(vs), - clientPort: ctx.generateClientPort(), - vsIP: vsIP, - vsPort: uint16(vs.Id.Port), - proto: vs.Id.Proto, - } -} - -// createPacketForSession creates a packet for the given session -func (ctx *Generator) createPacketForSession(s session) dataplane.PacketData { - var packetLayers []gopacket.SerializableLayer - - if s.proto == balancerpb.TransportProto_TCP { - // Use utility function from utils/packet.go - tcp := &layers.TCP{SYN: true} - packetLayers = utils.MakeTCPPacket( - s.clientIP, - s.clientPort, - s.vsIP, - s.vsPort, - tcp, - ) - } else { - // UDP packet - packetLayers = utils.MakeUDPPacket( - s.clientIP, - s.clientPort, - s.vsIP, - s.vsPort, - ) - } - - // Serialize packet - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - if err := gopacket.SerializeLayers(buf, opts, packetLayers...); err != nil { - panic(fmt.Sprintf("failed to serialize packet: %v", err)) - } - - packet := gopacket.NewPacket( - buf.Bytes(), - layers.LayerTypeEthernet, - gopacket.Default, - ) - - // Handle MSS for TCP packets - if s.proto == balancerpb.TransportProto_TCP && ctx.bench.mss > 0 { - modifiedPacket, err := utils.InsertOrUpdateMSS( - packet, - uint16(ctx.bench.mss), - ) - if err != nil { - panic(fmt.Sprintf("failed to insert MSS: %v", err)) - } - packet = *modifiedPacket - } - - return dataplane.PacketData{ - Data: packet.Data(), - TxDeviceId: 0, - RxDeviceId: 0, - } -} - -// generateWorkerPackets generates packets for a worker based on the bench config -func (ctx *Generator) generateWorkerPackets( - worker int, - count int, -) []dataplane.PacketData { - packets := make([]dataplane.PacketData, 0, ctx.bench.PacketsPerBatch) - - if worker != ctx.worker { - // new worker - ctx.sessions = []session{} - ctx.worker = worker - } - - for i := 0; i < count; i++ { - var s session - - // Decide: new session or reuse? - if ctx.rand.Float32() < ctx.bench.NewSessionProb || - len(ctx.sessions) == 0 { - // Create new session - s = ctx.createNewSession() - ctx.sessions = append(ctx.sessions, s) - } else { - // Reuse random existing session - idx := ctx.rand.IntN(len(ctx.sessions)) - s = ctx.sessions[idx] - } - - // Generate packet for this session - packetData := ctx.createPacketForSession(s) - packets = append(packets, packetData) - - ctx.generated++ - } - - return packets -} diff --git a/modules/balancer/bench/meson.build b/modules/balancer/bench/meson.build deleted file mode 100644 index f7c2696fb..000000000 --- a/modules/balancer/bench/meson.build +++ /dev/null @@ -1,19 +0,0 @@ -dependencies = [ - lib_common_dep, - lib_lib_utils_dep, -] - -sources = files( - 'alloc.c', - 'bench.c', -) - -static_library( - 'balancer_bench', - sources, - c_args: yanet_c_args, - link_args: yanet_link_args, - dependencies: dependencies, - include_directories: [yanet_rootdir], - install: false, -) \ No newline at end of file diff --git a/modules/balancer/cli/Makefile b/modules/balancer/cli/Makefile index 4291d0cc5..77009c2e9 100644 --- a/modules/balancer/cli/Makefile +++ b/modules/balancer/cli/Makefile @@ -5,22 +5,15 @@ BINDIR ?= $(PREFIX)/bin ROOT_DIR := ../../.. TARGET_DIR := $(ROOT_DIR)/target RELEASE_DIR := $(TARGET_DIR)/release -BUILD_DIR := $(ROOT_DIR)/build/modules/balancer/cli -MODULES := balancer -BINARIES := $(addprefix yanet-,$(MODULES)) -TARGETS := $(addprefix build/,$(BINARIES)) - -build: $(TARGETS) - -build/%: - $(CARGO) build --release --package $* +build: + $(CARGO) build --release --package yanet-cli-balancer install: install -d $(DESTDIR)$(BINDIR) - install -m 755 $(addprefix $(RELEASE_DIR)/,$(BINARIES)) $(DESTDIR)$(BINDIR)/ + install -m 755 $(RELEASE_DIR)/yanet-cli-balancer $(DESTDIR)$(BINDIR)/ clean: $(CARGO) clean || true -.PHONY: build install clean \ No newline at end of file +.PHONY: build install clean diff --git a/modules/balancer/cli/README.md b/modules/balancer/cli/README.md deleted file mode 100644 index eb733141b..000000000 --- a/modules/balancer/cli/README.md +++ /dev/null @@ -1,499 +0,0 @@ -# yanet-balancer CLI - -Command-line interface for managing the YANET balancer module. - -## Installation - -```bash -cd modules/balancer/cli -make build -sudo make install -``` - -The binary will be installed as `yanet-balancer` in `/usr/local/bin` by default. - -## Usage - -```bash -yanet-balancer [OPTIONS] -``` - -### Global Options - -- `--endpoint ` - gRPC endpoint (default: `grpc://[::1]:8080`) -- `-v, -vv, -vvv` - Increase verbosity (info, debug, trace) -- `--help` - Show help information -- `--version` - Show version information - -### Commands - -#### 1. update - -Update balancer configuration from a YAML file. - -```bash -yanet-balancer update \ - --name \ - --instance \ - --config -``` - -**Example:** -```bash -yanet-balancer update \ - --name my-balancer \ - --instance 0 \ - --config example-config.yaml -``` - -See [`example-config.yaml`](example-config.yaml) for configuration file format. - -#### 2. reals enable - -Enable a real server (buffered). - -```bash -yanet-balancer reals enable \ - --name \ - --instance \ - --virtual-ip \ - --proto \ - --virtual-port \ - --real-ip \ - [--weight ] -``` - -**Example:** -```bash -yanet-balancer reals enable \ - --name my-balancer \ - --virtual-ip 192.0.2.1 \ - --proto tcp \ - --virtual-port 80 \ - --real-ip 10.1.1.1 \ - --weight 200 -``` - -#### 3. reals disable - -Disable a real server (buffered). - -```bash -yanet-balancer reals disable \ - --name \ - --instance \ - --virtual-ip \ - --proto \ - --virtual-port \ - --real-ip -``` - -**Example:** -```bash -yanet-balancer reals disable \ - --name my-balancer \ - --virtual-ip 192.0.2.1 \ - --proto tcp \ - --virtual-port 80 \ - --real-ip 10.1.1.2 -``` - -#### 4. reals flush - -Flush buffered real server updates. - -```bash -yanet-balancer reals flush \ - --name \ - --instance -``` - -**Example:** -```bash -yanet-balancer reals flush --name my-balancer --instance 0 -``` - -#### 5. config - -Show balancer configuration. - -```bash -yanet-balancer config \ - --name \ - --instance \ - [--format ] -``` - -**Examples:** -```bash -# Show as table (default) -yanet-balancer config --name my-balancer - -# Show as JSON -yanet-balancer config --name my-balancer --format json - -# Show as tree -yanet-balancer config --name my-balancer --format tree -``` - -#### 6. list - -List all balancer configurations. - -```bash -yanet-balancer list [--format ] -``` - -**Example:** -```bash -yanet-balancer list --format table -``` - -#### 7. stats - -Show configuration statistics. - -```bash -yanet-balancer stats \ - --name \ - --instance \ - --device \ - --pipeline \ - --function \ - --chain \ - [--format ] -``` - -**Example:** -```bash -yanet-balancer stats \ - --name my-balancer \ - --instance 0 \ - --device eth0 \ - --pipeline main \ - --function balancer \ - --chain default \ - --format table -``` - -#### 8. state - -Show balancer state information (active sessions, VS info, real info). - -```bash -yanet-balancer state \ - --name \ - --instance \ - [--format ] -``` - -**Example:** -```bash -yanet-balancer state --name my-balancer --format table -``` - -#### 9. sessions - -Show active sessions information. - -```bash -yanet-balancer sessions \ - --name \ - --instance \ - [--format ] -``` - -**Example:** -```bash -yanet-balancer sessions --name my-balancer --format table -``` - -## Output Formats - -All display commands support three output formats: - -- **table** (default) - Human-readable formatted tables with colored output -- **json** - Machine-readable JSON format (full gRPC response) -- **tree** - Hierarchical tree structure (full gRPC response) - -## Configuration File Format - -The configuration file is in YAML format with two main sections: - -### packet_handler - -Packet processing configuration containing: - -- `vs` - List of virtual services - - `addr` - Virtual IP address (IPv4 or IPv6) - - `port` - Port number (0 for pure_l3 mode) - - `proto` - Protocol: `TCP`, `tcp`, `UDP`, or `udp` - - `scheduler` - Scheduler algorithm: - - `SOURCE_HASH`, `source_hash`, `SH`, `sh` - Source hash scheduling - - `ROUND_ROBIN`, `round_robin`, `RR`, `rr` - Round-robin scheduling - - `flags` - Service flags: - - `gre` - GRE encapsulation - - `fix_mss` - TCP MSS fixing - - `ops` - One Packet Scheduler (no session tracking) - - `pure_l3` - Match all ports (port must be 0) - - `wlc` - Enable dynamic weight adjustment - - `allowed_srcs` - Source access control (see [Source Port Filtering](#source-port-filtering)) - - `reals` - List of real servers: - - `ip` - Real server IP address - - `port` - Port (reserved for future use, currently must be 0) - - `weight` - Server weight for scheduling - - `src_addr` - Source address for forwarding - - `src_mask` - Source mask - - `peers` - List of peer balancer IPs for session synchronization - -- `source_address_v4` - IPv4 source address for encapsulation -- `source_address_v6` - IPv6 source address for encapsulation -- `decap_addresses` - List of decapsulation addresses for tunnel unwrapping -- `sessions_timeouts` - Session timeout configuration (in seconds): - - `tcp_syn_ack` - TCP SYN-ACK timeout - - `tcp_syn` - TCP SYN timeout - - `tcp_fin` - TCP FIN timeout - - `tcp` - Established TCP connection timeout - - `udp` - UDP session timeout - - `default` - Default timeout for other protocols -- `wlc` - WLC scheduler configuration - -### state - -State management configuration: - -- `session_table` - Session table configuration: - - `capacity` - Maximum concurrent sessions - - `max_load_factor` - Trigger resize at this load factor (0.0-1.0) -- `wlc` - WLC (Weighted Least Connections) configuration: - - `power` - Adjustment aggressiveness - - `max_weight` - Maximum weight after adjustment -- `refresh_period_ms` - Periodic refresh interval in milliseconds (0 to disable) - -### Source Port Filtering - -The `allowed_srcs` field provides fine-grained access control based on source IP addresses and optionally source ports. This feature is useful for: - -- Restricting access to specific client networks -- Limiting connections to known source port ranges -- Implementing security policies based on ephemeral port usage -- Controlling access from specific applications or services - -#### Format Options - -**Simple Format (Backward Compatible)** - -A simple string specifying the network in CIDR or netmask notation. All source ports are allowed. - -```yaml -allowed_srcs: - - "10.0.0.0/8" # CIDR notation - - "172.16.0.0/255.240.0.0" # Netmask notation - - "192.168.1.0/24" # Single /24 network - - "203.0.113.42/32" # Single host - - "2001:db8::/32" # IPv6 network -``` - -**Structured Format with Port Filtering** - -An object with `network` and optional `ports` fields for fine-grained control: - -```yaml -allowed_srcs: - # Network with single source port - - network: "10.0.0.0/8" - ports: "443" - - # Network with port range - - network: "172.16.0.0/12" - ports: "1024-65535" - - # Network with multiple ports and ranges - - network: "192.168.0.0/16" - ports: "80,443,8000-9000,3000-3010" - - # Netmask notation with ports - - network: "198.51.100.0/255.255.255.0" - ports: "22,3389,5900-5910" -``` - -**Mixed Format** - -You can mix simple and structured formats in the same `allowed_srcs` list: - -```yaml -allowed_srcs: - # Simple format - all ports allowed - - "10.0.0.0/8" - - # Structured with ports - - network: "172.16.0.0/12" - ports: "443,8443" - - # Another simple entry - - "192.168.0.0/16" - - # Structured with port range - - network: "203.0.113.0/24" - ports: "1024-65535" -``` - -#### Port Specification Format - -The `ports` field accepts a comma-separated list of ports and port ranges: - -- **Single port**: `"80"`, `"443"`, `"8080"` -- **Port range**: `"1024-65535"`, `"8000-9000"` -- **Multiple entries**: `"80,443,8000-9000,3000-3010"` - -**Rules:** -- Port numbers must be in range 1-65535 -- In ranges, the `from` port must be ≤ `to` port -- Whitespace around commas and hyphens is ignored -- Empty or missing `ports` field means all ports are allowed - -#### Access Control Semantics - -- **Empty list** (`allowed_srcs: []`) - **DENY ALL** traffic (useful for maintenance mode) -- **Allow all IPv4**: `["0.0.0.0/0"]` -- **Allow all IPv6**: `["::/0"]` -- **Multiple entries** - Traffic is allowed if it matches ANY entry (OR logic) -- **Port filtering** - When ports are specified, BOTH network AND port must match - -#### Examples - -**Example 1: Restrict to internal networks only** -```yaml -allowed_srcs: - - "10.0.0.0/8" - - "172.16.0.0/12" - - "192.168.0.0/16" -``` - -**Example 2: Allow specific networks with ephemeral ports only** -```yaml -allowed_srcs: - - network: "10.0.0.0/8" - ports: "1024-65535" # Only ephemeral ports - - network: "172.16.0.0/12" - ports: "32768-65535" # Linux default ephemeral range -``` - -**Example 3: Mixed access - some networks unrestricted, others port-limited** -```yaml -allowed_srcs: - # Trusted network - all ports - - "10.0.0.0/8" - - # DMZ network - only HTTPS source ports - - network: "172.16.0.0/12" - ports: "443,8443" - - # External network - only high ports - - network: "203.0.113.0/24" - ports: "1024-65535" -``` - -**Example 4: Service-specific restrictions** -```yaml -allowed_srcs: - # Allow SSH clients (typically use high ports) - - network: "192.168.1.0/24" - ports: "1024-65535" - - # Allow RDP clients - - network: "192.168.2.0/24" - ports: "3389" - - # Allow VNC clients - - network: "192.168.3.0/24" - ports: "5900-5910" -``` - -See [`example-config.yaml`](example-config.yaml) for a complete example with all features. - -## Workflow Example - -```bash -# 1. Update configuration -yanet-balancer update \ - --name my-balancer \ - --config config.yaml - -# 2. Check configuration -yanet-balancer config --name my-balancer - -# 3. Disable a real server (buffered) -yanet-balancer reals disable \ - --name my-balancer \ - --virtual-ip 192.0.2.1 \ - --proto tcp \ - --virtual-port 80 \ - --real-ip 10.1.1.1 - -# 4. Enable another real server (buffered) -yanet-balancer reals enable \ - --name my-balancer \ - --virtual-ip 192.0.2.1 \ - --proto tcp \ - --virtual-port 80 \ - --real-ip 10.1.1.2 - -# 5. Apply all buffered changes -yanet-balancer reals flush --name my-balancer - -# 6. Check state -yanet-balancer state --name my-balancer - -# 7. View statistics -yanet-balancer stats \ - --name my-balancer \ - --device eth0 \ - --pipeline main \ - --function balancer \ - --chain default - -# 8. View active sessions -yanet-balancer sessions --name my-balancer -``` - -## Development - -### Building - -```bash -cargo build --release -``` - -### Testing - -```bash -cargo test -``` - -### Linting - -```bash -cargo clippy -``` - -### Example Outputs - -To see example outputs for all commands without connecting to a gRPC server, run: - -```bash -cargo run --example show_outputs -``` - -This will display sample outputs in all three formats (table, tree, JSON) for: -- config -- list -- state -- stats -- sessions - -The example creates mock data structures and uses the actual output formatting functions to demonstrate what the CLI output looks like. - -## License - -See the main YANET project license. \ No newline at end of file diff --git a/modules/balancer/cli/build.rs b/modules/balancer/cli/build.rs index bac4d88e2..a478fb718 100644 --- a/modules/balancer/cli/build.rs +++ b/modules/balancer/cli/build.rs @@ -1,11 +1,12 @@ use core::error::Error; pub fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../agent/balancerpb/balancer.proto"); - println!("cargo:rerun-if-changed=../agent/balancerpb/info.proto"); - println!("cargo:rerun-if-changed=../agent/balancerpb/module.proto"); - println!("cargo:rerun-if-changed=../agent/balancerpb/stats.proto"); - println!("cargo:rerun-if-changed=../agent/balancerpb/graph.proto"); + println!("cargo:rerun-if-changed=../controlplane/balancerpb/balancer.proto"); + println!("cargo:rerun-if-changed=../controlplane/balancerpb/common.proto"); + println!("cargo:rerun-if-changed=../controlplane/balancerpb/filter.proto"); + println!("cargo:rerun-if-changed=../controlplane/balancerpb/state.proto"); + println!("cargo:rerun-if-changed=../controlplane/balancerpb/memory.proto"); + println!("cargo:rerun-if-changed=../../../common/filterpb/filter.proto"); println!("cargo:rerun-if-changed=../../../common/commonpb/metric.proto"); tonic_build::configure() @@ -20,11 +21,8 @@ pub fn main() -> Result<(), Box> { .field_attribute("timeout", "#[serde(skip)]") .compile_protos( &[ - "modules/balancer/agent/balancerpb/balancer.proto", - "modules/balancer/agent/balancerpb/info.proto", - "modules/balancer/agent/balancerpb/module.proto", - "modules/balancer/agent/balancerpb/stats.proto", - "modules/balancer/agent/balancerpb/graph.proto", + "modules/balancer/controlplane/balancerpb/balancer.proto", + "common/filterpb/filter.proto", "common/commonpb/metric.proto", ], &["../../.."], diff --git a/modules/balancer/cli/example-config.yaml b/modules/balancer/cli/example-config.yaml deleted file mode 100644 index 519ed5b51..000000000 --- a/modules/balancer/cli/example-config.yaml +++ /dev/null @@ -1,256 +0,0 @@ -# Example Balancer Configuration -# This demonstrates all features of the new YAML format - -# Packet processing configuration -packet_handler: - # Virtual services list - vs: - # HTTP service with SOURCE_HASH scheduler - # Demonstrates: flexible scheduler names, CIDR notation, multiple reals - - addr: "192.0.2.1" - port: 80 - proto: TCP # Accepts: TCP, tcp - scheduler: SOURCE_HASH # Accepts: SOURCE_HASH, source_hash, SH, sh - flags: - gre: false - fix_mss: true - ops: false - pure_l3: false - wlc: false # Dynamic weight adjustment disabled - # CIDR notation for allowed sources - allowed_srcs: - - "10.0.0.0/8" - - "172.16.0.0/12" - - "192.168.0.0/16" - reals: - - ip: "10.1.1.1" - port: 0 # Reserved for future use - weight: 100 - src_addr: "192.0.2.1" - src_mask: "255.255.255.255" - - ip: "10.1.1.2" - port: 0 - weight: 50 - src_addr: "192.0.2.1" - src_mask: "255.255.255.255" - peers: - - "192.0.2.10" - - "192.0.2.11" - - # HTTPS service with ROUND_ROBIN and WLC - # Demonstrates: lowercase proto, WLC flag, allow all sources - - addr: "192.0.2.2" - port: 443 - proto: tcp # Lowercase also accepted - scheduler: round_robin # Accepts: ROUND_ROBIN, round_robin, RR, rr - flags: - gre: false - fix_mss: true - ops: false - pure_l3: false - wlc: true # Enable dynamic weight adjustment - # Allow all IPv4 sources - allowed_srcs: - - "0.0.0.0/0" - reals: - - ip: "10.2.1.1" - port: 0 - weight: 100 - src_addr: "192.0.2.2" - src_mask: "255.255.255.255" - - ip: "10.2.1.2" - port: 0 - weight: 100 - src_addr: "192.0.2.2" - src_mask: "255.255.255.255" - - ip: "10.2.1.3" - port: 0 - weight: 50 - src_addr: "192.0.2.2" - src_mask: "255.255.255.255" - peers: - - "192.0.2.10" - - # API service with source port filtering - # Demonstrates: structured allowed_srcs with port restrictions - - addr: "192.0.2.3" - port: 8080 - proto: TCP - scheduler: SOURCE_HASH - flags: - gre: false - fix_mss: true - ops: false - pure_l3: false - wlc: false - # Source port filtering - only allow specific source ports - # Useful for restricting access to known client port ranges - allowed_srcs: - # Simple format (backward compatible) - all ports allowed - - "10.0.0.0/8" - # Structured format with single port and tag - - network: "172.16.0.0/12" - ports: "443" - tag: "123" - # Structured format with port ranges - - network: "192.168.0.0/16" - ports: "1024-65535" - # Multiple ports and ranges - - network: "203.0.113.0/24" - ports: "80,443,8000-9000" - # Netmask notation also supported - - network: "198.51.100.0/255.255.255.0" - ports: "22,3389,5900-5910" - reals: - - ip: "10.3.1.1" - port: 0 - weight: 100 - src_addr: "192.0.2.3" - src_mask: "255.255.255.255" - - ip: "10.3.1.2" - port: 0 - weight: 100 - src_addr: "192.0.2.3" - src_mask: "255.255.255.255" - peers: [] - - # UDP DNS service with OPS mode - # Demonstrates: short scheduler name, empty allowed_srcs (allow none) - - addr: "192.0.2.4" - port: 53 - proto: UDP - scheduler: SH # Short form - flags: - gre: false - fix_mss: false - ops: true # One Packet Scheduler - no session tracking - pure_l3: false - wlc: false - # Empty = allow NONE (reject all traffic) - # Useful for maintenance mode or testing - allowed_srcs: [] - reals: - - ip: "10.4.1.1" - port: 0 - weight: 1 - src_addr: "192.0.2.4" - src_mask: "255.255.255.255" - - ip: "10.4.1.2" - port: 0 - weight: 1 - src_addr: "192.0.2.4" - src_mask: "255.255.255.255" - peers: [] - - # Pure L3 service (matches all ports) - # Demonstrates: pure_l3 mode, port must be 0 - - addr: "192.0.2.5" - port: 0 # MUST be 0 for pure_l3 mode - proto: tcp - scheduler: sh # Short form lowercase - flags: - gre: false - fix_mss: false - ops: false - pure_l3: true # Matches ALL traffic to this IP regardless of port - wlc: false - allowed_srcs: - - "10.0.0.0/8" - reals: - - ip: "10.3.1.1" - port: 0 - weight: 1 - src_addr: "192.0.2.5" - src_mask: "255.255.255.255" - peers: [] - - # Service with arbitrary netmask (non-contiguous) - # Demonstrates: netmask that doesn't correspond to a CIDR prefix - - addr: "192.0.2.6" - port: 3306 - proto: TCP - scheduler: SOURCE_HASH - flags: - gre: false - fix_mss: true - ops: false - pure_l3: false - wlc: false - # Arbitrary netmasks (not representable as CIDR prefix) - # These will be stored and displayed as netmask notation - allowed_srcs: - # Non-contiguous mask: 255.0.255.0 (matches 10.x.y.0 where x can be anything) - - "10.0.0.0/255.0.255.0" - # Another non-contiguous mask with port filtering - - network: "172.16.0.0/255.240.15.0" - ports: "3306,33060" - reals: - - ip: "10.6.1.1" - port: 0 - weight: 100 - src_addr: "192.0.2.6" - src_mask: "255.255.255.255" - - ip: "10.6.1.2" - port: 0 - weight: 100 - src_addr: "192.0.2.6" - src_mask: "255.255.255.255" - peers: [] - - # IPv6 service - # Demonstrates: IPv6 addresses, single host allowed source - - addr: "2001:db8::1" - port: 443 - proto: TCP - scheduler: RR - flags: - gre: false - fix_mss: true - ops: false - pure_l3: false - wlc: false - # Allow single IPv6 host - allowed_srcs: - - "2001:db8:1::42/128" - reals: - - ip: "2001:db8:100::1" - port: 0 - weight: 100 - src_addr: "2001:db8::1" - src_mask: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff" - peers: [] - - # Source addresses for encapsulation - source_address_v4: "192.0.2.1" - source_address_v6: "2001:db8::1" - - # Addresses for decapsulation (return traffic) - decap_addresses: - - "192.0.2.1" - - "2001:db8::1" - - # Session timeouts (in seconds) - sessions_timeouts: - tcp_syn_ack: 10 - tcp_syn: 10 - tcp_fin: 10 - tcp: 60 - udp: 30 - default: 60 - -# State management configuration -state: - # Session table configuration - session_table: - capacity: 1000000 # Maximum concurrent sessions - max_load_factor: 0.75 # Trigger resize at 75% capacity - - # WLC (Weighted Least Connections) configuration - # Used when vs.flags.wlc = true - wlc: - power: 10 # Adjustment aggressiveness - max_weight: 1000 # Maximum weight after adjustment - - # Periodic refresh for session counting and weight adjustment - # Set to 0 to disable - refresh_period_ms: 5000 # 5 seconds diff --git a/modules/balancer/cli/src/cmd.rs b/modules/balancer/cli/src/cmd.rs deleted file mode 100644 index 55268bad8..000000000 --- a/modules/balancer/cli/src/cmd.rs +++ /dev/null @@ -1,652 +0,0 @@ -//! CLI command definitions - -use clap::{ArgAction, Parser}; -use ync::client::ConnectionArgs; - -use crate::{output, rpc::balancerpb}; - -//////////////////////////////////////////////////////////////////////////////// -// Main Command -//////////////////////////////////////////////////////////////////////////////// - -/// Balancer module CLI -#[derive(Debug, Clone, Parser)] -#[command(version, about)] -#[command(flatten_help = true)] -pub struct Cmd { - #[clap(subcommand)] - pub mode: Mode, - - #[command(flatten)] - pub connection: ConnectionArgs, - - /// Log verbosity level - #[clap(short, action = ArgAction::Count, global = true)] - pub verbosity: u8, -} - -//////////////////////////////////////////////////////////////////////////////// -// Output Format -//////////////////////////////////////////////////////////////////////////////// - -/// Helper struct for output format flags -#[derive(Debug, Clone, Parser)] -pub struct FormatFlags { - /// Output in JSON format - #[clap(long, short = 'j', conflicts_with_all = ["tree", "table"])] - pub json: bool, - - /// Output in tree format - #[clap(long, short = 't', conflicts_with_all = ["json", "table"])] - pub tree: bool, - - /// Output in table format (default) - #[clap(long, conflicts_with_all = ["json", "tree"])] - pub table: bool, -} - -impl FormatFlags { - /// Convert flags to OutputFormat, defaulting to Table if none specified - pub fn to_format(&self) -> crate::output::OutputFormat { - if self.json { - output::OutputFormat::Json - } else if self.tree { - output::OutputFormat::Tree - } else { - // Default to table if no format specified or if --table is explicitly set - output::OutputFormat::Table - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Inspect Format Flags -//////////////////////////////////////////////////////////////////////////////// - -/// Helper struct for inspect output format flags -#[derive(Debug, Clone, Parser)] -pub struct InspectFormatFlags { - /// Output in JSON format (compact, not pretty) - #[clap(long, short = 'j', conflicts_with_all = ["normal", "detail"])] - pub json: bool, - - /// Output in normal format (default, excludes per-VS list) - #[clap(long, conflicts_with_all = ["json", "detail"])] - pub normal: bool, - - /// Output in detailed format (includes per-VS breakdown) - #[clap(long, conflicts_with_all = ["json", "normal"])] - pub detail: bool, -} - -impl InspectFormatFlags { - /// Convert flags to InspectOutputFormat, defaulting to Normal if none - /// specified - pub fn to_format(&self) -> crate::output::InspectOutputFormat { - if self.json { - output::InspectOutputFormat::Json - } else if self.detail { - output::InspectOutputFormat::Detail - } else { - // Default to normal if no format specified or if --normal is explicitly set - output::InspectOutputFormat::Normal - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Commands -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub enum Mode { - /// Update balancer configuration from YAML file - Update(UpdateCmd), - /// Manage real servers - Reals(RealsCmd), - /// Manage virtual services - Vs(VsCmd), - /// Show balancer configuration - Config(ConfigCmd), - /// List all balancer configurations - List(ListCmd), - /// Show configuration statistics - Stats(StatsCmd), - /// Show information about sessions - Info(InfoCmd), - /// Show active sessions - Sessions(SessionsCmd), - /// Show balancing graph with state and weights of reals - Graph(GraphCmd), - /// Show memory usage inspection - Inspect(InspectCmd), - Metrics(MetricsCmd), -} - -//////////////////////////////////////////////////////////////////////////////// -// Update Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct UpdateCmd { - /// Name of the module config - #[arg(long, short = 'n')] - pub name: String, - - /// Path to the YAML configuration file - #[arg(long, short = 'c')] - pub config: String, - - #[clap(flatten)] - pub format: FormatFlags, -} - -//////////////////////////////////////////////////////////////////////////////// -// Reals Commands -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct RealsCmd { - #[clap(subcommand)] - pub mode: RealsMode, -} - -#[derive(Debug, Clone, Parser)] -pub enum RealsMode { - /// Enable a real server (buffered) - Enable(EnableRealCmd), - /// Disable a real server (buffered) - Disable(DisableRealCmd), - /// Flush buffered real updates - Flush(FlushRealUpdatesCmd), -} - -#[derive(Debug, Clone, Parser)] -pub struct EnableRealCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - /// Virtual service in format "ip:port/proto", "[ipv6]:port/proto", or - /// "ipv6:port/proto" (e.g., "192.168.1.1:80/tcp", - /// "[2001:db8::1]:443/tcp", or "2001:db8::1:443/tcp") - #[arg(long)] - pub vs: String, - - /// List of real server IPs to enable - #[arg(long, required = true, num_args = 1..)] - pub reals: Vec, - - /// Optional new weight for the real servers - #[arg(long)] - pub weight: Option, - - /// Flush buffered updates immediately after enabling - #[arg(long, default_value_t = false)] - pub flush: bool, -} - -/// Helper function to parse VS identifier from string -/// Supports three formats: -/// - IPv4: "192.168.1.1:80/tcp" -/// - IPv6 with brackets: "[2001:db8::1]:443/tcp" -/// - IPv6 without brackets: "2001:db8::1:443/tcp" -fn parse_vs_identifier(vs_str: &str) -> Result<(std::net::IpAddr, u16, balancerpb::TransportProto), String> { - // Split by '/' to separate address:port from protocol - let vs_parts: Vec<&str> = vs_str.split('/').collect(); - if vs_parts.len() != 2 { - return Err(format!( - "invalid --vs format: '{}'. Expected format: 'ip:port/proto', '[ipv6]:port/proto', or 'ipv6:port/proto'", - vs_str - )); - } - - let addr_port = vs_parts[0]; - let proto_str = vs_parts[1]; - - // Parse protocol (case-insensitive) - let proto = match proto_str.to_uppercase().as_str() { - "TCP" => balancerpb::TransportProto::Tcp, - "UDP" => balancerpb::TransportProto::Udp, - _ => { - return Err(format!( - "invalid proto: '{}'. Expected 'tcp' or 'udp' (case-insensitive)", - proto_str - )); - } - }; - - // Parse IP and port, handling IPv4, IPv6 with brackets, and IPv6 without - // brackets - let (ip_str, port_str) = if addr_port.starts_with('[') { - // IPv6 bracket notation: [ipv6]:port - let bracket_end = addr_port.find(']').ok_or_else(|| { - format!( - "invalid IPv6 bracket notation: '{}'. Expected format: '[ipv6]:port'", - addr_port - ) - })?; - - let ip_part = &addr_port[1..bracket_end]; // Extract IP without brackets - let remaining = &addr_port[bracket_end + 1..]; - - if !remaining.starts_with(':') { - return Err(format!( - "invalid IPv6 bracket notation: '{}'. Expected ':' after ']'", - addr_port - )); - } - - let port_part = &remaining[1..]; // Skip the ':' - (ip_part, port_part) - } else { - // IPv4 or IPv6 without brackets - // Split from the right to get the last component (port) - let addr_port_parts: Vec<&str> = addr_port.rsplitn(2, ':').collect(); - if addr_port_parts.len() != 2 { - return Err(format!( - "invalid address:port format: '{}'. Expected format: 'ip:port' or '[ipv6]:port'", - addr_port - )); - } - - let port_part = addr_port_parts[0]; - let ip_part = addr_port_parts[1]; - - // Try to parse the port to validate it's actually a port number - // This helps distinguish IPv6 addresses from port numbers - if port_part.parse::().is_err() { - return Err(format!( - "invalid port in '{}'. Last component after ':' must be a valid port number", - addr_port - )); - } - - (ip_part, port_part) - }; - - let virtual_port: u16 = port_str - .parse() - .map_err(|e| format!("invalid port '{}': {}", port_str, e))?; - - let virtual_ip: std::net::IpAddr = ip_str.parse().map_err(|e| format!("invalid IP '{}': {}", ip_str, e))?; - - Ok((virtual_ip, virtual_port, proto)) -} - -impl TryFrom for balancerpb::UpdateRealsRequest { - type Error = String; - - fn try_from(cmd: EnableRealCmd) -> Result { - // Parse the --vs option - let (virtual_ip, virtual_port, proto) = parse_vs_identifier(&cmd.vs)?; - - // Create updates for all real IPs - let mut updates = Vec::new(); - for real_ip_str in &cmd.reals { - let real_ip: std::net::IpAddr = real_ip_str - .parse() - .map_err(|e| format!("invalid real IP '{}': {}", real_ip_str, e))?; - - let real_id = balancerpb::RealIdentifier { - vs: Some(balancerpb::VsIdentifier { - addr: Some(balancerpb::Addr { - bytes: match virtual_ip { - std::net::IpAddr::V4(ip) => ip.octets().to_vec(), - std::net::IpAddr::V6(ip) => ip.octets().to_vec(), - }, - }), - port: virtual_port as u32, - proto: proto as i32, - }), - real: Some(balancerpb::RelativeRealIdentifier { - ip: Some(balancerpb::Addr { - bytes: match real_ip { - std::net::IpAddr::V4(ip) => ip.octets().to_vec(), - std::net::IpAddr::V6(ip) => ip.octets().to_vec(), - }, - }), - port: 0, - }), - }; - - updates.push(balancerpb::RealUpdate { - real_id: Some(real_id), - enable: Some(true), - weight: cmd.weight, - }); - } - - Ok(Self { - name: cmd.name, - updates, - buffer: true, // Always buffer - }) - } -} - -#[derive(Debug, Clone, Parser)] -pub struct DisableRealCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - /// Virtual service in format "ip:port/proto", "[ipv6]:port/proto", or - /// "ipv6:port/proto" (e.g., "192.168.1.1:80/tcp", - /// "[2001:db8::1]:443/tcp", or "2001:db8::1:443/tcp") - #[arg(long)] - pub vs: String, - - /// List of real server IPs to disable - #[arg(long, required = true, num_args = 1..)] - pub reals: Vec, - - /// Flush buffered updates immediately after disabling - #[arg(long, default_value_t = false)] - pub flush: bool, -} - -impl TryFrom for balancerpb::UpdateRealsRequest { - type Error = String; - - fn try_from(cmd: DisableRealCmd) -> Result { - // Parse the --vs option - let (virtual_ip, virtual_port, proto) = parse_vs_identifier(&cmd.vs)?; - - // Create updates for all real IPs - let mut updates = Vec::new(); - for real_ip_str in &cmd.reals { - let real_ip: std::net::IpAddr = real_ip_str - .parse() - .map_err(|e| format!("invalid real IP '{}': {}", real_ip_str, e))?; - - let real_id = balancerpb::RealIdentifier { - vs: Some(balancerpb::VsIdentifier { - addr: Some(balancerpb::Addr { - bytes: match virtual_ip { - std::net::IpAddr::V4(ip) => ip.octets().to_vec(), - std::net::IpAddr::V6(ip) => ip.octets().to_vec(), - }, - }), - port: virtual_port as u32, - proto: proto as i32, - }), - real: Some(balancerpb::RelativeRealIdentifier { - ip: Some(balancerpb::Addr { - bytes: match real_ip { - std::net::IpAddr::V4(ip) => ip.octets().to_vec(), - std::net::IpAddr::V6(ip) => ip.octets().to_vec(), - }, - }), - port: 0, - }), - }; - - updates.push(balancerpb::RealUpdate { - real_id: Some(real_id), - enable: Some(false), - weight: None, - }); - } - - Ok(Self { - name: cmd.name, - updates, - buffer: true, // Always buffer - }) - } -} - -#[derive(Debug, Clone, Parser)] -pub struct FlushRealUpdatesCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, -} - -impl From for balancerpb::FlushRealUpdatesRequest { - fn from(cmd: FlushRealUpdatesCmd) -> Self { - Self { name: cmd.name } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Config Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct ConfigCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl From<&ConfigCmd> for balancerpb::ShowConfigRequest { - fn from(cmd: &ConfigCmd) -> Self { - Self { name: cmd.name.clone() } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// List Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct ListCmd { - #[clap(flatten)] - pub format: FormatFlags, -} - -//////////////////////////////////////////////////////////////////////////////// -// Stats Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct StatsCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - /// Device name (optional) - #[arg(long, short = 'd')] - pub device: Option, - - /// Pipeline name (optional) - #[arg(long, short = 'p')] - pub pipeline: Option, - - /// Function name (optional) - #[arg(long, short = 'f')] - pub function: Option, - - /// Chain name (optional) - #[arg(long, short = 'c')] - pub chain: Option, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl From<&StatsCmd> for balancerpb::ShowStatsRequest { - fn from(cmd: &StatsCmd) -> Self { - Self { - name: cmd.name.clone(), - r#ref: Some(balancerpb::PacketHandlerRef { - device: cmd.device.clone(), - pipeline: cmd.pipeline.clone(), - function: cmd.function.clone(), - chain: cmd.chain.clone(), - }), - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// State Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct InfoCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl From<&InfoCmd> for balancerpb::ShowInfoRequest { - fn from(cmd: &InfoCmd) -> Self { - Self { name: cmd.name.clone() } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Sessions Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct SessionsCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl From<&SessionsCmd> for balancerpb::ShowSessionsRequest { - fn from(cmd: &SessionsCmd) -> Self { - Self { name: cmd.name.clone() } - } -} - -#[derive(Debug, Clone, Parser)] -pub struct GraphCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl From<&GraphCmd> for balancerpb::ShowGraphRequest { - fn from(cmd: &GraphCmd) -> Self { - Self { name: cmd.name.clone() } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// VS Commands -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct VsCmd { - #[clap(subcommand)] - pub mode: VsMode, -} - -#[derive(Debug, Clone, Parser)] -pub enum VsMode { - /// Update or add virtual services from YAML file - Update(UpdateVsCmd), - /// Delete virtual services by identifier - Delete(DeleteVsCmd), -} - -#[derive(Debug, Clone, Parser)] -pub struct UpdateVsCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - /// Path to the YAML configuration file containing virtual services - #[arg(long, short = 'c')] - pub config: String, - - #[clap(flatten)] - pub format: FormatFlags, -} - -#[derive(Debug, Clone, Parser)] -pub struct DeleteVsCmd { - /// Name of the module config (optional, auto-selects if only one exists) - #[arg(long, short = 'n')] - pub name: Option, - - /// Virtual services to delete in format "ip:port/proto", - /// "[ipv6]:port/proto", or "ipv6:port/proto" (e.g., "192.168.1.1:80/ - /// tcp", "[2001:db8::1]:443/tcp", or "2001:db8::1:443/tcp") - #[arg(long, required = true, num_args = 1..)] - pub vs: Vec, - - #[clap(flatten)] - pub format: FormatFlags, -} - -impl TryFrom for balancerpb::DeleteVsRequest { - type Error = String; - - fn try_from(cmd: DeleteVsCmd) -> Result { - let mut vs_list = Vec::new(); - - for vs_str in &cmd.vs { - let (ip, port, proto) = parse_vs_identifier(vs_str)?; - - // Create a minimal VirtualService with only the identifier - // Other fields are ignored for delete operation - vs_list.push(balancerpb::VirtualService { - id: Some(balancerpb::VsIdentifier { - addr: Some(balancerpb::Addr { - bytes: match ip { - std::net::IpAddr::V4(ip) => ip.octets().to_vec(), - std::net::IpAddr::V6(ip) => ip.octets().to_vec(), - }, - }), - port: port as u32, - proto: proto as i32, - }), - scheduler: 0, // Ignored for delete - allowed_srcs: vec![], - reals: vec![], - flags: None, - peers: vec![], - }); - } - - Ok(Self { name: cmd.name, vs: vs_list }) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Inspect Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct InspectCmd { - #[clap(flatten)] - pub format: InspectFormatFlags, -} - -impl From<&InspectCmd> for balancerpb::ShowInspectRequest { - fn from(_cmd: &InspectCmd) -> Self { - Self {} - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Metrics Command -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Parser)] -pub struct MetricsCmd {} - -impl From<&MetricsCmd> for balancerpb::GetMetricsRequest { - fn from(_cmd: &MetricsCmd) -> Self { - Self {} - } -} diff --git a/modules/balancer/cli/src/config.rs b/modules/balancer/cli/src/config.rs new file mode 100644 index 000000000..56e49b1f1 --- /dev/null +++ b/modules/balancer/cli/src/config.rs @@ -0,0 +1,414 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +use serde::{Deserialize, Serialize}; +use yanet_cli_balancer::{balancerpb, filterpb}; + +use crate::ip_to_bytes; + +// ─── YAML Config Types ─────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BalancerConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub packet_handler: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub state: Option, +} + +impl BalancerConfig { + pub fn from_yaml_file(path: &str) -> Result> { + let file = std::fs::File::open(path)?; + let config = serde_yaml::from_reader(file)?; + Ok(config) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PacketHandlerConfig { + pub vs: Vec, + pub source_address_v4: String, + pub source_address_v6: String, + pub decap_addresses: Vec, + pub sessions_timeouts: SessionsTimeouts, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VirtualService { + pub addr: String, + pub port: u32, + pub proto: Proto, + pub scheduler: Scheduler, + pub flags: VsFlags, + #[serde(default)] + pub allowed_srcs: Vec, + pub reals: Vec, + #[serde(default)] + pub peers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Real { + pub ip: String, + #[serde(default)] + pub port: u32, + pub weight: u32, + pub src_addr: String, + pub src_mask: String, +} + +#[derive(Debug, Clone, Serialize)] +pub enum Scheduler { + Sh, + Wrr, + Wlc, +} + +impl<'de> Deserialize<'de> for Scheduler { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "sh" => Ok(Scheduler::Sh), + "wrr" => Ok(Scheduler::Wrr), + "wlc" => Ok(Scheduler::Wlc), + _ => Err(serde::de::Error::custom(format!( + "invalid scheduler: '{}'. Expected: sh, wrr or wlc", + s + ))), + } + } +} + +#[derive(Debug, Clone, Serialize)] +pub enum Proto { + Tcp, + Udp, +} + +impl<'de> Deserialize<'de> for Proto { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + match s.to_uppercase().as_str() { + "TCP" => Ok(Proto::Tcp), + "UDP" => Ok(Proto::Udp), + _ => Err(serde::de::Error::custom(format!( + "invalid protocol: '{}'. Expected: TCP, UDP", + s + ))), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct VsFlags { + #[serde(default)] + pub gre: bool, + #[serde(default)] + pub fix_mss: bool, + #[serde(default)] + pub ops: bool, + #[serde(default)] + pub pure_l3: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionsTimeouts { + pub tcp_syn_ack: u32, + pub tcp_syn: u32, + pub tcp_fin: u32, + pub tcp: u32, + pub udp: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub session_table: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub wlc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_period_ms: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionTableConfig { + pub capacity: u64, + pub max_load_factor: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WlcConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub power: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_weight: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum AllowedSrcEntry { + Simple(String), + Structured { + network: String, + #[serde(skip_serializing_if = "Option::is_none")] + ports: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tag: Option, + }, +} + +// ─── Network Parsing ───────────────────────────────────────────────────────── + +fn parse_network(network: &str) -> Result<(IpAddr, Vec), String> { + let parts: Vec<&str> = network.split('/').collect(); + if parts.len() != 2 { + return Err(format!("invalid network format: '{}'", network)); + } + + let addr: IpAddr = parts[0] + .parse() + .map_err(|e| format!("invalid IP in '{}': {}", network, e))?; + + if let Ok(prefix_len) = parts[1].parse::() { + let max = if addr.is_ipv4() { 32 } else { 128 }; + if prefix_len > max { + return Err(format!("prefix length {} exceeds max {}", prefix_len, max)); + } + Ok((addr, prefix_to_mask(addr.is_ipv4(), prefix_len))) + } else { + let mask: IpAddr = parts[1] + .parse() + .map_err(|e| format!("invalid netmask in '{}': {}", network, e))?; + match (addr, mask) { + (IpAddr::V4(_), IpAddr::V4(m)) => Ok((addr, m.octets().to_vec())), + (IpAddr::V6(_), IpAddr::V6(m)) => Ok((addr, m.octets().to_vec())), + _ => Err("IP version mismatch between address and mask".to_string()), + } + } +} + +fn prefix_to_mask(is_ipv4: bool, prefix_len: u32) -> Vec { + if is_ipv4 { + let mask: u32 = if prefix_len == 0 { 0 } else { !0u32 << (32 - prefix_len) }; + mask.to_be_bytes().to_vec() + } else { + let mut mask = [0u8; 16]; + let full_bytes = (prefix_len / 8) as usize; + let remaining_bits = prefix_len % 8; + for byte in mask.iter_mut().take(full_bytes) { + *byte = 0xFF; + } + if remaining_bits > 0 && full_bytes < 16 { + mask[full_bytes] = !0u8 << (8 - remaining_bits); + } + mask.to_vec() + } +} + +fn parse_ports(ports_str: &str) -> Result, String> { + let mut ranges = Vec::new(); + for part in ports_str.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + if part.contains('-') { + let ps: Vec<&str> = part.split('-').collect(); + if ps.len() != 2 { + return Err(format!("invalid port range: '{}'", part)); + } + let from: u32 = ps[0].trim().parse().map_err(|e| format!("invalid port: {}", e))?; + let to: u32 = ps[1].trim().parse().map_err(|e| format!("invalid port: {}", e))?; + ranges.push(filterpb::PortRange { from, to }); + } else { + let port: u32 = part.parse().map_err(|e| format!("invalid port: {}", e))?; + ranges.push(filterpb::PortRange { from: port, to: port }); + } + } + Ok(ranges) +} + +// ─── Conversion to Proto ───────────────────────────────────────────────────── + +impl TryFrom for balancerpb::BalancerConfig { + type Error = String; + + fn try_from(config: BalancerConfig) -> Result { + Ok(Self { + packet_handler: config.packet_handler.map(TryInto::try_into).transpose()?, + state: config.state.map(Into::into), + }) + } +} + +impl TryFrom for balancerpb::PacketHandlerConfig { + type Error = String; + + fn try_from(config: PacketHandlerConfig) -> Result { + let vs: Result, String> = config.vs.into_iter().map(TryInto::try_into).collect(); + + let source_v4: Ipv4Addr = config + .source_address_v4 + .parse() + .map_err(|e| format!("invalid source IPv4: {}", e))?; + let source_v6: Ipv6Addr = config + .source_address_v6 + .parse() + .map_err(|e| format!("invalid source IPv6: {}", e))?; + + let decap: Result, String> = config + .decap_addresses + .into_iter() + .map(|s| { + let addr: IpAddr = s.parse().map_err(|e| format!("invalid decap IP '{}': {}", s, e))?; + Ok(ip_to_bytes(addr)) + }) + .collect(); + + Ok(Self { + vs: vs?, + source_address_v4: source_v4.octets().to_vec(), + source_address_v6: source_v6.octets().to_vec(), + decap_addresses: decap?, + sessions_timeouts: Some(config.sessions_timeouts.into()), + }) + } +} + +impl TryFrom for balancerpb::VirtualService { + type Error = String; + + fn try_from(vs: VirtualService) -> Result { + let addr: IpAddr = vs.addr.parse().map_err(|e| format!("invalid VS IP: {}", e))?; + let proto = match vs.proto { + Proto::Tcp => balancerpb::TransportProto::Tcp, + Proto::Udp => balancerpb::TransportProto::Udp, + }; + let scheduler = match vs.scheduler { + Scheduler::Sh => balancerpb::VsScheduler::Sh, + Scheduler::Wrr => balancerpb::VsScheduler::Wrr, + Scheduler::Wlc => balancerpb::VsScheduler::Wlc, + }; + + let allowed_srcs: Result, String> = vs + .allowed_srcs + .iter() + .map(|entry| match entry { + AllowedSrcEntry::Simple(network_str) => { + let (addr, mask) = parse_network(network_str)?; + Ok(balancerpb::AllowedSources { + nets: vec![filterpb::IpNet { addr: ip_to_bytes(addr), mask }], + ports: vec![], + tag: None, + }) + } + AllowedSrcEntry::Structured { network, ports, tag } => { + let (addr, mask) = parse_network(network)?; + let port_ranges = match ports { + Some(s) => parse_ports(s)?, + None => vec![], + }; + Ok(balancerpb::AllowedSources { + nets: vec![filterpb::IpNet { addr: ip_to_bytes(addr), mask }], + ports: port_ranges, + tag: tag.clone(), + }) + } + }) + .collect(); + + let peers: Result, String> = vs + .peers + .iter() + .map(|p| { + let ip: IpAddr = p.parse().map_err(|e| format!("invalid peer IP '{}': {}", p, e))?; + Ok(ip_to_bytes(ip)) + }) + .collect(); + + let reals: Result, String> = vs.reals.into_iter().map(TryInto::try_into).collect(); + + Ok(Self { + id: Some(balancerpb::VsIdentifier { + addr: ip_to_bytes(addr), + port: vs.port, + proto: proto as i32, + }), + scheduler: scheduler as i32, + allowed_srcs: allowed_srcs?, + reals: reals?, + flags: Some(vs.flags.into()), + peers: peers?, + }) + } +} + +impl From for balancerpb::VsFlags { + fn from(f: VsFlags) -> Self { + Self { + gre: f.gre, + fix_mss: f.fix_mss, + ops: f.ops, + pure_l3: f.pure_l3, + } + } +} + +impl TryFrom for balancerpb::Real { + type Error = String; + + fn try_from(real: Real) -> Result { + let ip: IpAddr = real.ip.parse().map_err(|e| format!("invalid real IP: {}", e))?; + let src_addr: IpAddr = real.src_addr.parse().map_err(|e| format!("invalid src_addr: {}", e))?; + let src_mask: IpAddr = real.src_mask.parse().map_err(|e| format!("invalid src_mask: {}", e))?; + + Ok(Self { + id: Some(balancerpb::RelativeRealIdentifier { ip: ip_to_bytes(ip), port: real.port }), + weight: real.weight, + src: Some(filterpb::IpNet { + addr: ip_to_bytes(src_addr), + mask: ip_to_bytes(src_mask), + }), + }) + } +} + +impl From for balancerpb::SessionsTimeouts { + fn from(t: SessionsTimeouts) -> Self { + Self { + tcp_syn_ack: t.tcp_syn_ack, + tcp_syn: t.tcp_syn, + tcp_fin: t.tcp_fin, + tcp: t.tcp, + udp: t.udp, + } + } +} + +impl From for balancerpb::StateConfig { + fn from(config: StateConfig) -> Self { + Self { + session_table_capacity: config.session_table.as_ref().map(|st| st.capacity), + session_table_max_load_factor: config.session_table.as_ref().map(|st| st.max_load_factor), + wlc: config.wlc.map(Into::into), + refresh_period: config.refresh_period_ms.map(|ms| prost_types::Duration { + seconds: (ms / 1000) as i64, + nanos: ((ms % 1000) * 1_000_000) as i32, + }), + } + } +} + +impl From for balancerpb::WlcConfig { + fn from(c: WlcConfig) -> Self { + Self { + power: c.power, + max_weight: c.max_weight, + } + } +} diff --git a/modules/balancer/cli/src/display.rs b/modules/balancer/cli/src/display.rs new file mode 100644 index 000000000..97279cede --- /dev/null +++ b/modules/balancer/cli/src/display.rs @@ -0,0 +1,720 @@ +use tabled::Tabled; +use yanet_cli_balancer::balancerpb; +use ync::display::print_table; + +use crate::{bytes_to_ip, format_ip_port}; + +// ─── Compact (IPVS-style) Output ──────────────────────────────────────────── + +pub fn print_compact(state: &balancerpb::BalancerState) { + println!("Balancer: {}", state.balancer_name); + println!("Active Sessions: {}", format_number(state.active_sessions)); + println!(); + const LINE_WIDTH: usize = 112; + + println!("{:<46}{:<8}Flags", "VirtualService", "Sched",); + println!( + " -> {:<38}{:<10}{:<10}{:<12}{:<18}{:<18}", + "RemoteAddress:Port", "Enabled", "Weight", "Conns", "Pkts", "Bytes", + ); + println!("{}", "\u{2500}".repeat(LINE_WIDTH)); + + for (i, vs) in state.virtual_services.iter().enumerate() { + let Some(id) = &vs.id else { continue }; + let ip = match bytes_to_ip(&id.addr) { + Ok(ip) => ip, + Err(_) => continue, + }; + if i > 0 { + println!("{}", "\u{2500}".repeat(LINE_WIDTH)); + } + let proto = proto_str(id.proto); + let scheduler = scheduler_str(vs.scheduler); + let flags = flags_str(vs.flags.as_ref()); + let vs_str = format!("{}/{}", format_ip_port(ip, id.port), proto); + + println!("{:<46}{:<8}{}", vs_str, scheduler, flags); + + for real in &vs.reals { + let Some(rid) = &real.id else { continue }; + let rip = match bytes_to_ip(&rid.ip) { + Ok(ip) => ip, + Err(_) => continue, + }; + let real_addr = format_ip_port(rip, rid.port); + let rs = real.real_stats.as_ref(); + let enabled = if real.enabled { "true" } else { "false" }; + println!( + " -> {:<38}{:<10}{:<10}{:<12}{:<18}{:<18}", + real_addr, + enabled, + format_number(real.weight), + format_number(real.active_sessions), + format_number(rs.map_or(0, |s| s.packets)), + format_number(rs.map_or(0, |s| s.bytes)), + ); + } + } +} + +// ─── Module Stats ────────────────────────────────────────────────────────── + +fn print_module_stats(state: &balancerpb::BalancerState) { + println!("Module:"); + + let mut rows: Vec = Vec::new(); + + if let Some(c) = &state.common_stats { + rows.push(StatsRow::new( + "Common", + "Incoming Pkts", + format_number(c.incoming_packets), + )); + rows.push(StatsRow::new("", "Incoming Bytes", format_number(c.incoming_bytes))); + rows.push(StatsRow::new( + "", + "Unexpected Proto", + format_number(c.unexpected_network_proto), + )); + rows.push(StatsRow::new("", "Decap Success", format_number(c.decap_successful))); + rows.push(StatsRow::new("", "Decap Failed", format_number(c.decap_failed))); + rows.push(StatsRow::new("", "Outgoing Pkts", format_number(c.outgoing_packets))); + rows.push(StatsRow::new("", "Outgoing Bytes", format_number(c.outgoing_bytes))); + rows.push(StatsRow::empty()); + } + + if let Some(l) = &state.l4_stats { + rows.push(StatsRow::new("L4", "Incoming Pkts", format_number(l.incoming_packets))); + rows.push(StatsRow::new("", "Outgoing Pkts", format_number(l.outgoing_packets))); + rows.push(StatsRow::new("", "Select VS Fail", format_number(l.select_vs_failed))); + rows.push(StatsRow::new( + "", + "Select Real Fail", + format_number(l.select_real_failed), + )); + rows.push(StatsRow::new("", "Invalid Pkts", format_number(l.invalid_packets))); + rows.push(StatsRow::empty()); + } + + if let Some(icmp) = &state.icmp_ipv4_stats { + push_icmp_rows(&mut rows, "ICMPv4", icmp); + rows.push(StatsRow::empty()); + } + + if let Some(icmp) = &state.icmp_ipv6_stats { + push_icmp_rows(&mut rows, "ICMPv6", icmp); + } + + // Remove trailing empty row. + if rows + .last() + .is_some_and(|r| r.category.is_empty() && r.metric.is_empty()) + { + rows.pop(); + } + + print_table(rows); + println!(); +} + +fn push_icmp_rows(rows: &mut Vec, category: &str, icmp: &balancerpb::IcmpStats) { + rows.push(StatsRow::new( + category, + "Incoming Pkts", + format_number(icmp.incoming_packets), + )); + rows.push(StatsRow::new( + "", + "Src Not Allowed", + format_number(icmp.src_not_allowed), + )); + rows.push(StatsRow::new("", "Echo Responses", format_number(icmp.echo_responses))); + rows.push(StatsRow::new( + "", + "Payload Short IP", + format_number(icmp.payload_too_short_ip), + )); + rows.push(StatsRow::new( + "", + "Unmatch Src Orig", + format_number(icmp.unmatching_src_from_original), + )); + rows.push(StatsRow::new( + "", + "Payload Short Port", + format_number(icmp.payload_too_short_port), + )); + rows.push(StatsRow::new( + "", + "Unexpected Trans", + format_number(icmp.unexpected_transport), + )); + rows.push(StatsRow::new( + "", + "Unrecognized VS", + format_number(icmp.unrecognized_vs), + )); + rows.push(StatsRow::new( + "", + "Forwarded Pkts", + format_number(icmp.forwarded_packets), + )); + rows.push(StatsRow::new( + "", + "Broadcasted Pkts", + format_number(icmp.broadcasted_packets), + )); + rows.push(StatsRow::new("", "Clones Sent", format_number(icmp.packet_clones_sent))); + rows.push(StatsRow::new( + "", + "Clones Received", + format_number(icmp.packet_clones_received), + )); + rows.push(StatsRow::new( + "", + "Clone Failures", + format_number(icmp.packet_clone_failures), + )); +} + +// ─── Table View Output ───────────────────────────────────────────────────── + +pub struct ShowOptions { + pub stats: bool, + pub acl: bool, + pub peers: bool, + pub decap: bool, +} + +pub fn print_table_view(states: &[balancerpb::BalancerState], opts: &ShowOptions) { + for (i, state) in states.iter().enumerate() { + if i > 0 { + println!(); + } + print_table_view_state(state, opts); + } +} + +fn print_table_view_state(state: &balancerpb::BalancerState, opts: &ShowOptions) { + println!("Balancer: {}", state.balancer_name); + if let Some(r) = &state.r#ref { + print_ref_inline(r); + } + + if opts.stats { + println!("Active Sessions: {}", format_number(state.active_sessions)); + println!( + "Last Packet: {}", + state + .last_packet_timestamp + .as_ref() + .map_or_else(|| "N/A".to_string(), format_timestamp), + ); + println!(); + } + + if opts.decap { + print_decap(state); + } + + if opts.stats { + print_module_stats(state); + } + + for vs in &state.virtual_services { + print_table_view_vs(vs, opts); + } +} + +fn print_table_view_vs(vs: &balancerpb::VsState, opts: &ShowOptions) { + let Some(id) = &vs.id else { return }; + let ip = match bytes_to_ip(&id.addr) { + Ok(ip) => ip, + Err(_) => return, + }; + let proto = proto_str(id.proto).to_uppercase(); + let addr_port = format_ip_port(ip, id.port); + let scheduler = scheduler_str(vs.scheduler); + let flags = flags_str(vs.flags.as_ref()); + + println!("VS {}/{}:", addr_port, proto); + println!(" Scheduler: {}", scheduler); + if !flags.is_empty() { + println!(" Flags: {}", flags); + } + + if opts.stats { + println!(" Active Sessions: {}", format_number(vs.active_sessions)); + if let Some(ts) = &vs.last_packet_timestamp { + println!(" Last Packet: {}", format_timestamp(ts)); + } + if let Some(stats) = &vs.stats { + print_vs_stats(stats); + } + } + + if opts.peers { + print_vs_peers(vs); + } + + if opts.acl { + print_vs_acl(vs, opts.stats); + } + + // Reals table. + let real_rows: Vec<_> = vs + .reals + .iter() + .filter_map(|real| { + let rid = real.id.as_ref()?; + let rip = bytes_to_ip(&rid.ip).ok()?; + let real_addr = format_ip_port(rip, rid.port); + + if opts.stats { + let rs = real.real_stats.as_ref(); + Some(RealTableRow::Stats(RealStatsRow { + real: real_addr, + enabled: if real.enabled { + "true".to_string() + } else { + "false".to_string() + }, + weight: format_number(real.weight), + effective_weight: format_number(real.effective_weight), + packets: format_number(rs.map_or(0, |s| s.packets)), + bytes: format_number(rs.map_or(0, |s| s.bytes)), + active_sessions: format_number(real.active_sessions), + created_sessions: format_number(real.real_stats.map_or(0, |s| s.created_sessions)), + last_packet: real + .last_packet_timestamp + .as_ref() + .map_or_else(|| "-".to_string(), format_timestamp), + disabled_pkts: format_number(rs.map_or(0, |s| s.packets_real_disabled)), + icmp_pkts: format_number(rs.map_or(0, |s| s.error_icmp_packets)), + })) + } else { + Some(RealTableRow::Basic(RealBasicRow { + real: real_addr, + weight: format_number(real.weight), + effective_weight: format_number(real.effective_weight), + enabled: if real.enabled { + "true".to_string() + } else { + "false".to_string() + }, + })) + } + }) + .collect(); + + if !real_rows.is_empty() { + match &real_rows[0] { + RealTableRow::Stats(_) => { + let rows: Vec = real_rows + .into_iter() + .filter_map(|r| match r { + RealTableRow::Stats(s) => Some(s), + _ => None, + }) + .collect(); + print_table(rows); + } + RealTableRow::Basic(_) => { + let rows: Vec = real_rows + .into_iter() + .filter_map(|r| match r { + RealTableRow::Basic(b) => Some(b), + _ => None, + }) + .collect(); + print_table(rows); + } + } + } + println!(); +} + +fn print_vs_stats(stats: &balancerpb::VsStats) { + println!(" Incoming Packets: {}", format_number(stats.incoming_packets)); + println!(" Incoming Bytes: {}", format_number(stats.incoming_bytes)); + println!(" Outgoing Packets: {}", format_number(stats.outgoing_packets)); + println!(" Outgoing Bytes: {}", format_number(stats.outgoing_bytes)); + println!(" Created Sessions: {}", format_number(stats.created_sessions)); + println!( + " Packet Src Not Allowed: {}", + format_number(stats.packet_src_not_allowed) + ); + println!(" No Reals: {}", format_number(stats.no_reals)); + println!( + " Session Table Overflow: {}", + format_number(stats.session_table_overflow) + ); + println!(" Echo ICMP Packets: {}", format_number(stats.echo_icmp_packets)); + println!(" Error ICMP Packets: {}", format_number(stats.error_icmp_packets)); + println!(" Real Is Disabled: {}", format_number(stats.real_is_disabled)); + println!(" Real Is Removed: {}", format_number(stats.real_is_removed)); + println!( + " Not Rescheduled Packets: {}", + format_number(stats.not_rescheduled_packets) + ); + println!( + " Broadcasted ICMP Packets: {}", + format_number(stats.broadcasted_icmp_packets) + ); +} + +fn print_vs_acl(vs: &balancerpb::VsState, with_stats: bool) { + if vs.allowed_srcs_config.is_empty() { + return; + } + + // Build a map from tag -> passes for stats lookup. + let stats_map: std::collections::HashMap<&str, u64> = + vs.allowed_sources.iter().map(|s| (s.tag.as_str(), s.passes)).collect(); + + println!(" Allowed Sources:"); + for src in &vs.allowed_srcs_config { + let tag = src.tag.as_deref().unwrap_or(""); + if !tag.is_empty() { + if with_stats { + let passes = stats_map.get(tag).copied().unwrap_or(0); + println!(" Tag: {} (passes: {})", tag, format_number(passes)); + } else { + println!(" Tag: {}", tag); + } + } + for net in &src.nets { + let addr = bytes_to_ip(&net.addr) + .map(|ip| ip.to_string()) + .unwrap_or_else(|_| "?".to_string()); + let mask = bytes_to_ip(&net.mask) + .map(|ip| ip.to_string()) + .unwrap_or_else(|_| "?".to_string()); + println!(" Net: {}/{}", addr, mask); + } + for pr in &src.ports { + if pr.from == pr.to { + println!(" Port: {}", pr.from); + } else { + println!(" Ports: {}-{}", pr.from, pr.to); + } + } + } +} + +fn print_vs_peers(vs: &balancerpb::VsState) { + if vs.peers.is_empty() { + return; + } + println!(" Peers:"); + for peer in &vs.peers { + if let Ok(ip) = bytes_to_ip(peer) { + println!(" {}", ip); + } + } +} + +fn print_decap(state: &balancerpb::BalancerState) { + if !state.source_ipv4.is_empty() { + if let Ok(ip) = bytes_to_ip(&state.source_ipv4) { + println!("Source IPv4: {}", ip); + } + } + if !state.source_ipv6.is_empty() { + if let Ok(ip) = bytes_to_ip(&state.source_ipv6) { + println!("Source IPv6: {}", ip); + } + } + if !state.decap_addresses.is_empty() { + println!("Decap Addresses:"); + for addr in &state.decap_addresses { + if let Ok(ip) = bytes_to_ip(addr) { + println!(" {}", ip); + } + } + } + println!(); +} + +enum RealTableRow { + Basic(RealBasicRow), + Stats(RealStatsRow), +} + +#[derive(Tabled)] +struct RealBasicRow { + #[tabled(rename = "Real")] + real: String, + #[tabled(rename = "Enabled")] + enabled: String, + #[tabled(rename = "Wght")] + weight: String, + #[tabled(rename = "Eff Wght")] + effective_weight: String, +} + +// ─── Sessions Output ──────────────────────────────────────────────────────── + +pub fn print_sessions_header() { + println!( + "{:<40} {:<40} {:<50} {:<8} {:<8} {:<8}", + "VS", "Real", "Client", "Expires", "Timeout", "Age" + ); +} + +pub fn print_session(session: &balancerpb::Session) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let vs = format_vs_id(session.vs_id.as_ref()); + let real = format_real_id(session.real_id.as_ref()); + let client = format_client(session); + let expires = format_expires(session, now); + let age = format_age(session, now); + let timeout = format_timeout(session); + + println!( + "{:<40} {:<40} {:<50} {:<8} {:<8} {:<8}", + vs, real, client, expires, timeout, age + ); +} + +fn format_vs_id(vs_id: Option<&balancerpb::VsIdentifier>) -> String { + vs_id + .and_then(|id| { + bytes_to_ip(&id.addr) + .ok() + .map(|ip| format!("{}/{}", format_ip_port(ip, id.port), proto_str(id.proto))) + }) + .unwrap_or_else(|| "-".to_string()) +} + +fn format_real_id(real_id: Option<&balancerpb::RelativeRealIdentifier>) -> String { + real_id + .and_then(|r| bytes_to_ip(&r.ip).ok().map(|ip| format_ip_port(ip, r.port))) + .unwrap_or_else(|| "-".to_string()) +} + +fn format_client(session: &balancerpb::Session) -> String { + bytes_to_ip(&session.client_addr) + .ok() + .map(|ip| format_ip_port(ip, session.client_port)) + .unwrap_or_else(|| "-".to_string()) +} + +fn format_expires(session: &balancerpb::Session, now: i64) -> String { + match (session.last_packet_timestamp.as_ref(), session.timeout.as_ref()) { + (Some(last_packet), Some(timeout)) => { + let remaining = (last_packet.seconds + timeout.seconds - now).max(0); + format!("{}", remaining) + } + _ => "-".to_string(), + } +} + +fn format_timeout(session: &balancerpb::Session) -> String { + session + .timeout + .as_ref() + .map_or_else(|| "-".to_string(), |d| format!("{}", d.seconds)) +} + +fn format_age(session: &balancerpb::Session, now: i64) -> String { + session + .create_timestamp + .as_ref() + .map_or_else(|| "-".to_string(), |ts| format!("{}", (now - ts.seconds).max(0))) +} + +// ─── Tabled Row Types ─────────────────────────────────────────────────────── + +#[derive(Tabled)] +struct StatsRow { + #[tabled(rename = "Category")] + category: String, + #[tabled(rename = "Metric")] + metric: String, + #[tabled(rename = "Value")] + value: String, +} + +impl StatsRow { + fn new(category: &str, metric: &str, value: String) -> Self { + Self { + category: category.to_string(), + metric: metric.to_string(), + value, + } + } + + fn empty() -> Self { + Self { + category: String::new(), + metric: String::new(), + value: String::new(), + } + } +} + +#[derive(Tabled)] +struct RealStatsRow { + #[tabled(rename = "Real")] + real: String, + #[tabled(rename = "Enabled")] + enabled: String, + #[tabled(rename = "Wght")] + weight: String, + #[tabled(rename = "Eff Wght")] + effective_weight: String, + #[tabled(rename = "Pkts")] + packets: String, + #[tabled(rename = "Bytes")] + bytes: String, + #[tabled(rename = "Last Pkt")] + last_packet: String, + #[tabled(rename = "Dis Pkts")] + disabled_pkts: String, + #[tabled(rename = "ICMP Err")] + icmp_pkts: String, + #[tabled(rename = "Sess Act")] + active_sessions: String, + #[tabled(rename = "Sess Crt")] + created_sessions: String, +} + +// ─── JSON Prettification ──────────────────────────────────────────────────── + +/// Recursively walk a JSON value and prettify it for human-readable output: +/// - Convert byte arrays (IP addresses) into IP strings. +/// - Convert known enum integer values into short string names. +pub fn prettify_json(value: &mut serde_json::Value) { + match value { + serde_json::Value::Array(arr) => { + if let Some(ip) = try_bytes_to_ip_string(arr) { + *value = serde_json::Value::String(ip); + } else { + for item in arr.iter_mut() { + prettify_json(item); + } + } + } + serde_json::Value::Object(map) => { + prettify_enum(map, "scheduler", |v| { + balancerpb::VsScheduler::try_from(v) + .ok() + .map(|s| s as i32) + .map(scheduler_str) + }); + prettify_enum(map, "proto", |v| { + balancerpb::TransportProto::try_from(v).ok().map(|p| match p { + balancerpb::TransportProto::Tcp => "tcp", + balancerpb::TransportProto::Udp => "udp", + }) + }); + for (_, v) in map.iter_mut() { + prettify_json(v); + } + } + _ => {} + } +} + +fn prettify_enum( + map: &mut serde_json::Map, + key: &str, + to_str: impl FnOnce(i32) -> Option<&'static str>, +) { + if let Some(val) = map.get(key).and_then(|v| v.as_i64()) { + if let Some(name) = to_str(val as i32) { + map.insert(key.to_string(), serde_json::Value::String(name.to_string())); + } + } +} + +fn try_bytes_to_ip_string(arr: &[serde_json::Value]) -> Option { + if arr.len() != 4 && arr.len() != 16 { + return None; + } + let bytes: Vec = arr + .iter() + .map(|v| v.as_u64().and_then(|n| u8::try_from(n).ok())) + .collect::>>()?; + + let ip = crate::bytes_to_ip(&bytes).ok()?; + Some(ip.to_string()) +} + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +fn proto_str(proto: i32) -> &'static str { + match balancerpb::TransportProto::try_from(proto) { + Ok(balancerpb::TransportProto::Tcp) => "TCP", + Ok(balancerpb::TransportProto::Udp) => "UDP", + _ => "???", + } +} + +fn scheduler_str(scheduler: i32) -> &'static str { + match balancerpb::VsScheduler::try_from(scheduler) { + Ok(balancerpb::VsScheduler::Sh) => "sh", + Ok(balancerpb::VsScheduler::Wrr) => "wrr", + Ok(balancerpb::VsScheduler::Wlc) => "wlc", + _ => "???", + } +} + +fn flags_str(flags: Option<&balancerpb::VsFlags>) -> String { + let Some(f) = flags else { + return String::new(); + }; + let mut parts = Vec::new(); + if f.gre { + parts.push("gre"); + } + if f.fix_mss { + parts.push("mss"); + } + if f.ops { + parts.push("ops"); + } + if f.pure_l3 { + parts.push("l3"); + } + parts.join(",") +} + +fn print_ref_inline(r: &balancerpb::PacketHandlerRef) { + let mut parts = Vec::new(); + if let Some(d) = &r.device { + parts.push(format!("Device: {}", d)); + } + if let Some(p) = &r.pipeline { + parts.push(format!("Pipeline: {}", p)); + } + if let Some(f) = &r.function { + parts.push(format!("Function: {}", f)); + } + if let Some(c) = &r.chain { + parts.push(format!("Chain: {}", c)); + } + if !parts.is_empty() { + println!("{}", parts.join(" | ")); + } +} + +pub fn format_number(n: u64) -> String { + n.to_string() +} + +fn format_timestamp(ts: &prost_types::Timestamp) -> String { + if ts.seconds == 0 && ts.nanos == 0 { + return "N/A".to_string(); + } + let ndt = chrono::DateTime::from_timestamp(ts.seconds, ts.nanos as u32); + match ndt { + Some(dt) => dt.format("%Y-%m-%d %H:%M:%S").to_string(), + None => "-".to_string(), + } +} diff --git a/modules/balancer/cli/src/entities.rs b/modules/balancer/cli/src/entities.rs deleted file mode 100644 index a282628cc..000000000 --- a/modules/balancer/cli/src/entities.rs +++ /dev/null @@ -1,789 +0,0 @@ -//! Data structures and helper functions for balancer entities - -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - -use serde::{Deserialize, Serialize}; - -use crate::rpc::balancerpb; - -//////////////////////////////////////////////////////////////////////////////// -// Helper Functions -//////////////////////////////////////////////////////////////////////////////// - -/// Convert IP address to bytes -pub fn ip_to_bytes(ip: IpAddr) -> Vec { - match ip { - IpAddr::V4(ipv4) => ipv4.octets().to_vec(), - IpAddr::V6(ipv6) => ipv6.octets().to_vec(), - } -} - -/// Convert bytes to IP address -pub fn bytes_to_ip(bytes: &[u8]) -> Result { - match bytes.len() { - 4 => { - let arr: [u8; 4] = bytes.try_into().map_err(|_| "invalid IPv4 bytes")?; - Ok(IpAddr::V4(Ipv4Addr::from(arr))) - } - 16 => { - let arr: [u8; 16] = bytes.try_into().map_err(|_| "invalid IPv6 bytes")?; - Ok(IpAddr::V6(Ipv6Addr::from(arr))) - } - _ => Err(format!("invalid IP address length: {}", bytes.len())), - } -} - -/// Convert Addr protobuf message to IP address -pub fn addr_to_ip(addr: &balancerpb::Addr) -> Result { - bytes_to_ip(&addr.bytes) -} - -/// Convert optional Addr protobuf message to IP address -pub fn opt_addr_to_ip(addr: &Option) -> Result { - addr.as_ref() - .ok_or_else(|| "missing address".to_string()) - .and_then(addr_to_ip) -} - -/// Format AllowedSources protobuf message as a string -/// Returns networks in CIDR notation (if possible) or netmask notation with -/// optional port ranges and tag Format: "10.0.0.0/8, 192.168.0.0/16 -/// [80,443,1024-65535] [tag: 100]" -pub fn format_allowed_src(allowed_src: &balancerpb::AllowedSources) -> Result { - if allowed_src.nets.is_empty() { - return Err("no networks specified".to_string()); - } - - // Format all networks - let network_strs: Result, String> = allowed_src - .nets - .iter() - .map(|net| { - let addr = opt_addr_to_ip(&net.addr)?; - let mask_bytes = net.mask.as_ref().ok_or("missing mask")?.bytes.as_slice(); - - // Try to convert to CIDR prefix, fall back to netmask notation if not - // contiguous - match mask_to_prefix(mask_bytes) { - Ok(prefix_len) => Ok(format!("{}/{}", addr, prefix_len)), - Err(_) => { - // Non-contiguous mask - display as netmask notation - let mask_ip = bytes_to_ip(mask_bytes)?; - Ok(format!("{}/{}", addr, mask_ip)) - } - } - }) - .collect(); - - let mut result = network_strs?.join(", "); - - // Add port ranges if present in square brackets - if !allowed_src.ports.is_empty() { - let port_strs: Vec = allowed_src - .ports - .iter() - .map(|pr| { - if pr.from == pr.to { - format!("{}", pr.from) - } else { - format!("{}-{}", pr.from, pr.to) - } - }) - .collect(); - result.push_str(&format!(" [{}]", port_strs.join(","))); - } - - // Add tag - match &allowed_src.tag { - Some(tag) if !tag.is_empty() => { - result.push_str(&format!(" [tag: {}]", tag)); - } - _ => { - result.push_str(" [tag: None]"); - } - } - - Ok(result) -} - -/// Parse CIDR notation (e.g., "192.168.0.0/24") -/// This function is kept for backward compatibility but internally uses -/// parse_network -#[allow(dead_code)] -pub fn parse_cidr(cidr: &str) -> Result<(IpAddr, u32), String> { - let (addr, mask_bytes) = parse_network(cidr)?; - - // Convert mask bytes back to prefix length for backward compatibility - let prefix_len = mask_to_prefix(&mask_bytes)?; - Ok((addr, prefix_len)) -} - -/// Parse network specification in either CIDR or netmask notation -/// Returns (address, mask_bytes) where mask_bytes is the actual netmask -/// -/// Examples: -/// - CIDR: "10.0.0.0/8" -> (10.0.0.0, [255, 0, 0, 0]) -/// - Netmask: "10.0.0.0/255.0.0.0" -> (10.0.0.0, [255, 0, 0, 0]) -pub fn parse_network(network: &str) -> Result<(IpAddr, Vec), String> { - let parts: Vec<&str> = network.split('/').collect(); - if parts.len() != 2 { - return Err(format!( - "invalid network format: '{}'. Expected format: '192.168.0.0/24' or '192.168.0.0/255.255.0.0'", - network - )); - } - - let addr: IpAddr = parts[0] - .parse() - .map_err(|e| format!("invalid IP address in network '{}': {}", network, e))?; - - let mask_part = parts[1]; - - // Try to parse as CIDR prefix length first - if let Ok(prefix_len) = mask_part.parse::() { - // CIDR notation (e.g., "10.0.0.0/8") - let max_prefix = match addr { - IpAddr::V4(_) => 32, - IpAddr::V6(_) => 128, - }; - - if prefix_len > max_prefix { - return Err(format!( - "invalid prefix length: {} (max {} for {})", - prefix_len, - max_prefix, - if addr.is_ipv4() { "IPv4" } else { "IPv6" } - )); - } - - // Convert prefix length to netmask bytes - let mask_bytes = prefix_to_mask(addr.is_ipv4(), prefix_len)?; - Ok((addr, mask_bytes)) - } else { - // Try to parse as netmask (e.g., "10.0.0.0/255.0.0.0") - let mask: IpAddr = mask_part - .parse() - .map_err(|e| format!("invalid netmask in network '{}': {}", network, e))?; - - // Validate that address and mask are same IP version - match (addr, mask) { - (IpAddr::V4(_), IpAddr::V4(mask_v4)) => Ok((addr, mask_v4.octets().to_vec())), - (IpAddr::V6(_), IpAddr::V6(mask_v6)) => Ok((addr, mask_v6.octets().to_vec())), - _ => Err(format!( - "IP version mismatch: address is {} but mask is {}", - if addr.is_ipv4() { "IPv4" } else { "IPv6" }, - if mask.is_ipv4() { "IPv4" } else { "IPv6" } - )), - } - } -} - -/// Convert CIDR prefix length to netmask bytes -fn prefix_to_mask(is_ipv4: bool, prefix_len: u32) -> Result, String> { - if is_ipv4 { - // IPv4: 32 bits - if prefix_len > 32 { - return Err(format!("invalid IPv4 prefix length: {}", prefix_len)); - } - - let mask: u32 = if prefix_len == 0 { 0 } else { !0u32 << (32 - prefix_len) }; - - Ok(mask.to_be_bytes().to_vec()) - } else { - // IPv6: 128 bits - if prefix_len > 128 { - return Err(format!("invalid IPv6 prefix length: {}", prefix_len)); - } - - let mut mask = [0u8; 16]; - let full_bytes = (prefix_len / 8) as usize; - let remaining_bits = prefix_len % 8; - - // Fill complete bytes with 0xFF - for byte in mask.iter_mut().take(full_bytes) { - *byte = 0xFF; - } - - // Fill partial byte - if remaining_bits > 0 && full_bytes < 16 { - mask[full_bytes] = !0u8 << (8 - remaining_bits); - } - - Ok(mask.to_vec()) - } -} - -/// Convert netmask bytes to prefix length -pub fn mask_to_prefix(mask_bytes: &[u8]) -> Result { - let mut prefix_len = 0u32; - let mut seen_zero = false; - - for &byte in mask_bytes { - if seen_zero { - // After seeing a zero bit, all remaining bits must be zero - if byte != 0 { - return Err("invalid netmask: non-contiguous mask bits".to_string()); - } - } else { - // Count leading ones in this byte - let leading_ones = byte.leading_ones(); - prefix_len += leading_ones; - - if leading_ones < 8 { - // Found first zero bit - seen_zero = true; - - // Verify remaining bits in this byte are zero - let remaining_bits = byte & (!0u8 >> leading_ones); - if remaining_bits != 0 { - return Err("invalid netmask: non-contiguous mask bits".to_string()); - } - } - } - } - - Ok(prefix_len) -} - -/// Parse port specification string into PortRange vector -/// Format: "80,443,8000-9000,1024-1030" -/// -/// Examples: -/// - Single port: "80" -> [PortRange { from: 80, to: 80 }] -/// - Range: "1024-65535" -> [PortRange { from: 1024, to: 65535 }] -/// - Multiple: "80,443,8000-9000" -> [PortRange { from: 80, to: 80 }, PortRange -/// { from: 443, to: 443 }, PortRange { from: 8000, to: 9000 }] -pub fn parse_ports(ports_str: &str) -> Result, String> { - let mut ranges = Vec::new(); - - for part in ports_str.split(',') { - let part = part.trim(); - if part.is_empty() { - continue; - } - - if part.contains('-') { - // Parse range "1024-65535" - let parts: Vec<&str> = part.split('-').collect(); - if parts.len() != 2 { - return Err(format!("invalid port range: '{}'. Expected format: 'from-to'", part)); - } - - let from: u32 = parts[0] - .trim() - .parse() - .map_err(|e| format!("invalid port '{}': {}", parts[0], e))?; - let to: u32 = parts[1] - .trim() - .parse() - .map_err(|e| format!("invalid port '{}': {}", parts[1], e))?; - - if !(1..=65535).contains(&from) { - return Err(format!("port {} out of range (1-65535)", from)); - } - if !(1..=65535).contains(&to) { - return Err(format!("port {} out of range (1-65535)", to)); - } - if from > to { - return Err(format!("invalid range {}-{}: from > to", from, to)); - } - - ranges.push(PortRange { from, to }); - } else { - // Parse single port "80" - let port: u32 = part.parse().map_err(|e| format!("invalid port '{}': {}", part, e))?; - - if !(1..=65535).contains(&port) { - return Err(format!("port {} out of range (1-65535)", port)); - } - - ranges.push(PortRange { from: port, to: port }); - } - } - - if ranges.is_empty() { - return Err("no valid port ranges specified".to_string()); - } - - Ok(ranges) -} - -/// Format bytes as human-readable size -pub fn format_bytes(bytes: u64) -> String { - const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"]; - let mut size = bytes as f64; - let mut unit_idx = 0; - - while size >= 1024.0 && unit_idx < UNITS.len() - 1 { - size /= 1024.0; - unit_idx += 1; - } - - if unit_idx == 0 { - format!("{} {}", bytes, UNITS[unit_idx]) - } else { - format!("{:.1} {}", size, UNITS[unit_idx]) - } -} - -/// Format number with thousands separators -pub fn format_number(n: u64) -> String { - let s = n.to_string(); - let mut result = String::new(); - - for (count, c) in s.chars().rev().enumerate() { - if count > 0 && count % 3 == 0 { - result.push(','); - } - result.push(c); - } - - result.chars().rev().collect() -} - -//////////////////////////////////////////////////////////////////////////////// -// Configuration Structures -//////////////////////////////////////////////////////////////////////////////// - -/// Port range for source filtering -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PortRange { - pub from: u32, - pub to: u32, -} - -/// Allowed source entry - supports both simple string and structured format -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum AllowedSrcEntry { - /// Simple network string (backward compatible) - /// Supports both CIDR ("10.0.0.0/8") and netmask ("10.0.0.0/255.0.0.0") - /// notation - Simple(String), - - /// Structured format with optional port restrictions and tag - Structured { - /// Network in CIDR or netmask notation - network: String, - - /// Optional port restrictions (comma-separated ranges) - /// Format: "80,443,8000-9000,1024-1030" - #[serde(skip_serializing_if = "Option::is_none")] - ports: Option, - - /// Optional tag for tracking (empty/None = no tag) - #[serde(skip_serializing_if = "Option::is_none")] - tag: Option, - }, -} - -/// Complete balancer configuration for YAML -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BalancerConfig { - /// Packet processing configuration (optional for UPDATE) - #[serde(skip_serializing_if = "Option::is_none")] - pub packet_handler: Option, - - /// State management configuration (optional for UPDATE) - #[serde(skip_serializing_if = "Option::is_none")] - pub state: Option, -} - -impl BalancerConfig { - /// Load configuration from a YAML file - pub fn from_yaml_file(path: &str) -> Result> { - let file = std::fs::File::open(path)?; - let config = serde_yaml::from_reader(file)?; - Ok(config) - } -} - -/// Configuration containing only virtual services list -/// Used for UpdateVS command -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VsListConfig { - pub vs: Vec, -} - -impl VsListConfig { - /// Load configuration from a YAML file - pub fn from_yaml_file(path: &str) -> Result> { - let file = std::fs::File::open(path)?; - let config = serde_yaml::from_reader(file)?; - Ok(config) - } -} - -/// Packet processing configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PacketHandlerConfig { - pub vs: Vec, - pub source_address_v4: String, - pub source_address_v6: String, - pub decap_addresses: Vec, - pub sessions_timeouts: SessionsTimeouts, -} - -/// Virtual service configuration (flat structure in YAML) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VirtualService { - // Flat fields (no nested 'id') - pub addr: String, - pub port: u32, - pub proto: Proto, - - pub scheduler: Scheduler, - pub flags: VsFlags, - - /// Allowed source networks with optional port restrictions - /// Supports both simple CIDR strings and structured format with ports - /// Empty list = allow NONE (reject all) - /// ["0.0.0.0/0"] = allow all IPv4 - /// ["::/0"] = allow all IPv6 - #[serde(default)] - pub allowed_srcs: Vec, - - pub reals: Vec, - #[serde(default)] - pub peers: Vec, -} - -/// Real server configuration (flat structure in YAML) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Real { - // Flat fields (no nested 'id') - pub ip: String, - pub port: u32, // Reserved for future use - - pub weight: u32, - pub src_addr: String, - pub src_mask: String, -} - -/// Scheduler algorithm with flexible parsing -#[derive(Debug, Clone, Serialize)] -pub enum Scheduler { - SourceHash, - RoundRobin, -} - -impl<'de> Deserialize<'de> for Scheduler { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - match s.to_uppercase().as_str() { - "SOURCE_HASH" | "SH" => Ok(Scheduler::SourceHash), - "ROUND_ROBIN" | "RR" => Ok(Scheduler::RoundRobin), - _ => Err(serde::de::Error::custom(format!( - "invalid scheduler: '{}'. Expected: SOURCE_HASH, source_hash, SH, sh, ROUND_ROBIN, round_robin, RR, rr", - s - ))), - } - } -} - -/// Protocol with flexible parsing -#[derive(Debug, Clone, Serialize)] -pub enum Proto { - Tcp, - Udp, -} - -impl<'de> Deserialize<'de> for Proto { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - match s.to_uppercase().as_str() { - "TCP" => Ok(Proto::Tcp), - "UDP" => Ok(Proto::Udp), - _ => Err(serde::de::Error::custom(format!( - "invalid protocol: '{}'. Expected: TCP, tcp, UDP, udp", - s - ))), - } - } -} - -/// Virtual service flags -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct VsFlags { - #[serde(default)] - pub gre: bool, - #[serde(default)] - pub fix_mss: bool, - #[serde(default)] - pub ops: bool, - #[serde(default)] - pub pure_l3: bool, - #[serde(default)] - pub wlc: bool, // Dynamic weight adjustment flag -} - -/// Session timeouts -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SessionsTimeouts { - pub tcp_syn_ack: u32, - pub tcp_syn: u32, - pub tcp_fin: u32, - pub tcp: u32, - pub udp: u32, - pub default: u32, -} - -/// State management configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StateConfig { - /// Session table configuration (optional in UPDATE) - #[serde(skip_serializing_if = "Option::is_none")] - pub session_table: Option, - - /// WLC configuration (optional in UPDATE) - #[serde(skip_serializing_if = "Option::is_none")] - pub wlc: Option, - - /// Refresh period in milliseconds (optional in UPDATE) - /// Set to 0 to disable periodic refresh - #[serde(skip_serializing_if = "Option::is_none")] - pub refresh_period_ms: Option, -} - -/// Session table configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SessionTableConfig { - pub capacity: u64, - pub max_load_factor: f32, -} - -/// WLC configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WlcConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub power: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_weight: Option, -} - -//////////////////////////////////////////////////////////////////////////////// -// Conversion to protobuf -//////////////////////////////////////////////////////////////////////////////// - -impl TryFrom for balancerpb::BalancerConfig { - type Error = String; - - fn try_from(config: BalancerConfig) -> Result { - Ok(Self { - packet_handler: config.packet_handler.map(TryInto::try_into).transpose()?, - state: config.state.map(Into::into), - }) - } -} - -impl TryFrom for balancerpb::PacketHandlerConfig { - type Error = String; - - fn try_from(config: PacketHandlerConfig) -> Result { - let virtual_services: Result, String> = config.vs.into_iter().map(TryInto::try_into).collect(); - - let source_v4: Ipv4Addr = config - .source_address_v4 - .parse() - .map_err(|e| format!("invalid source IPv4 '{}': {}", config.source_address_v4, e))?; - let source_v6: Ipv6Addr = config - .source_address_v6 - .parse() - .map_err(|e| format!("invalid source IPv6 '{}': {}", config.source_address_v6, e))?; - - let decap: Result, String> = config - .decap_addresses - .into_iter() - .map(|s| { - let addr: IpAddr = s.parse().map_err(|e| format!("invalid decap IP '{}': {}", s, e))?; - Ok(balancerpb::Addr { bytes: ip_to_bytes(addr) }) - }) - .collect(); - - Ok(Self { - vs: virtual_services?, - source_address_v4: Some(balancerpb::Addr { bytes: source_v4.octets().to_vec() }), - source_address_v6: Some(balancerpb::Addr { bytes: source_v6.octets().to_vec() }), - decap_addresses: decap?, - sessions_timeouts: Some(config.sessions_timeouts.into()), - }) - } -} - -impl TryFrom for balancerpb::VirtualService { - type Error = String; - - fn try_from(vs: VirtualService) -> Result { - let addr: IpAddr = vs - .addr - .parse() - .map_err(|e| format!("invalid VS IP address '{}': {}", vs.addr, e))?; - - let proto = match vs.proto { - Proto::Tcp => balancerpb::TransportProto::Tcp, - Proto::Udp => balancerpb::TransportProto::Udp, - }; - - let scheduler = match vs.scheduler { - Scheduler::SourceHash => balancerpb::VsScheduler::SourceHash, - Scheduler::RoundRobin => balancerpb::VsScheduler::RoundRobin, - }; - - // Parse allowed sources with optional port ranges - let allowed_srcs: Result, String> = vs - .allowed_srcs - .iter() - .map(|entry| { - match entry { - AllowedSrcEntry::Simple(network_str) => { - // Simple format - no port restrictions, no tag - let (addr, mask_bytes) = parse_network(network_str)?; - Ok(balancerpb::AllowedSources { - nets: vec![balancerpb::Net { - addr: Some(balancerpb::Addr { bytes: ip_to_bytes(addr) }), - mask: Some(balancerpb::Addr { bytes: mask_bytes }), - }], - ports: vec![], // Empty = all ports allowed - tag: None, // No tag - }) - } - AllowedSrcEntry::Structured { network, ports, tag } => { - // Structured format with optional ports and tag - let (addr, mask_bytes) = parse_network(network)?; - - let port_ranges = if let Some(ports_str) = ports { - parse_ports(ports_str)? - .into_iter() - .map(|pr| balancerpb::PortsRange { from: pr.from, to: pr.to }) - .collect() - } else { - vec![] // No ports specified = all ports allowed - }; - - Ok(balancerpb::AllowedSources { - nets: vec![balancerpb::Net { - addr: Some(balancerpb::Addr { bytes: ip_to_bytes(addr) }), - mask: Some(balancerpb::Addr { bytes: mask_bytes }), - }], - ports: port_ranges, - tag: tag.as_ref().map(|t| t.to_string()), // Convert Option to Option - }) - } - } - }) - .collect(); - - let peers: Result, String> = vs - .peers - .iter() - .map(|p| { - let ip: IpAddr = p.parse().map_err(|e| format!("invalid peer IP '{}': {}", p, e))?; - Ok(balancerpb::Addr { bytes: ip_to_bytes(ip) }) - }) - .collect(); - - let reals: Result, String> = vs.reals.into_iter().map(TryInto::try_into).collect(); - - // Create VsIdentifier from flat fields - let id = Some(balancerpb::VsIdentifier { - addr: Some(balancerpb::Addr { bytes: ip_to_bytes(addr) }), - port: vs.port, - proto: proto as i32, - }); - - Ok(Self { - id, - scheduler: scheduler as i32, - allowed_srcs: allowed_srcs?, - reals: reals?, - flags: Some(vs.flags.into()), - peers: peers?, - }) - } -} - -impl From for balancerpb::VsFlags { - fn from(flags: VsFlags) -> Self { - Self { - gre: flags.gre, - fix_mss: flags.fix_mss, - ops: flags.ops, - pure_l3: flags.pure_l3, - wlc: flags.wlc, - } - } -} - -impl TryFrom for balancerpb::Real { - type Error = String; - - fn try_from(real: Real) -> Result { - let ip: IpAddr = real - .ip - .parse() - .map_err(|e| format!("invalid real IP '{}': {}", real.ip, e))?; - let src_addr: IpAddr = real - .src_addr - .parse() - .map_err(|e| format!("invalid src address '{}': {}", real.src_addr, e))?; - let src_mask: IpAddr = real - .src_mask - .parse() - .map_err(|e| format!("invalid src mask '{}': {}", real.src_mask, e))?; - - // Create RelativeRealIdentifier from flat fields - let id = Some(balancerpb::RelativeRealIdentifier { - ip: Some(balancerpb::Addr { bytes: ip_to_bytes(ip) }), - port: real.port, - }); - - Ok(Self { - id, - weight: real.weight, - src_addr: Some(balancerpb::Addr { bytes: ip_to_bytes(src_addr) }), - src_mask: Some(balancerpb::Addr { bytes: ip_to_bytes(src_mask) }), - }) - } -} - -impl From for balancerpb::SessionsTimeouts { - fn from(timeouts: SessionsTimeouts) -> Self { - Self { - tcp_syn_ack: timeouts.tcp_syn_ack, - tcp_syn: timeouts.tcp_syn, - tcp_fin: timeouts.tcp_fin, - tcp: timeouts.tcp, - udp: timeouts.udp, - default: timeouts.default, - } - } -} - -impl From for balancerpb::StateConfig { - fn from(config: StateConfig) -> Self { - Self { - session_table_capacity: config.session_table.as_ref().map(|st| st.capacity), - session_table_max_load_factor: config.session_table.as_ref().map(|st| st.max_load_factor), - wlc: config.wlc.map(Into::into), - refresh_period: config.refresh_period_ms.map(|ms| prost_types::Duration { - seconds: (ms / 1000) as i64, - nanos: ((ms % 1000) * 1_000_000) as i32, - }), - } - } -} - -impl From for balancerpb::WlcConfig { - fn from(config: WlcConfig) -> Self { - Self { - power: config.power, - max_weight: config.max_weight, - } - } -} diff --git a/modules/balancer/cli/src/json_output.rs b/modules/balancer/cli/src/json_output.rs deleted file mode 100644 index 1f8289f9f..000000000 --- a/modules/balancer/cli/src/json_output.rs +++ /dev/null @@ -1,946 +0,0 @@ -//! JSON output with proper IP address formatting - -use serde::Serialize; - -use crate::{ - entities::{addr_to_ip, opt_addr_to_ip}, - rpc::balancerpb, -}; - -//////////////////////////////////////////////////////////////////////////////// -// ShowConfig JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ShowConfigJson { - pub config: Option, - pub buffered_real_updates: Vec, -} - -#[derive(Serialize)] -pub struct BalancerConfigJson { - pub packet_handler: Option, - pub state: Option, -} - -#[derive(Serialize)] -pub struct PacketHandlerConfigJson { - pub virtual_services: Vec, - pub source_address_v4: String, - pub source_address_v6: String, - pub decap_addresses: Vec, - pub sessions_timeouts: Option, -} - -#[derive(Serialize)] -pub struct VirtualServiceJson { - pub id: Option, - pub scheduler: String, - pub allowed_srcs: Vec, - pub reals: Vec, - pub flags: Option, - pub peers: Vec, -} - -#[derive(Serialize)] -pub struct VsIdentifierJson { - pub addr: String, - pub port: u32, - pub proto: String, -} - -#[derive(Serialize)] -pub struct AllowedSourcesJson { - pub networks: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub ports: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tag: Option, -} - -#[derive(Serialize)] -pub struct NetworkJson { - pub addr: String, - pub mask: String, -} - -#[derive(Serialize)] -pub struct PortRangeJson { - pub from: u32, - pub to: u32, -} - -#[derive(Serialize)] -pub struct RealJson { - pub id: Option, - pub weight: u32, - pub src_addr: String, - pub src_mask: String, -} - -#[derive(Serialize)] -pub struct RelativeRealIdentifierJson { - pub ip: String, - pub port: u32, -} - -#[derive(Serialize)] -pub struct VsFlagsJson { - pub gre: bool, - pub fix_mss: bool, - pub ops: bool, - pub pure_l3: bool, - pub wlc: bool, -} - -#[derive(Serialize)] -pub struct SessionsTimeoutsJson { - pub tcp_syn_ack: u32, - pub tcp_syn: u32, - pub tcp_fin: u32, - pub tcp: u32, - pub udp: u32, - pub default: u32, -} - -#[derive(Serialize)] -pub struct StateConfigJson { - pub session_table_capacity: Option, - pub session_table_max_load_factor: Option, - pub wlc: Option, - pub refresh_period: Option, -} - -#[derive(Serialize)] -pub struct WlcConfigJson { - pub power: Option, - pub max_weight: Option, -} - -#[derive(Serialize)] -pub struct RealUpdateJson { - pub real_id: Option, - pub enable: Option, - pub weight: Option, -} - -#[derive(Serialize)] -pub struct RealIdentifierJson { - pub vs: Option, - pub real: Option, -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowInfo JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ShowInfoResponseJson { - pub name: String, - pub info: Option, -} - -#[derive(Serialize)] -pub struct BalancerInfoJson { - pub active_sessions: u64, - pub last_packet_timestamp: Option, - pub vs: Vec, -} - -#[derive(Serialize)] -pub struct VsInfoJson { - pub id: Option, - pub active_sessions: u64, - pub last_packet_timestamp: Option, - pub reals: Vec, -} - -#[derive(Serialize)] -pub struct RealInfoJson { - pub id: Option, - pub active_sessions: u64, - pub last_packet_timestamp: Option, -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowStats JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ShowStatsResponseJson { - pub entries: Vec, -} - -#[derive(Serialize)] -pub struct StatsEntryJson { - pub name: String, - pub ref_: Option, - pub stats: Option, -} - -#[derive(Serialize)] -pub struct PacketHandlerRefJson { - pub device: Option, - pub pipeline: Option, - pub function: Option, - pub chain: Option, -} - -#[derive(Serialize)] -pub struct BalancerStatsJson { - pub l4: Option, - pub icmpv4: Option, - pub icmpv6: Option, - pub common: Option, - pub vs: Vec, -} - -#[derive(Serialize)] -pub struct L4StatsJson { - pub incoming_packets: u64, - pub select_vs_failed: u64, - pub invalid_packets: u64, - pub select_real_failed: u64, - pub outgoing_packets: u64, -} - -#[derive(Serialize)] -pub struct IcmpStatsJson { - pub incoming_packets: u64, - pub src_not_allowed: u64, - pub echo_responses: u64, - pub payload_too_short_ip: u64, - pub unmatching_src_from_original: u64, - pub payload_too_short_port: u64, - pub unexpected_transport: u64, - pub unrecognized_vs: u64, - pub forwarded_packets: u64, - pub broadcasted_packets: u64, - pub packet_clones_sent: u64, - pub packet_clones_received: u64, - pub packet_clone_failures: u64, -} - -#[derive(Serialize)] -pub struct CommonStatsJson { - pub incoming_packets: u64, - pub incoming_bytes: u64, - pub unexpected_network_proto: u64, - pub decap_successful: u64, - pub decap_failed: u64, - pub outgoing_packets: u64, - pub outgoing_bytes: u64, -} - -#[derive(Serialize)] -pub struct NamedVsStatsJson { - pub vs: Option, - pub stats: Option, - pub reals: Vec, - pub allowed_sources: Vec, -} - -#[derive(Serialize)] -pub struct AllowedSourcesStatsJson { - pub tag: String, - pub passes: u64, -} - -#[derive(Serialize)] -pub struct VsStatsJson { - pub incoming_packets: u64, - pub incoming_bytes: u64, - pub packet_src_not_allowed: u64, - pub no_reals: u64, - pub ops_packets: u64, - pub session_table_overflow: u64, - pub echo_icmp_packets: u64, - pub error_icmp_packets: u64, - pub real_is_disabled: u64, - pub real_is_removed: u64, - pub not_rescheduled_packets: u64, - pub broadcasted_icmp_packets: u64, - pub created_sessions: u64, - pub outgoing_packets: u64, - pub outgoing_bytes: u64, -} - -#[derive(Serialize)] -pub struct NamedRealStatsJson { - pub real: Option, - pub stats: Option, -} - -#[derive(Serialize)] -pub struct RealStatsJson { - pub packets_real_disabled: u64, - pub packets_real_not_present: u64, - pub ops_packets: u64, - pub error_icmp_packets: u64, - pub created_sessions: u64, - pub packets: u64, - pub bytes: u64, -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowSessions JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ShowSessionsResponseJson { - pub sessions: Vec, -} - -#[derive(Serialize)] -pub struct SessionInfoJson { - pub client_addr: String, - pub client_port: u32, - pub vs_id: Option, - pub real_id: Option, - pub create_timestamp: Option, - pub last_packet_timestamp: Option, - pub timeout: Option, -} - -//////////////////////////////////////////////////////////////////////////////// -// ListConfigs JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ListConfigsResponseJson { - pub configs: Vec, -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowGraph JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct ShowGraphResponseJson { - pub graph: Option, -} - -#[derive(Serialize)] -pub struct GraphJson { - pub virtual_services: Vec, -} - -#[derive(Serialize)] -pub struct GraphVsJson { - pub identifier: Option, - pub reals: Vec, -} - -#[derive(Serialize)] -pub struct GraphRealJson { - pub identifier: Option, - pub weight: u32, - pub effective_weight: u32, - pub enabled: bool, -} - -//////////////////////////////////////////////////////////////////////////////// -// UpdateInfo JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct UpdateInfoJson { - pub created: bool, - pub vs_ipv4_matcher_reused: bool, - pub vs_ipv6_matcher_reused: bool, - pub vs_acl_reuses: Vec, -} - -//////////////////////////////////////////////////////////////////////////////// -// Conversion functions -//////////////////////////////////////////////////////////////////////////////// - -fn proto_to_string(proto: i32) -> String { - match balancerpb::TransportProto::try_from(proto) { - Ok(balancerpb::TransportProto::Tcp) => "tcp".to_string(), - Ok(balancerpb::TransportProto::Udp) => "udp".to_string(), - _ => format!("Unknown({})", proto), - } -} - -fn scheduler_to_string(sched: i32) -> String { - match balancerpb::VsScheduler::try_from(sched) { - Ok(balancerpb::VsScheduler::SourceHash) => "source_hash".to_string(), - Ok(balancerpb::VsScheduler::RoundRobin) => "round_robin".to_string(), - _ => format!("Unknown({})", sched), - } -} - -fn format_timestamp(ts: Option<&prost_types::Timestamp>) -> Option { - ts.and_then(|t| { - if t.seconds == 0 && t.nanos == 0 { - return None; - } - use chrono::{DateTime, Utc}; - DateTime::::from_timestamp(t.seconds, t.nanos as u32).map(|dt| dt.to_rfc3339()) - }) -} - -fn format_duration(dur: Option<&prost_types::Duration>) -> Option { - dur.map(|d| format!("{}s", d.seconds)) -} - -fn convert_vs_identifier(id: Option<&balancerpb::VsIdentifier>) -> Option { - id.map(|id| VsIdentifierJson { - addr: opt_addr_to_ip(&id.addr).map(|ip| ip.to_string()).unwrap_or_default(), - port: id.port, - proto: proto_to_string(id.proto), - }) -} - -fn convert_relative_real_identifier( - id: Option<&balancerpb::RelativeRealIdentifier>, -) -> Option { - id.map(|id| RelativeRealIdentifierJson { - ip: opt_addr_to_ip(&id.ip).map(|ip| ip.to_string()).unwrap_or_default(), - port: id.port, - }) -} - -fn convert_real_identifier(id: Option<&balancerpb::RealIdentifier>) -> Option { - id.map(|id| RealIdentifierJson { - vs: convert_vs_identifier(id.vs.as_ref()), - real: convert_relative_real_identifier(id.real.as_ref()), - }) -} - -pub fn convert_show_config(response: &balancerpb::ShowConfigResponse) -> ShowConfigJson { - ShowConfigJson { - config: response.config.as_ref().map(|c| BalancerConfigJson { - packet_handler: c.packet_handler.as_ref().map(|ph| PacketHandlerConfigJson { - virtual_services: ph - .vs - .iter() - .map(|vs| VirtualServiceJson { - id: convert_vs_identifier(vs.id.as_ref()), - scheduler: scheduler_to_string(vs.scheduler), - allowed_srcs: vs - .allowed_srcs - .iter() - .filter_map(|s| { - // Extract all networks - let networks: Vec = s - .nets - .iter() - .filter_map(|net| { - let addr = opt_addr_to_ip(&net.addr).ok()?; - let mask_bytes = net.mask.as_ref()?.bytes.as_slice(); - let mask = crate::entities::bytes_to_ip(mask_bytes).ok()?; - Some(NetworkJson { - addr: addr.to_string(), - mask: mask.to_string(), - }) - }) - .collect(); - - if networks.is_empty() { - return None; - } - - // Extract port ranges if present - let ports = if s.ports.is_empty() { - None - } else { - Some( - s.ports - .iter() - .map(|pr| PortRangeJson { from: pr.from, to: pr.to }) - .collect(), - ) - }; - - Some(AllowedSourcesJson { networks, ports, tag: s.tag.clone() }) - }) - .collect(), - reals: vs - .reals - .iter() - .map(|r| RealJson { - id: convert_relative_real_identifier(r.id.as_ref()), - weight: r.weight, - src_addr: opt_addr_to_ip(&r.src_addr).map(|ip| ip.to_string()).unwrap_or_default(), - src_mask: opt_addr_to_ip(&r.src_mask).map(|ip| ip.to_string()).unwrap_or_default(), - }) - .collect(), - flags: vs.flags.as_ref().map(|f| VsFlagsJson { - gre: f.gre, - fix_mss: f.fix_mss, - ops: f.ops, - pure_l3: f.pure_l3, - wlc: f.wlc, - }), - peers: vs - .peers - .iter() - .filter_map(|p| addr_to_ip(p).ok().map(|ip| ip.to_string())) - .collect(), - }) - .collect(), - source_address_v4: opt_addr_to_ip(&ph.source_address_v4) - .map(|ip| ip.to_string()) - .unwrap_or_default(), - source_address_v6: opt_addr_to_ip(&ph.source_address_v6) - .map(|ip| ip.to_string()) - .unwrap_or_default(), - decap_addresses: ph - .decap_addresses - .iter() - .filter_map(|a| addr_to_ip(a).ok().map(|ip| ip.to_string())) - .collect(), - sessions_timeouts: ph.sessions_timeouts.as_ref().map(|t| SessionsTimeoutsJson { - tcp_syn_ack: t.tcp_syn_ack, - tcp_syn: t.tcp_syn, - tcp_fin: t.tcp_fin, - tcp: t.tcp, - udp: t.udp, - default: t.default, - }), - }), - state: c.state.as_ref().map(|s| StateConfigJson { - session_table_capacity: s.session_table_capacity, - session_table_max_load_factor: s.session_table_max_load_factor, - wlc: s.wlc.as_ref().map(|w| WlcConfigJson { - power: w.power, - max_weight: w.max_weight, - }), - refresh_period: s - .refresh_period - .as_ref() - .map(|p| format!("{}ms", p.seconds * 1000 + p.nanos as i64 / 1_000_000)), - }), - }), - buffered_real_updates: response - .buffered_real_updates - .iter() - .map(|u| RealUpdateJson { - real_id: convert_real_identifier(u.real_id.as_ref()), - enable: u.enable, - weight: u.weight, - }) - .collect(), - } -} - -pub fn convert_list_configs(response: &balancerpb::ListConfigsResponse) -> ListConfigsResponseJson { - ListConfigsResponseJson { configs: response.configs.clone() } -} - -pub fn convert_show_info(response: &balancerpb::ShowInfoResponse) -> ShowInfoResponseJson { - ShowInfoResponseJson { - name: response.name.clone(), - info: response.info.as_ref().map(|i| BalancerInfoJson { - active_sessions: i.active_sessions, - last_packet_timestamp: format_timestamp(i.last_packet_timestamp.as_ref()), - vs: i - .vs - .iter() - .map(|v| VsInfoJson { - id: convert_vs_identifier(v.id.as_ref()), - active_sessions: v.active_sessions, - last_packet_timestamp: format_timestamp(v.last_packet_timestamp.as_ref()), - reals: v - .reals - .iter() - .map(|r| RealInfoJson { - id: convert_real_identifier(r.id.as_ref()), - active_sessions: r.active_sessions, - last_packet_timestamp: format_timestamp(r.last_packet_timestamp.as_ref()), - }) - .collect(), - }) - .collect(), - }), - } -} - -pub fn convert_show_stats(response: &balancerpb::ShowStatsResponse) -> ShowStatsResponseJson { - ShowStatsResponseJson { - entries: response - .entries - .iter() - .map(|e| StatsEntryJson { - name: e.name.clone(), - ref_: e.r#ref.as_ref().map(|r| PacketHandlerRefJson { - device: r.device.clone(), - pipeline: r.pipeline.clone(), - function: r.function.clone(), - chain: r.chain.clone(), - }), - stats: e.stats.as_ref().map(|s| BalancerStatsJson { - l4: s.l4.as_ref().map(|l| L4StatsJson { - incoming_packets: l.incoming_packets, - select_vs_failed: l.select_vs_failed, - invalid_packets: l.invalid_packets, - select_real_failed: l.select_real_failed, - outgoing_packets: l.outgoing_packets, - }), - icmpv4: s.icmpv4.as_ref().map(|i| IcmpStatsJson { - incoming_packets: i.incoming_packets, - src_not_allowed: i.src_not_allowed, - echo_responses: i.echo_responses, - payload_too_short_ip: i.payload_too_short_ip, - unmatching_src_from_original: i.unmatching_src_from_original, - payload_too_short_port: i.payload_too_short_port, - unexpected_transport: i.unexpected_transport, - unrecognized_vs: i.unrecognized_vs, - forwarded_packets: i.forwarded_packets, - broadcasted_packets: i.broadcasted_packets, - packet_clones_sent: i.packet_clones_sent, - packet_clones_received: i.packet_clones_received, - packet_clone_failures: i.packet_clone_failures, - }), - icmpv6: s.icmpv6.as_ref().map(|i| IcmpStatsJson { - incoming_packets: i.incoming_packets, - src_not_allowed: i.src_not_allowed, - echo_responses: i.echo_responses, - payload_too_short_ip: i.payload_too_short_ip, - unmatching_src_from_original: i.unmatching_src_from_original, - payload_too_short_port: i.payload_too_short_port, - unexpected_transport: i.unexpected_transport, - unrecognized_vs: i.unrecognized_vs, - forwarded_packets: i.forwarded_packets, - broadcasted_packets: i.broadcasted_packets, - packet_clones_sent: i.packet_clones_sent, - packet_clones_received: i.packet_clones_received, - packet_clone_failures: i.packet_clone_failures, - }), - common: s.common.as_ref().map(|c| CommonStatsJson { - incoming_packets: c.incoming_packets, - incoming_bytes: c.incoming_bytes, - unexpected_network_proto: c.unexpected_network_proto, - decap_successful: c.decap_successful, - decap_failed: c.decap_failed, - outgoing_packets: c.outgoing_packets, - outgoing_bytes: c.outgoing_bytes, - }), - vs: s - .vs - .iter() - .map(|v| NamedVsStatsJson { - vs: convert_vs_identifier(v.vs.as_ref()), - stats: v.stats.as_ref().map(|st| VsStatsJson { - incoming_packets: st.incoming_packets, - incoming_bytes: st.incoming_bytes, - packet_src_not_allowed: st.packet_src_not_allowed, - no_reals: st.no_reals, - ops_packets: st.ops_packets, - session_table_overflow: st.session_table_overflow, - echo_icmp_packets: st.echo_icmp_packets, - error_icmp_packets: st.error_icmp_packets, - real_is_disabled: st.real_is_disabled, - real_is_removed: st.real_is_removed, - not_rescheduled_packets: st.not_rescheduled_packets, - broadcasted_icmp_packets: st.broadcasted_icmp_packets, - created_sessions: st.created_sessions, - outgoing_packets: st.outgoing_packets, - outgoing_bytes: st.outgoing_bytes, - }), - reals: v - .reals - .iter() - .map(|r| NamedRealStatsJson { - real: convert_real_identifier(r.real.as_ref()), - stats: r.stats.as_ref().map(|st| RealStatsJson { - packets_real_disabled: st.packets_real_disabled, - packets_real_not_present: 0, // Field removed in new proto - ops_packets: st.ops_packets, - error_icmp_packets: st.error_icmp_packets, - created_sessions: st.created_sessions, - packets: st.packets, - bytes: st.bytes, - }), - }) - .collect(), - allowed_sources: v - .allowed_sources - .iter() - .map(|a| AllowedSourcesStatsJson { tag: a.tag.clone(), passes: a.passes }) - .collect(), - }) - .collect(), - }), - }) - .collect(), - } -} - -pub fn convert_show_sessions(response: &balancerpb::ShowSessionsResponse) -> ShowSessionsResponseJson { - ShowSessionsResponseJson { - sessions: response - .sessions - .iter() - .map(|s| SessionInfoJson { - client_addr: opt_addr_to_ip(&s.client_addr) - .map(|ip| ip.to_string()) - .unwrap_or_default(), - client_port: s.client_port, - vs_id: convert_vs_identifier(s.vs_id.as_ref()), - real_id: convert_real_identifier(s.real_id.as_ref()), - create_timestamp: format_timestamp(s.create_timestamp.as_ref()), - last_packet_timestamp: format_timestamp(s.last_packet_timestamp.as_ref()), - timeout: format_duration(s.timeout.as_ref()), - }) - .collect(), - } -} - -pub fn convert_show_graph(response: &balancerpb::ShowGraphResponse) -> ShowGraphResponseJson { - ShowGraphResponseJson { - graph: response.graph.as_ref().map(|g| GraphJson { - virtual_services: g - .virtual_services - .iter() - .map(|vs| GraphVsJson { - identifier: convert_vs_identifier(vs.identifier.as_ref()), - reals: vs - .reals - .iter() - .map(|r| GraphRealJson { - identifier: convert_relative_real_identifier(r.identifier.as_ref()), - weight: r.weight, - effective_weight: r.effective_weight, - enabled: r.enabled, - }) - .collect(), - }) - .collect(), - }), - } -} - -pub fn convert_update_info(info: &balancerpb::UpdateInfo) -> UpdateInfoJson { - UpdateInfoJson { - created: info.created, - vs_ipv4_matcher_reused: info.vs_ipv4_matcher_reused, - vs_ipv6_matcher_reused: info.vs_ipv6_matcher_reused, - vs_acl_reuses: info - .vs_acl_reuses - .iter() - .filter_map(|vs_id| convert_vs_identifier(Some(vs_id))) - .collect(), - } -} - -//////////////////////////////////////////////////////////////////////////////// -// VS Update Info JSON structures (without created field) -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct VsUpdateInfoJson { - pub vs_ipv4_matcher_reused: bool, - pub vs_ipv6_matcher_reused: bool, - pub vs_acl_reuses: Vec, -} - -pub fn convert_vs_update_info(info: &balancerpb::UpdateInfo) -> VsUpdateInfoJson { - VsUpdateInfoJson { - vs_ipv4_matcher_reused: info.vs_ipv4_matcher_reused, - vs_ipv6_matcher_reused: info.vs_ipv6_matcher_reused, - vs_acl_reuses: info - .vs_acl_reuses - .iter() - .filter_map(|vs_id| convert_vs_identifier(Some(vs_id))) - .collect(), - } -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowInspect JSON structures -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Serialize)] -pub struct AgentInspectJson { - pub memory_limit: u64, - pub memory_usage: u64, - pub balancers: Vec, -} - -#[derive(Serialize)] -pub struct BalancerInspectJson { - pub name: String, - pub packet_handler_inspect: PacketHandlerInspectJson, - pub state_inspect: StateInspectJson, - pub other_usage: u64, - pub total_usage: u64, -} - -#[derive(Serialize)] -pub struct PacketHandlerInspectJson { - pub vs_ipv4_inspect: PacketHandlerVsInspectJson, - pub vs_ipv6_inspect: PacketHandlerVsInspectJson, - pub summary_vs_usage: u64, - pub vs_index_usage: u64, - pub reals_index_usage: u64, - pub counters_usage: u64, - pub decap_usage: u64, - pub total_usage: u64, -} - -#[derive(Serialize)] -pub struct PacketHandlerVsInspectJson { - pub matcher_usage: u64, - pub summary_vs_usage: u64, - pub vs_inspects: Vec, - pub announce_usage: u64, - pub index_usage: u64, - pub total_usage: u64, -} - -#[derive(Serialize)] -pub struct NamedVsInspectJson { - pub identifier: Option, - pub inspect: VsInspectJson, -} - -#[derive(Serialize)] -pub struct VsInspectJson { - pub acl_usage: u64, - pub ring_usage: u64, - pub counters_usage: u64, - pub reals_usage: RealsUsageJson, - pub other_usage: u64, - pub total_usage: u64, -} - -#[derive(Serialize)] -pub struct RealsUsageJson { - pub counters_usage: u64, - pub data_usage: u64, - pub total_usage: u64, -} - -#[derive(Serialize)] -pub struct StateInspectJson { - pub session_table_usage: u64, - pub total_usage: u64, -} - -pub fn convert_show_inspect(response: &balancerpb::ShowInspectResponse) -> AgentInspectJson { - let inspect = response.inspect.as_ref(); - - AgentInspectJson { - memory_limit: inspect.map(|i| i.memory_limit).unwrap_or(0), - memory_usage: inspect.map(|i| i.memory_usage).unwrap_or(0), - balancers: inspect - .map(|i| { - i.balancers - .iter() - .map(|b| BalancerInspectJson { - name: b.name.clone(), - packet_handler_inspect: convert_packet_handler_inspect(b.packet_handler_inspect.as_ref()), - state_inspect: convert_state_inspect(b.state_inspect.as_ref()), - other_usage: b.other_usage, - total_usage: b.total_usage, - }) - .collect() - }) - .unwrap_or_default(), - } -} - -fn convert_packet_handler_inspect(ph: Option<&balancerpb::PacketHandlerInspect>) -> PacketHandlerInspectJson { - match ph { - Some(ph) => PacketHandlerInspectJson { - vs_ipv4_inspect: convert_packet_handler_vs_inspect(ph.vs_ipv4_inspect.as_ref()), - vs_ipv6_inspect: convert_packet_handler_vs_inspect(ph.vs_ipv6_inspect.as_ref()), - summary_vs_usage: ph.summary_vs_usage, - vs_index_usage: ph.vs_index_usage, - reals_index_usage: ph.reals_index_usage, - counters_usage: ph.counters_usage, - decap_usage: ph.decap_usage, - total_usage: ph.total_usage, - }, - None => PacketHandlerInspectJson { - vs_ipv4_inspect: convert_packet_handler_vs_inspect(None), - vs_ipv6_inspect: convert_packet_handler_vs_inspect(None), - summary_vs_usage: 0, - vs_index_usage: 0, - reals_index_usage: 0, - counters_usage: 0, - decap_usage: 0, - total_usage: 0, - }, - } -} - -fn convert_packet_handler_vs_inspect(vs: Option<&balancerpb::PacketHandlerVsInspect>) -> PacketHandlerVsInspectJson { - match vs { - Some(vs) => PacketHandlerVsInspectJson { - matcher_usage: vs.matcher_usage, - summary_vs_usage: vs.summary_vs_usage, - vs_inspects: vs - .vs_inspects - .iter() - .map(|nvi| NamedVsInspectJson { - identifier: convert_vs_identifier(nvi.identifier.as_ref()), - inspect: convert_vs_inspect(nvi.inspect.as_ref()), - }) - .collect(), - announce_usage: vs.announce_usage, - index_usage: vs.index_usage, - total_usage: vs.total_usage, - }, - None => PacketHandlerVsInspectJson { - matcher_usage: 0, - summary_vs_usage: 0, - vs_inspects: Vec::new(), - announce_usage: 0, - index_usage: 0, - total_usage: 0, - }, - } -} - -fn convert_vs_inspect(vs: Option<&balancerpb::VsInspect>) -> VsInspectJson { - match vs { - Some(vs) => VsInspectJson { - acl_usage: vs.acl_usage, - ring_usage: vs.ring_usage, - counters_usage: vs.counters_usage, - reals_usage: convert_reals_usage(vs.reals_usage.as_ref()), - other_usage: vs.other_usage, - total_usage: vs.total_usage, - }, - None => VsInspectJson { - acl_usage: 0, - ring_usage: 0, - counters_usage: 0, - reals_usage: convert_reals_usage(None), - other_usage: 0, - total_usage: 0, - }, - } -} - -fn convert_reals_usage(reals: Option<&balancerpb::RealsUsage>) -> RealsUsageJson { - match reals { - Some(reals) => RealsUsageJson { - counters_usage: reals.counters_usage, - data_usage: reals.data_usage, - total_usage: reals.total_usage, - }, - None => RealsUsageJson { - counters_usage: 0, - data_usage: 0, - total_usage: 0, - }, - } -} - -fn convert_state_inspect(state: Option<&balancerpb::StateInspect>) -> StateInspectJson { - match state { - Some(state) => StateInspectJson { - session_table_usage: state.session_table_usage, - total_usage: state.total_usage, - }, - None => StateInspectJson { - session_table_usage: 0, - total_usage: 0, - }, - } -} diff --git a/modules/balancer/cli/src/lib.rs b/modules/balancer/cli/src/lib.rs index 775fa24a1..87f6c514e 100644 --- a/modules/balancer/cli/src/lib.rs +++ b/modules/balancer/cli/src/lib.rs @@ -1,8 +1,16 @@ -//! YANET Balancer CLI library +#[allow(clippy::all, non_snake_case)] +pub mod filterpb { + tonic::include_proto!("filterpb"); +} -pub mod cmd; -pub mod entities; -pub mod json_output; -pub mod output; -pub mod rpc; -pub mod service; +#[allow(clippy::all, non_snake_case)] +pub mod commonpb { + tonic::include_proto!("commonpb"); +} + +#[allow(clippy::all, non_snake_case)] +pub mod balancerpb { + tonic::include_proto!("balancerpb"); +} + +pub use balancerpb::balancer_client::BalancerClient; diff --git a/modules/balancer/cli/src/main.rs b/modules/balancer/cli/src/main.rs index 64afa5455..32e36d883 100644 --- a/modules/balancer/cli/src/main.rs +++ b/modules/balancer/cli/src/main.rs @@ -1,33 +1,364 @@ -mod cmd; -mod entities; -mod json_output; -mod output; -mod rpc; +mod config; +mod display; mod service; use std::error::Error; -use clap::{CommandFactory, Parser}; +use clap::{ArgAction, CommandFactory, Parser}; use clap_complete::CompleteEnv; use ync::logging; -use crate::{cmd::Cmd, service::BalancerService}; +use crate::service::BalancerService; -//////////////////////////////////////////////////////////////////////////////// +/// Balancer module CLI. +#[derive(Debug, Clone, Parser)] +#[command(version, about)] +#[command(flatten_help = true)] +pub struct Cmd { + #[clap(subcommand)] + pub mode: ModeCmd, + #[command(flatten)] + pub connection: ync::client::ConnectionArgs, + /// Be verbose in terms of logging. + #[clap(short, action = ArgAction::Count, global = true)] + pub verbose: u8, +} + +#[derive(Debug, Clone, Parser)] +pub enum ModeCmd { + /// Update balancer configuration from YAML file. + Update(UpdateCmd), + /// List all balancer instances. + List, + /// Show balancer configuration. + Config(ConfigCmd), + /// Show balancer state (IPVS-style). + Show(ShowCmd), + /// Show active sessions (streaming). + Sessions(SessionsCmd), + /// Show balancer metrics (JSON). + Metrics(MetricsCmd), + /// Manage real servers. + Reals(RealsCmd), +} + +// ─── Update ────────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct UpdateCmd { + /// Balancer instance name. + #[arg(long, short = 'n')] + pub name: String, + /// Path to YAML configuration file. + #[arg(long, short = 'c')] + pub config: String, +} + +// ─── Config ────────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct ConfigCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, +} + +// ─── Show ──────────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct ShowCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, + + /// Tabled output: VS info, scheduler, flags, reals with weights. + #[arg(long, short = 't')] + pub table: bool, + + /// Show all counters, active sessions and last packet timestamps. + #[arg(long, short = 's')] + pub stats: bool, + + /// Show allowed sources config per VS (with counters if --stats is + /// present). + #[arg(long, short = 'a')] + pub acl: bool, + + /// Show peers per VS. + #[arg(long)] + pub peers: bool, + + /// Show decap addresses and source IPs. + #[arg(long)] + pub decap: bool, + + /// Enable all output sections (--table --stats --acl --peers --decap). + #[arg(long, short = 'd')] + pub detail: bool, + + #[command(flatten)] + pub filter: FilterFlags, + + /// Filter by device name. + #[arg(long)] + pub device: Option, + /// Filter by pipeline name. + #[arg(long, short = 'p')] + pub pipeline: Option, + /// Filter by function name. + #[arg(long, short = 'f')] + pub function: Option, + /// Filter by chain name. + #[arg(long)] + pub chain: Option, +} + +impl ShowCmd { + /// Whether counters should be requested from the server. + pub fn include_counters(&self) -> bool { + self.stats || self.detail || !self.needs_table() + } + + /// Whether tabled output mode is active. + pub fn needs_table(&self) -> bool { + self.table || self.stats || self.acl || self.peers || self.decap || self.detail + } +} + +// ─── Sessions ──────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct SessionsCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, + + #[command(flatten)] + pub filter: FilterFlags, +} + +// ─── Metrics ──────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct MetricsCmd {} + +// ─── Reals ─────────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct RealsCmd { + #[clap(subcommand)] + pub mode: RealsMode, +} + +#[derive(Debug, Clone, Parser)] +pub enum RealsMode { + /// Enable real servers (buffered). + Enable(EnableRealCmd), + /// Disable real servers (buffered). + Disable(DisableRealCmd), + /// Flush buffered real server updates. + Flush(FlushRealsCmd), +} + +#[derive(Debug, Clone, Parser)] +pub struct EnableRealCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, + /// Virtual service identifier: "ip:port/proto" or "[ipv6]:port/proto". + #[arg(long)] + pub vs: String, + /// Real server IPs to enable. + #[arg(long, required = true, num_args = 1..)] + pub reals: Vec, + /// Optional new weight for the real servers. + #[arg(long)] + pub weight: Option, + /// Flush buffered updates immediately after enabling. + #[arg(long, default_value_t = false)] + pub flush: bool, +} + +#[derive(Debug, Clone, Parser)] +pub struct DisableRealCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, + /// Virtual service identifier: "ip:port/proto" or "[ipv6]:port/proto". + #[arg(long)] + pub vs: String, + /// Real server IPs to disable. + #[arg(long, required = true, num_args = 1..)] + pub reals: Vec, + /// Flush buffered updates immediately after disabling. + #[arg(long, default_value_t = false)] + pub flush: bool, +} + +#[derive(Debug, Clone, Parser)] +pub struct FlushRealsCmd { + /// Balancer instance name (optional, auto-selects if only one exists). + #[arg(long, short = 'n')] + pub name: Option, +} + +// ─── Shared Filter Flags ───────────────────────────────────────────────────── + +#[derive(Debug, Clone, Parser)] +pub struct FilterFlags { + /// Filter by VIP address. + #[arg(long)] + pub vip: Option, + /// Filter by virtual service port. + #[arg(long)] + pub vs_port: Option, + /// Filter by transport protocol (tcp or udp). + #[arg(long)] + pub proto: Option, + /// Filter by real server IP. + #[arg(long)] + pub real_ip: Option, + /// Filter by real server port. + #[arg(long)] + pub real_port: Option, +} + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum Proto { + Tcp, + Udp, +} + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +use yanet_cli_balancer::balancerpb; + +/// Parse a VS identifier string: "ip:port/proto", "[ipv6]:port/proto", or +/// "ipv6:port/proto". +pub fn parse_vs_identifier(vs_str: &str) -> Result<(std::net::IpAddr, u16, balancerpb::TransportProto), String> { + let vs_parts: Vec<&str> = vs_str.split('/').collect(); + if vs_parts.len() != 2 { + return Err(format!( + "invalid --vs format: '{}'. Expected: 'ip:port/proto', '[ipv6]:port/proto'", + vs_str + )); + } + + let addr_port = vs_parts[0]; + let proto = match vs_parts[1].to_uppercase().as_str() { + "TCP" => balancerpb::TransportProto::Tcp, + "UDP" => balancerpb::TransportProto::Udp, + other => return Err(format!("invalid proto: '{}'. Expected 'tcp' or 'udp'", other)), + }; + + let (ip_str, port_str) = if addr_port.starts_with('[') { + let bracket_end = addr_port + .find(']') + .ok_or_else(|| format!("invalid IPv6 bracket notation: '{}'", addr_port))?; + let ip_part = &addr_port[1..bracket_end]; + let remaining = &addr_port[bracket_end + 1..]; + if !remaining.starts_with(':') { + return Err(format!("expected ':' after ']' in '{}'", addr_port)); + } + (ip_part, &remaining[1..]) + } else { + let parts: Vec<&str> = addr_port.rsplitn(2, ':').collect(); + if parts.len() != 2 { + return Err(format!("invalid address:port format: '{}'", addr_port)); + } + (parts[1], parts[0]) + }; + + let port: u16 = port_str + .parse() + .map_err(|e| format!("invalid port '{}': {}", port_str, e))?; + let ip: std::net::IpAddr = ip_str.parse().map_err(|e| format!("invalid IP '{}': {}", ip_str, e))?; + + Ok((ip, port, proto)) +} + +pub fn ip_to_bytes(ip: std::net::IpAddr) -> Vec { + match ip { + std::net::IpAddr::V4(v4) => v4.octets().to_vec(), + std::net::IpAddr::V6(v6) => v6.octets().to_vec(), + } +} + +pub fn bytes_to_ip(bytes: &[u8]) -> Result { + match bytes.len() { + 4 => { + let arr: [u8; 4] = bytes.try_into().map_err(|_| "invalid IPv4 bytes")?; + Ok(std::net::IpAddr::V4(std::net::Ipv4Addr::from(arr))) + } + 16 => { + let arr: [u8; 16] = bytes.try_into().map_err(|_| "invalid IPv6 bytes")?; + Ok(std::net::IpAddr::V6(std::net::Ipv6Addr::from(arr))) + } + n => Err(format!("invalid IP address length: {}", n)), + } +} + +pub fn format_ip_port(ip: std::net::IpAddr, port: u32) -> String { + match ip { + std::net::IpAddr::V4(_) => { + if port == 0 { + format!("{}", ip) + } else { + format!("{}:{}", ip, port) + } + } + std::net::IpAddr::V6(_) => { + if port == 0 { + format!("{}", ip) + } else { + format!("[{}]:{}", ip, port) + } + } + } +} + +impl FilterFlags { + pub fn to_proto(&self) -> Option { + if self.vip.is_none() + && self.vs_port.is_none() + && self.proto.is_none() + && self.real_ip.is_none() + && self.real_port.is_none() + { + return None; + } + + Some(balancerpb::Filter { + vip: self.vip.as_ref().map(|s| { + let ip: std::net::IpAddr = s.parse().expect("invalid VIP address"); + ip_to_bytes(ip) + }), + vs_port: self.vs_port, + proto: self.proto.as_ref().map(|p| match p { + Proto::Tcp => balancerpb::TransportProto::Tcp as i32, + Proto::Udp => balancerpb::TransportProto::Udp as i32, + }), + real_ip: self.real_ip.as_ref().map(|s| { + let ip: std::net::IpAddr = s.parse().expect("invalid real IP address"); + ip_to_bytes(ip) + }), + real_port: self.real_port, + }) + } +} + +// ─── Entry Point ───────────────────────────────────────────────────────────── async fn run(cmd: Cmd) -> Result<(), Box> { let mut service = BalancerService::connect(&cmd.connection).await?; - service.handle_cmd(cmd.mode).await + service.handle(cmd.mode).await } -//////////////////////////////////////////////////////////////////////////////// - #[tokio::main(flavor = "current_thread")] pub async fn main() { CompleteEnv::with_factory(Cmd::command).complete(); let cmd = Cmd::parse(); - - logging::init(cmd.verbosity as usize).expect("failed to initialize logging"); + logging::init(cmd.verbose as usize).expect("failed to initialize logging"); if let Err(err) = run(cmd).await { log::error!("{err}"); diff --git a/modules/balancer/cli/src/output.rs b/modules/balancer/cli/src/output.rs deleted file mode 100644 index 842f16d4d..000000000 --- a/modules/balancer/cli/src/output.rs +++ /dev/null @@ -1,2690 +0,0 @@ -//! Output formatting for different display formats (JSON, Tree, Table) - -use std::{error::Error, net::IpAddr}; - -use chrono::{DateTime, Utc}; -use colored::Colorize; -use ptree::TreeBuilder; -use tabled::Tabled; -use ync::display::print_table; - -use crate::{ - entities::{addr_to_ip, format_bytes, format_number, opt_addr_to_ip}, - json_output, - rpc::balancerpb, -}; - -//////////////////////////////////////////////////////////////////////////////// -// Output Format Enum -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Copy)] -pub enum OutputFormat { - Json, - Tree, - Table, -} - -#[derive(Debug, Clone, Copy)] -pub enum InspectOutputFormat { - Json, - Normal, - Detail, -} - -//////////////////////////////////////////////////////////////////////////////// -// Helper Functions -//////////////////////////////////////////////////////////////////////////////// - -fn format_real(ip: IpAddr, port: u16) -> String { - if port == 0 { - format!("{}", ip) - } else { - format_ip_port(ip, port) - } -} - -/// Format IP address with port, using brackets for IPv6 -fn format_ip_port(ip: IpAddr, port: u16) -> String { - match ip { - IpAddr::V4(_) => format!("{}:{}", ip, port), - IpAddr::V6(_) => format!("[{}]:{}", ip, port), - } -} - -/// Format IP address with port and protocol, using brackets for IPv6 -fn format_vs(ip: IpAddr, port: u32, proto: i32) -> String { - let proto_str = proto_to_string(proto); - match ip { - IpAddr::V4(_) => format!("{}:{}/{}", ip, port, proto_str), - IpAddr::V6(_) => format!("[{}]:{}/{}", ip, port, proto_str), - } -} - -/// Print a boxed header with title and optional subtitle (can be multi-line) -fn print_boxed_header(title: &str, subtitle: Option<&str>) { - let title_len = title.len(); - - // Handle multi-line subtitles - let subtitle_lines: Vec<&str> = subtitle.map(|s| s.lines().collect()).unwrap_or_default(); - let max_subtitle_len = subtitle_lines.iter().map(|line| line.len()).max().unwrap_or(0); - let max_len = title_len.max(max_subtitle_len); - let box_width = max_len + 4; // 2 spaces padding on each side - - // Top border - println!("{}", format!("╔{}╗", "═".repeat(box_width)).cyan().bold()); - - // Title line (centered) - let title_padding = (box_width - title_len) / 2; - print!("{}", "║".cyan().bold()); - print!("{}", " ".repeat(title_padding)); - print!("{}", title.white().bold()); - print!("{}", " ".repeat(box_width - title_len - title_padding)); - println!("{}", "║".cyan().bold()); - - // Subtitle lines if present - if !subtitle_lines.is_empty() { - println!("{}", format!("╟{}╢", "─".repeat(box_width)).cyan()); - for line in subtitle_lines { - let line_len = line.len(); - let line_padding = (box_width - line_len) / 2; - print!("{}", "║".cyan()); - print!("{}", " ".repeat(line_padding)); - print!("{}", line.bright_white()); - print!("{}", " ".repeat(box_width - line_len - line_padding)); - println!("{}", "║".cyan()); - } - } - - // Bottom border - println!("{}", format!("╚{}╝", "═".repeat(box_width)).cyan().bold()); -} - -fn proto_to_string(proto: i32) -> String { - match balancerpb::TransportProto::try_from(proto) { - Ok(balancerpb::TransportProto::Tcp) => "TCP".to_string(), - Ok(balancerpb::TransportProto::Udp) => "UDP".to_string(), - _ => format!("Unknown({})", proto), - } -} - -fn scheduler_to_string(sched: i32) -> String { - match balancerpb::VsScheduler::try_from(sched) { - Ok(balancerpb::VsScheduler::SourceHash) => "source_hash".to_string(), - Ok(balancerpb::VsScheduler::RoundRobin) => "round_robin".to_string(), - _ => format!("Unknown({})", sched), - } -} - -fn format_timestamp(ts: Option<&prost_types::Timestamp>) -> String { - match ts { - Some(ts) if ts.seconds == 0 && ts.nanos == 0 => "N/A".to_string(), - Some(ts) => { - let dt = DateTime::::from_timestamp(ts.seconds, ts.nanos as u32).unwrap_or_default(); - dt.format("%Y-%m-%d %H:%M:%S").to_string() - } - None => "N/A".to_string(), - } -} - -fn format_duration(dur: Option<&prost_types::Duration>) -> String { - match dur { - Some(dur) => format!("{}s", dur.seconds), - None => "N/A".to_string(), - } -} - -fn format_flags(flags: Option<&balancerpb::VsFlags>) -> String { - match flags { - Some(flags) => { - let mut parts = Vec::new(); - if flags.gre { - parts.push("gre"); - } - if flags.fix_mss { - parts.push("mss"); - } - if flags.ops { - parts.push("ops"); - } - if flags.pure_l3 { - parts.push("l3"); - } - if flags.wlc { - parts.push("wlc"); - } - if parts.is_empty() { - "none".to_string() - } else { - parts.join(",") - } - } - None => "none".to_string(), - } -} - -//////////////////////////////////////////////////////////////////////////////// -// UpdateInfo Output -//////////////////////////////////////////////////////////////////////////////// - -/// Print update information after configuration update -pub fn print_update_info(update_info: &balancerpb::UpdateInfo, format: OutputFormat) -> Result<(), Box> { - match format { - OutputFormat::Json => print_update_info_json(update_info), - OutputFormat::Tree => print_update_info_tree(update_info), - OutputFormat::Table => print_update_info_table(update_info), - } -} - -fn print_update_info_json(update_info: &balancerpb::UpdateInfo) -> Result<(), Box> { - let json = json_output::convert_update_info(update_info); - println!("{}", serde_json::to_string(&json)?); - Ok(()) -} - -fn print_update_info_tree(update_info: &balancerpb::UpdateInfo) -> Result<(), Box> { - let mut tree = TreeBuilder::new("Configuration Update".to_string()); - - // Operation type - let operation = if update_info.created { - "Created (new configuration)" - } else { - "Updated (existing configuration)" - }; - tree.add_empty_child(format!("Operation: {}", operation)); - - // Filter reuse status (only relevant for updates) - if !update_info.created { - tree.begin_child("Filter Reuse Status".to_string()); - - let ipv4_status = if update_info.vs_ipv4_matcher_reused { - "Reused (not recompiled)" - } else { - "Recompiled" - }; - tree.add_empty_child(format!("IPv4 VS Matcher: {}", ipv4_status)); - - let ipv6_status = if update_info.vs_ipv6_matcher_reused { - "Reused (not recompiled)" - } else { - "Recompiled" - }; - tree.add_empty_child(format!("IPv6 VS Matcher: {}", ipv6_status)); - - tree.end_child(); - - // ACL reuse information - if !update_info.vs_acl_reuses.is_empty() { - tree.begin_child(format!( - "ACL Filters Reused ({} virtual services)", - update_info.vs_acl_reuses.len() - )); - for vs_id in &update_info.vs_acl_reuses { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.add_empty_child(format_vs(ip, vs_id.port, vs_id.proto)); - } - } - tree.end_child(); - } else { - tree.add_empty_child("ACL Filters Reused: None (all ACLs recompiled)".to_string()); - } - } - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_update_info_table(update_info: &balancerpb::UpdateInfo) -> Result<(), Box> { - println!(); - println!("{}", "═".repeat(60).cyan().bold()); - println!("{}", " Configuration Update Summary".white().bold()); - println!("{}", "═".repeat(60).cyan().bold()); - println!(); - - // Operation type - let operation = if update_info.created { - "Created (new configuration)".bright_green().bold() - } else { - "Updated (existing configuration)".bright_blue().bold() - }; - println!("{} {}", "Operation:".bright_cyan().bold(), operation); - println!(); - - // Filter reuse status (only relevant for updates) - if !update_info.created { - println!("{}", "Filter Reuse Status:".bright_cyan().bold()); - - let ipv4_status = if update_info.vs_ipv4_matcher_reused { - "✓ Reused (not recompiled)".bright_green() - } else { - "✗ Recompiled".bright_yellow() - }; - println!(" IPv4 VS Matcher: {}", ipv4_status); - - let ipv6_status = if update_info.vs_ipv6_matcher_reused { - "✓ Reused (not recompiled)".bright_green() - } else { - "✗ Recompiled".bright_yellow() - }; - println!(" IPv6 VS Matcher: {}", ipv6_status); - - println!(); - - // ACL reuse information - if !update_info.vs_acl_reuses.is_empty() { - println!( - "{} {}", - "ACL Filters Reused:".bright_cyan().bold(), - format!("({} virtual services)", update_info.vs_acl_reuses.len()).bright_white() - ); - - for vs_id in &update_info.vs_acl_reuses { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - println!(" • {}", format_vs(ip, vs_id.port, vs_id.proto).bright_yellow(),); - } - } - } else { - println!( - "{} {}", - "ACL Filters Reused:".bright_cyan().bold(), - "None (all ACLs recompiled)".bright_yellow() - ); - } - - println!(); - } - - println!("{}", "═".repeat(60).cyan().bold()); - println!(); - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// VS Update/Delete Info Output -//////////////////////////////////////////////////////////////////////////////// - -#[derive(Debug, Clone, Copy)] -pub enum VsOperation { - Update, - Delete, -} - -/// Print VS update/delete information (without created flag) -pub fn print_vs_update_info( - update_info: &balancerpb::UpdateInfo, - format: OutputFormat, - operation: VsOperation, -) -> Result<(), Box> { - match format { - OutputFormat::Json => print_vs_update_info_json(update_info), - OutputFormat::Tree => print_vs_update_info_tree(update_info, operation), - OutputFormat::Table => print_vs_update_info_table(update_info, operation), - } -} - -fn print_vs_update_info_json(update_info: &balancerpb::UpdateInfo) -> Result<(), Box> { - let json = json_output::convert_vs_update_info(update_info); - println!("{}", serde_json::to_string(&json)?); - Ok(()) -} - -fn print_vs_update_info_tree( - update_info: &balancerpb::UpdateInfo, - operation: VsOperation, -) -> Result<(), Box> { - let title = match operation { - VsOperation::Update => "VS Update Result", - VsOperation::Delete => "VS Delete Result", - }; - let mut tree = TreeBuilder::new(title.to_string()); - - // Filter reuse status - tree.begin_child("Filter Reuse Status".to_string()); - - let ipv4_status = if update_info.vs_ipv4_matcher_reused { - "Reused (not recompiled)" - } else { - "Recompiled" - }; - tree.add_empty_child(format!("IPv4 VS Matcher: {}", ipv4_status)); - - let ipv6_status = if update_info.vs_ipv6_matcher_reused { - "Reused (not recompiled)" - } else { - "Recompiled" - }; - tree.add_empty_child(format!("IPv6 VS Matcher: {}", ipv6_status)); - - tree.end_child(); - - // ACL reuse information (only for update operations) - if matches!(operation, VsOperation::Update) { - if !update_info.vs_acl_reuses.is_empty() { - tree.begin_child(format!( - "ACL Filters Reused ({} virtual services)", - update_info.vs_acl_reuses.len() - )); - for vs_id in &update_info.vs_acl_reuses { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.add_empty_child(format_vs(ip, vs_id.port, vs_id.proto)); - } - } - tree.end_child(); - } else { - tree.add_empty_child("ACL Filters Reused: None".to_string()); - } - } - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_vs_update_info_table( - update_info: &balancerpb::UpdateInfo, - operation: VsOperation, -) -> Result<(), Box> { - println!(); - println!("{}", "═".repeat(60).cyan().bold()); - let title = match operation { - VsOperation::Update => " VS Update Summary", - VsOperation::Delete => " VS Delete Result", - }; - println!("{}", title.white().bold()); - println!("{}", "═".repeat(60).cyan().bold()); - println!(); - - // Filter reuse status - println!("{}", "Filter Reuse Status:".bright_cyan().bold()); - - let ipv4_status = if update_info.vs_ipv4_matcher_reused { - "✓ Reused (not recompiled)".bright_green() - } else { - "✗ Recompiled".bright_yellow() - }; - println!(" IPv4 VS Matcher: {}", ipv4_status); - - let ipv6_status = if update_info.vs_ipv6_matcher_reused { - "✓ Reused (not recompiled)".bright_green() - } else { - "✗ Recompiled".bright_yellow() - }; - println!(" IPv6 VS Matcher: {}", ipv6_status); - - println!(); - - // ACL reuse information (only for update operations) - if matches!(operation, VsOperation::Update) { - if !update_info.vs_acl_reuses.is_empty() { - println!( - "{} {}", - "ACL Filters Reused:".bright_cyan().bold(), - format!("({} virtual services)", update_info.vs_acl_reuses.len()).bright_white() - ); - - for vs_id in &update_info.vs_acl_reuses { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - println!(" • {}", format_vs(ip, vs_id.port, vs_id.proto).bright_yellow()); - } - } - } else { - println!( - "{} {}", - "ACL Filters Reused:".bright_cyan().bold(), - "None".bright_yellow() - ); - } - - println!(); - } - - println!("{}", "═".repeat(60).cyan().bold()); - println!(); - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowConfig Output -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_config( - response: &balancerpb::ShowConfigResponse, - format: OutputFormat, -) -> Result<(), Box> { - match format { - OutputFormat::Json => print_show_config_json(response), - OutputFormat::Tree => print_show_config_tree(response), - OutputFormat::Table => print_show_config_table(response), - } -} - -fn print_show_config_json(response: &balancerpb::ShowConfigResponse) -> Result<(), Box> { - let json = json_output::convert_show_config(response); - println!("{}", serde_json::to_string(&json)?); - Ok(()) -} - -fn print_show_config_tree(response: &balancerpb::ShowConfigResponse) -> Result<(), Box> { - let mut tree = TreeBuilder::new("Balancer Configuration".to_string()); - - tree.begin_child(format!("Config: {}", response.name)); - - if let Some(config) = &response.config { - if let Some(packet_handler) = &config.packet_handler { - tree.begin_child("Packet Handler".to_string()); - - // Source addresses - tree.begin_child("Source Addresses".to_string()); - if let Ok(ipv4) = opt_addr_to_ip(&packet_handler.source_address_v4) { - tree.add_empty_child(format!("IPv4: {}", ipv4)); - } - if let Ok(ipv6) = opt_addr_to_ip(&packet_handler.source_address_v6) { - tree.add_empty_child(format!("IPv6: {}", ipv6)); - } - tree.end_child(); - - // Decap addresses - if !packet_handler.decap_addresses.is_empty() { - tree.begin_child("Decap Addresses".to_string()); - for addr in &packet_handler.decap_addresses { - if let Ok(ip) = addr_to_ip(addr) { - tree.add_empty_child(ip.to_string()); - } - } - tree.end_child(); - } - - // Timeouts - if let Some(timeouts) = &packet_handler.sessions_timeouts { - tree.begin_child("Session Timeouts".to_string()); - tree.add_empty_child(format!("TCP: {}s", timeouts.tcp)); - tree.add_empty_child(format!("TCP SYN: {}s", timeouts.tcp_syn)); - tree.add_empty_child(format!("TCP SYN-ACK: {}s", timeouts.tcp_syn_ack)); - tree.add_empty_child(format!("TCP FIN: {}s", timeouts.tcp_fin)); - tree.add_empty_child(format!("UDP: {}s", timeouts.udp)); - tree.add_empty_child(format!("Default: {}s", timeouts.default)); - tree.end_child(); - } - - // Virtual services - tree.begin_child(format!("Virtual Services ({})", packet_handler.vs.len())); - for (idx, vs) in packet_handler.vs.iter().enumerate() { - if let Some(vs_id) = &vs.id { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.begin_child(format!("[{}]", idx).cyan().to_string()); - tree.add_empty_child(format!("VS: {}", format_vs(ip, vs_id.port, vs_id.proto))); - tree.add_empty_child(format!("Scheduler: {}", scheduler_to_string(vs.scheduler))); - tree.add_empty_child(format!("Flags: {}", format_flags(vs.flags.as_ref()))); - - if !vs.allowed_srcs.is_empty() { - tree.begin_child("Allowed Sources".to_string()); - for subnet in &vs.allowed_srcs { - if let Ok(formatted) = crate::entities::format_allowed_src(subnet) { - tree.add_empty_child(format!("- {}", formatted)); - } - } - tree.end_child(); - } - - if !vs.peers.is_empty() { - tree.begin_child("Peers".to_string()); - for peer in &vs.peers { - if let Ok(ip) = addr_to_ip(peer) { - tree.add_empty_child(ip.to_string()); - } - } - tree.end_child(); - } - - tree.begin_child(format!("Reals ({})", vs.reals.len())); - for (ridx, real) in vs.reals.iter().enumerate() { - if let Some(real_id) = &real.id { - if let Ok(dst) = opt_addr_to_ip(&real_id.ip) { - tree.begin_child(format!("[{}]", ridx).cyan().to_string()); - tree.add_empty_child(format!("Real: {}", format_real(dst, real_id.port as u16))); - tree.add_empty_child(format!("Weight: {}", real.weight)); - tree.end_child(); - } - } - } - tree.end_child(); - - tree.end_child(); - } - } - } - tree.end_child(); - - tree.end_child(); - } - - if let Some(state_config) = &config.state { - tree.begin_child("State".to_string()); - if let Some(capacity) = state_config.session_table_capacity { - tree.add_empty_child(format!("Session Table Capacity: {}", format_number(capacity))); - } - if let Some(period) = &state_config.refresh_period { - tree.add_empty_child(format!( - "Refresh Period: {}ms", - period.seconds * 1000 + period.nanos as i64 / 1_000_000 - )); - } - if let Some(load_factor) = state_config.session_table_max_load_factor { - tree.add_empty_child(format!("Max Load Factor: {:.2}", load_factor)); - } - if let Some(wlc) = &state_config.wlc { - tree.begin_child("WLC".to_string()); - if let Some(power) = wlc.power { - tree.add_empty_child(format!("Power: {}", power)); - } - if let Some(max_weight) = wlc.max_weight { - tree.add_empty_child(format!("Max Weight: {}", max_weight)); - } - tree.end_child(); - } - tree.end_child(); - } - } - - if !response.buffered_real_updates.is_empty() { - tree.begin_child(format!( - "Buffered Real Updates ({})", - response.buffered_real_updates.len() - )); - for (idx, update) in response.buffered_real_updates.iter().enumerate() { - if let Some(real_id) = &update.real_id { - if let (Some(vs_id), Some(rel_real)) = (&real_id.vs, &real_id.real) { - let vip = opt_addr_to_ip(&vs_id.addr) - .ok() - .map(|ip| ip.to_string()) - .unwrap_or_default(); - let rip = opt_addr_to_ip(&rel_real.ip) - .ok() - .map(|ip| ip.to_string()) - .unwrap_or_default(); - let action = if update.enable.unwrap_or(false) { - "enable" - } else { - "disable" - }; - tree.begin_child(format!("[{}]", idx).cyan().to_string()); - tree.add_empty_child(format!("Action: {}", action)); - tree.add_empty_child(format!( - "VS: {}", - format_vs( - vip.parse() - .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)), - vs_id.port, - vs_id.proto - ) - )); - tree.add_empty_child(format!( - "Real: {}", - format_real( - rip.parse() - .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)), - rel_real.port as u16 - ) - )); - tree.end_child(); - } - } - } - tree.end_child(); - } - - tree.end_child(); - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_show_config_table(response: &balancerpb::ShowConfigResponse) -> Result<(), Box> { - // Print header - let subtitle = Some(format!("Config: {}", response.name)); - print_boxed_header("BALANCER CONFIGURATION", subtitle.as_deref()); - println!(); - - if let Some(config) = &response.config { - if let Some(packet_handler) = &config.packet_handler { - // Decap addresses (one per line, green color for list items) - println!("{}", "Decap Addresses:".bright_cyan().bold()); - if !packet_handler.decap_addresses.is_empty() { - for addr in &packet_handler.decap_addresses { - if let Ok(ip) = addr_to_ip(addr) { - println!(" {}", ip.to_string().bright_green()); - } - } - } else { - println!(" {}", "None".bright_green()); - } - println!(); - - // Source addresses - println!("{}", "Source Addresses:".bright_cyan().bold()); - if let Ok(ipv4) = opt_addr_to_ip(&packet_handler.source_address_v4) { - println!(" IPv4: {}", ipv4.to_string().bright_green()); - } - if let Ok(ipv6) = opt_addr_to_ip(&packet_handler.source_address_v6) { - println!(" IPv6: {}", ipv6.to_string().bright_green()); - } - println!(); - - // Session timeouts (one per line) - if let Some(timeouts) = &packet_handler.sessions_timeouts { - println!("{}", "Session Timeouts:".bright_cyan().bold()); - println!(" TCP: {}", format!("{}s", timeouts.tcp).bright_green()); - println!(" TCP SYN: {}", format!("{}s", timeouts.tcp_syn).bright_green()); - println!(" TCP SYN-ACK: {}", format!("{}s", timeouts.tcp_syn_ack).bright_green()); - println!(" TCP FIN: {}", format!("{}s", timeouts.tcp_fin).bright_green()); - println!(" UDP: {}", format!("{}s", timeouts.udp).bright_green()); - println!(" Default: {}", format!("{}s", timeouts.default).bright_green()); - println!(); - } - } - - // State (one value per line) - if let Some(state_config) = &config.state { - println!("{}", "State:".bright_cyan().bold()); - - if let Some(capacity) = state_config.session_table_capacity { - println!(" Session Table Capacity: {}", format_number(capacity).bright_green()); - } - - if let Some(period) = &state_config.refresh_period { - let refresh_period_ms = period.seconds * 1000 + period.nanos as i64 / 1_000_000; - println!( - " Refresh Period: {}", - format!("{}ms", refresh_period_ms).bright_green() - ); - } - - if let Some(load_factor) = state_config.session_table_max_load_factor { - println!(" Max Load Factor: {}", format!("{:.2}", load_factor).bright_green()); - } - - if let Some(wlc) = &state_config.wlc { - if let Some(power) = wlc.power { - println!(" WLC Power: {}", power.to_string().bright_green()); - } - if let Some(max_weight) = wlc.max_weight { - println!(" WLC Max Weight: {}", max_weight.to_string().bright_green()); - } - } - println!(); - } - } - - // Virtual services (hierarchical display similar to info/stats) - if let Some(config) = &response.config { - if let Some(packet_handler) = &config.packet_handler { - if !packet_handler.vs.is_empty() { - // Print details for each VS - for vs in &packet_handler.vs { - if let Some(vs_id) = &vs.id { - if let Ok(vs_ip) = opt_addr_to_ip(&vs_id.addr) { - println!( - "{}:", - format!("VS {}", format_vs(vs_ip, vs_id.port, vs_id.proto)) - .bright_yellow() - .bold() - ); - - // Display VS properties on separate lines - println!(" Scheduler: {}", scheduler_to_string(vs.scheduler).bright_green()); - println!(" Flags: {}", format_flags(vs.flags.as_ref()).bright_green()); - - // Peers - one per line with dash prefix - if !vs.peers.is_empty() { - println!(" Peers:"); - for peer in &vs.peers { - if let Ok(ip) = addr_to_ip(peer) { - println!(" - {}", ip.to_string().bright_green()); - } - } - } else { - println!(" Peers: {}", "none".bright_green()); - } - - // Allowed sources - one per line with dash prefix - if !vs.allowed_srcs.is_empty() { - println!(" Allowed Sources:"); - for src in &vs.allowed_srcs { - if let Ok(formatted) = crate::entities::format_allowed_src(src) { - println!(" - {}", formatted.bright_green()); - } - } - } else { - println!(" Allowed Sources: {}", "none".bright_green()); - } - - // Reals table - if !vs.reals.is_empty() { - #[derive(Tabled)] - struct RealRow { - #[tabled(rename = "Real")] - real: String, - #[tabled(rename = "Weight")] - weight: String, - #[tabled(rename = "Source")] - source: String, - #[tabled(rename = "Source Mask")] - mask: String, - } - - let real_rows: Vec = vs - .reals - .iter() - .filter_map(|real| { - real.id.as_ref().map(|real_id| RealRow { - real: opt_addr_to_ip(&real_id.ip) - .map(|ip| format_real(ip, real_id.port as u16)) - .unwrap_or_default(), - weight: real.weight.to_string(), - source: opt_addr_to_ip(&real.src_addr) - .map(|ip| ip.to_string()) - .unwrap_or_default(), - mask: opt_addr_to_ip(&real.src_mask) - .map(|ip| ip.to_string()) - .unwrap_or_default(), - }) - }) - .collect(); - - print_table(real_rows); - } - - println!(); - } - } - } - } - } - } - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ListConfigs Output -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_list_configs( - response: &balancerpb::ListConfigsResponse, - format: OutputFormat, -) -> Result<(), Box> { - match format { - OutputFormat::Json => { - let json = json_output::convert_list_configs(response); - println!("{}", serde_json::to_string(&json)?); - } - OutputFormat::Tree => { - let mut tree = TreeBuilder::new(format!("Balancer Configs ({})", response.configs.len())); - - for config_name in &response.configs { - tree.add_empty_child(config_name.clone()); - } - - let tree = tree.build(); - ptree::print_tree(&tree)?; - } - OutputFormat::Table => { - #[derive(Tabled)] - struct ConfigRow { - #[tabled(rename = "Config Name")] - name: String, - } - - let rows: Vec = response - .configs - .iter() - .map(|name| ConfigRow { name: name.clone() }) - .collect(); - - print_table(rows); - } - } - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowInfo Output (Hierarchical) -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_info(response: &balancerpb::ShowInfoResponse, format: OutputFormat) -> Result<(), Box> { - match format { - OutputFormat::Json => { - let json = json_output::convert_show_info(response); - println!("{}", serde_json::to_string(&json)?); - } - OutputFormat::Tree => print_show_info_tree(response)?, - OutputFormat::Table => print_show_info_table(response)?, - } - Ok(()) -} - -fn print_show_info_tree(response: &balancerpb::ShowInfoResponse) -> Result<(), Box> { - let mut tree = TreeBuilder::new("Balancer State Info".to_string()); - - tree.begin_child(format!("Config: {}", response.name)); - - if let Some(info) = &response.info { - tree.add_empty_child(format!("Active Sessions: {}", format_number(info.active_sessions))); - tree.add_empty_child(format!( - "Last Packet: {}", - format_timestamp(info.last_packet_timestamp.as_ref()) - )); - - // Virtual services (hierarchical - reals nested under VS) - if !info.vs.is_empty() { - tree.begin_child(format!("Virtual Services ({})", info.vs.len())); - for (vs_idx, vs_info) in info.vs.iter().enumerate() { - if let Some(vs_id) = &vs_info.id { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.begin_child(format!("[{}]", vs_idx).cyan().to_string()); - tree.add_empty_child(format!("VS: {}", format_vs(ip, vs_id.port, vs_id.proto))); - tree.add_empty_child(format!("Active Sessions: {}", format_number(vs_info.active_sessions))); - tree.add_empty_child(format!( - "Last Packet: {}", - format_timestamp(vs_info.last_packet_timestamp.as_ref()) - )); - - // Reals under this VS - if !vs_info.reals.is_empty() { - tree.begin_child(format!("Reals ({})", vs_info.reals.len())); - for (real_idx, real_info) in vs_info.reals.iter().enumerate() { - if let Some(real_id) = &real_info.id { - if let Some(rel_real) = &real_id.real { - if let Ok(real_ip) = opt_addr_to_ip(&rel_real.ip) { - tree.begin_child(format!("[{}]", real_idx).cyan().to_string()); - tree.add_empty_child(format!( - "Real: {}", - format_real(real_ip, rel_real.port as u16) - )); - tree.add_empty_child(format!( - "Active Sessions: {}", - format_number(real_info.active_sessions) - )); - tree.add_empty_child(format!( - "Last Packet: {}", - format_timestamp(real_info.last_packet_timestamp.as_ref()) - )); - tree.end_child(); - } - } - } - } - tree.end_child(); - } - - tree.end_child(); - } - } - } - tree.end_child(); - } - } - - tree.end_child(); - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_show_info_table(response: &balancerpb::ShowInfoResponse) -> Result<(), Box> { - // Print header - let subtitle = Some(format!("Config: {}", response.name)); - print_boxed_header("BALANCER INFO", subtitle.as_deref()); - - println!(); - - if let Some(info) = &response.info { - println!( - "Active Sessions: {}", - format_number(info.active_sessions).bright_green() - ); - println!( - "Last Packet: {}", - format_timestamp(info.last_packet_timestamp.as_ref()).bright_green() - ); - } - - println!(); - - if let Some(info) = &response.info { - // VS table (hierarchical display - reals nested under VS) - if !info.vs.is_empty() { - for vs_info in &info.vs { - if let Some(vs_id) = &vs_info.id { - if let Ok(vs_ip) = opt_addr_to_ip(&vs_id.addr) { - println!( - "{}:", - format!("VS {}", format_vs(vs_ip, vs_id.port, vs_id.proto)) - .bright_yellow() - .bold() - ); - println!( - " Active Sessions: {}", - format_number(vs_info.active_sessions).bright_green() - ); - println!( - " Last Packet: {}", - format_timestamp(vs_info.last_packet_timestamp.as_ref()).bright_green() - ); - - if !vs_info.reals.is_empty() { - #[derive(Tabled)] - struct RealInfoRow { - #[tabled(rename = "Real")] - real: String, - #[tabled(rename = "Active Sessions")] - sessions: String, - #[tabled(rename = "Last Packet")] - last_packet: String, - } - - let rows: Vec = vs_info - .reals - .iter() - .filter_map(|real_info| { - real_info.id.as_ref().and_then(|real_id| { - real_id.real.as_ref().and_then(|rel_real| { - opt_addr_to_ip(&rel_real.ip).ok().map(|real_ip| RealInfoRow { - real: format_real(real_ip, rel_real.port as u16), - sessions: format_number(real_info.active_sessions), - last_packet: format_timestamp(real_info.last_packet_timestamp.as_ref()), - }) - }) - }) - }) - .collect(); - - print_table(rows); - } - println!(); - } - } - } - } - } - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowStats Output (Hierarchical) -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_stats(response: &balancerpb::ShowStatsResponse, format: OutputFormat) -> Result<(), Box> { - match format { - OutputFormat::Json => { - let json = json_output::convert_show_stats(response); - println!("{}", serde_json::to_string(&json)?); - } - OutputFormat::Tree => print_show_stats_tree(response)?, - OutputFormat::Table => print_show_stats_table(response)?, - } - Ok(()) -} - -fn print_show_stats_tree(response: &balancerpb::ShowStatsResponse) -> Result<(), Box> { - let mut tree = TreeBuilder::new("Balancer Statistics".to_string()); - - for (entry_idx, entry) in response.entries.iter().enumerate() { - tree.begin_child(format!("[{}]", entry_idx).cyan().to_string()); - tree.add_empty_child(format!("Config: {}", entry.name)); - - if let Some(ref_info) = &entry.r#ref { - if let Some(device) = &ref_info.device { - tree.add_empty_child(format!("Device: {}", device)); - } - if let Some(pipeline) = &ref_info.pipeline { - tree.add_empty_child(format!("Pipeline: {}", pipeline)); - } - if let Some(function) = &ref_info.function { - tree.add_empty_child(format!("Function: {}", function)); - } - if let Some(chain) = &ref_info.chain { - tree.add_empty_child(format!("Chain: {}", chain)); - } - } - - if let Some(stats) = &entry.stats { - // Module stats (split into components) - tree.begin_child("Module".to_string()); - - if let Some(common) = &stats.common { - tree.begin_child("Common".to_string()); - tree.add_empty_child(format!("Incoming Packets: {}", format_number(common.incoming_packets))); - tree.add_empty_child(format!("Incoming Bytes: {}", format_bytes(common.incoming_bytes))); - tree.add_empty_child(format!( - "Unexpected Network Proto: {}", - format_number(common.unexpected_network_proto) - )); - tree.add_empty_child(format!("Decap Successful: {}", format_number(common.decap_successful))); - tree.add_empty_child(format!("Decap Failed: {}", format_number(common.decap_failed))); - tree.add_empty_child(format!("Outgoing Packets: {}", format_number(common.outgoing_packets))); - tree.add_empty_child(format!("Outgoing Bytes: {}", format_bytes(common.outgoing_bytes))); - tree.end_child(); - } - - if let Some(l4) = &stats.l4 { - tree.begin_child("L4".to_string()); - tree.add_empty_child(format!("Incoming Packets: {}", format_number(l4.incoming_packets))); - tree.add_empty_child(format!("Outgoing Packets: {}", format_number(l4.outgoing_packets))); - tree.add_empty_child(format!("Select VS Failed: {}", format_number(l4.select_vs_failed))); - tree.add_empty_child(format!("Select Real Failed: {}", format_number(l4.select_real_failed))); - tree.add_empty_child(format!("Invalid Packets: {}", format_number(l4.invalid_packets))); - tree.end_child(); - } - - if let Some(icmpv4) = &stats.icmpv4 { - tree.begin_child("ICMPv4".to_string()); - tree.add_empty_child(format!("Incoming Packets: {}", format_number(icmpv4.incoming_packets))); - tree.add_empty_child(format!("Src Not Allowed: {}", format_number(icmpv4.src_not_allowed))); - tree.add_empty_child(format!("Echo Responses: {}", format_number(icmpv4.echo_responses))); - tree.add_empty_child(format!( - "Payload Too Short IP: {}", - format_number(icmpv4.payload_too_short_ip) - )); - tree.add_empty_child(format!( - "Unmatching Src From Original: {}", - format_number(icmpv4.unmatching_src_from_original) - )); - tree.add_empty_child(format!( - "Payload Too Short Port: {}", - format_number(icmpv4.payload_too_short_port) - )); - tree.add_empty_child(format!( - "Unexpected Transport: {}", - format_number(icmpv4.unexpected_transport) - )); - tree.add_empty_child(format!("Unrecognized VS: {}", format_number(icmpv4.unrecognized_vs))); - tree.add_empty_child(format!( - "Forwarded Packets: {}", - format_number(icmpv4.forwarded_packets) - )); - tree.add_empty_child(format!( - "Broadcasted Packets: {}", - format_number(icmpv4.broadcasted_packets) - )); - tree.add_empty_child(format!( - "Packet Clones Sent: {}", - format_number(icmpv4.packet_clones_sent) - )); - tree.add_empty_child(format!( - "Packet Clones Received: {}", - format_number(icmpv4.packet_clones_received) - )); - tree.add_empty_child(format!( - "Packet Clone Failures: {}", - format_number(icmpv4.packet_clone_failures) - )); - tree.end_child(); - } - - if let Some(icmpv6) = &stats.icmpv6 { - tree.begin_child("ICMPv6".to_string()); - tree.add_empty_child(format!("Incoming Packets: {}", format_number(icmpv6.incoming_packets))); - tree.add_empty_child(format!("Src Not Allowed: {}", format_number(icmpv6.src_not_allowed))); - tree.add_empty_child(format!("Echo Responses: {}", format_number(icmpv6.echo_responses))); - tree.add_empty_child(format!( - "Payload Too Short IP: {}", - format_number(icmpv6.payload_too_short_ip) - )); - tree.add_empty_child(format!( - "Unmatching Src From Original: {}", - format_number(icmpv6.unmatching_src_from_original) - )); - tree.add_empty_child(format!( - "Payload Too Short Port: {}", - format_number(icmpv6.payload_too_short_port) - )); - tree.add_empty_child(format!( - "Unexpected Transport: {}", - format_number(icmpv6.unexpected_transport) - )); - tree.add_empty_child(format!("Unrecognized VS: {}", format_number(icmpv6.unrecognized_vs))); - tree.add_empty_child(format!( - "Forwarded Packets: {}", - format_number(icmpv6.forwarded_packets) - )); - tree.add_empty_child(format!( - "Broadcasted Packets: {}", - format_number(icmpv6.broadcasted_packets) - )); - tree.add_empty_child(format!( - "Packet Clones Sent: {}", - format_number(icmpv6.packet_clones_sent) - )); - tree.add_empty_child(format!( - "Packet Clones Received: {}", - format_number(icmpv6.packet_clones_received) - )); - tree.add_empty_child(format!( - "Packet Clone Failures: {}", - format_number(icmpv6.packet_clone_failures) - )); - tree.end_child(); - } - - tree.end_child(); // Module - - // VS stats - if !stats.vs.is_empty() { - tree.begin_child(format!("Virtual Services ({})", stats.vs.len())); - for (vs_idx, vs) in stats.vs.iter().enumerate() { - if let Some(vs_id) = &vs.vs { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.begin_child(format!("[{}]", vs_idx).cyan().to_string()); - tree.add_empty_child(format!("VS: {}", format_vs(ip, vs_id.port, vs_id.proto))); - if let Some(s) = &vs.stats { - tree.add_empty_child(format!( - "Incoming: {} pkts, {}", - format_number(s.incoming_packets), - format_bytes(s.incoming_bytes) - )); - tree.add_empty_child(format!( - "Outgoing: {} pkts, {}", - format_number(s.outgoing_packets), - format_bytes(s.outgoing_bytes) - )); - tree.add_empty_child(format!( - "Created Sessions: {}", - format_number(s.created_sessions) - )); - tree.add_empty_child(format!( - "Packet Src Not Allowed: {}", - format_number(s.packet_src_not_allowed) - )); - tree.add_empty_child(format!("No Reals: {}", format_number(s.no_reals))); - tree.add_empty_child(format!("OPS Packets: {}", format_number(s.ops_packets))); - tree.add_empty_child(format!( - "Session Table Overflow: {}", - format_number(s.session_table_overflow) - )); - tree.add_empty_child(format!( - "Echo ICMP Packets: {}", - format_number(s.echo_icmp_packets) - )); - tree.add_empty_child(format!( - "Error ICMP Packets: {}", - format_number(s.error_icmp_packets) - )); - tree.add_empty_child(format!( - "Real Is Disabled: {}", - format_number(s.real_is_disabled) - )); - tree.add_empty_child(format!("Real Is Removed: {}", format_number(s.real_is_removed))); - tree.add_empty_child(format!( - "Not Rescheduled Packets: {}", - format_number(s.not_rescheduled_packets) - )); - tree.add_empty_child(format!( - "Broadcasted ICMP Packets: {}", - format_number(s.broadcasted_icmp_packets) - )); - } - - // Real stats - if !vs.reals.is_empty() { - tree.begin_child(format!("Reals ({})", vs.reals.len())); - for (real_idx, real) in vs.reals.iter().enumerate() { - if let Some(real_id) = &real.real { - if let Some(rel_real) = &real_id.real { - if let Ok(real_ip) = opt_addr_to_ip(&rel_real.ip) { - tree.begin_child(format!("[{}]", real_idx).cyan().to_string()); - tree.add_empty_child(format!( - "Real: {}", - format_real(real_ip, rel_real.port as u16) - )); - if let Some(s) = &real.stats { - tree.add_empty_child(format!( - "Packets: {}", - format_number(s.packets) - )); - tree.add_empty_child(format!("Bytes: {}", format_bytes(s.bytes))); - tree.add_empty_child(format!( - "Created Sessions: {}", - format_number(s.created_sessions) - )); - tree.add_empty_child(format!( - "Packets Real Disabled: {}", - format_number(s.packets_real_disabled) - )); - tree.add_empty_child(format!( - "OPS Packets: {}", - format_number(s.ops_packets) - )); - tree.add_empty_child(format!( - "Error ICMP Packets: {}", - format_number(s.error_icmp_packets) - )); - } - tree.end_child(); - } - } - } - } - tree.end_child(); - } - - // Allowed sources stats - if !vs.allowed_sources.is_empty() { - tree.begin_child("Allowed Sources".to_string()); - for allowed_src in &vs.allowed_sources { - let tag_str = if allowed_src.tag.is_empty() { - "None".to_string() - } else { - allowed_src.tag.clone() - }; - tree.add_empty_child(format!( - "Tag {}: {} passes", - tag_str, - format_number(allowed_src.passes) - )); - } - tree.end_child(); - } - - tree.end_child(); // VS - } - } - } - tree.end_child(); - } - } - - tree.end_child(); // entry - } - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_show_stats_table(response: &balancerpb::ShowStatsResponse) -> Result<(), Box> { - if response.entries.is_empty() { - print_boxed_header("BALANCER STATISTICS", Some("Entries: 0")); - println!(); - return Ok(()); - } - - for (idx, entry) in response.entries.iter().enumerate() { - if idx > 0 { - println!("{}", "─".repeat(80).bright_black()); - println!(); - } - - // Print header with topology info and config name as two-line subtitle - let subtitle = if let Some(ref_info) = &entry.r#ref { - format!( - "Config: {} | Device: {}\nPipeline: {} | Function: {} | Chain: {}", - entry.name, - ref_info.device.as_deref().unwrap_or("N/A"), - ref_info.pipeline.as_deref().unwrap_or("N/A"), - ref_info.function.as_deref().unwrap_or("N/A"), - ref_info.chain.as_deref().unwrap_or("N/A"), - ) - } else { - format!("Config: {}", entry.name) - }; - print_boxed_header("BALANCER STATISTICS", Some(&subtitle)); - println!(); - - let Some(stats) = &entry.stats else { - continue; - }; - - println!("{}", "Module:".bright_yellow().bold()); - - #[derive(Tabled)] - struct ModuleStatsRow { - #[tabled(rename = "Category")] - category: String, - #[tabled(rename = "Metric")] - metric: String, - #[tabled(rename = "Value")] - value: String, - } - - let mut rows = Vec::new(); - - if let Some(common) = &stats.common { - rows.push(ModuleStatsRow { - category: "Common".to_string(), - metric: "Incoming Pkts".to_string(), - value: format_number(common.incoming_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Incoming Bytes".to_string(), - value: format_bytes(common.incoming_bytes), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unexpected Proto".to_string(), - value: format_number(common.unexpected_network_proto), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Decap Success".to_string(), - value: format_number(common.decap_successful), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Decap Failed".to_string(), - value: format_number(common.decap_failed), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Outgoing Pkts".to_string(), - value: format_number(common.outgoing_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Outgoing Bytes".to_string(), - value: format_bytes(common.outgoing_bytes), - }); - } - - // Separator between Common and L4 - if stats.common.is_some() && stats.l4.is_some() { - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "".to_string(), - value: "".to_string(), - }); - } - - if let Some(l4) = &stats.l4 { - rows.push(ModuleStatsRow { - category: "L4".to_string(), - metric: "Incoming Pkts".to_string(), - value: format_number(l4.incoming_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Outgoing Pkts".to_string(), - value: format_number(l4.outgoing_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Select VS Fail".to_string(), - value: format_number(l4.select_vs_failed), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Select Real Fail".to_string(), - value: format_number(l4.select_real_failed), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Invalid Pkts".to_string(), - value: format_number(l4.invalid_packets), - }); - } - - // Separator between L4 and ICMPv4 - if stats.l4.is_some() && stats.icmpv4.is_some() { - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "".to_string(), - value: "".to_string(), - }); - } - - if let Some(icmpv4) = &stats.icmpv4 { - rows.push(ModuleStatsRow { - category: "ICMPv4".to_string(), - metric: "Incoming Pkts".to_string(), - value: format_number(icmpv4.incoming_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Src Not Allowed".to_string(), - value: format_number(icmpv4.src_not_allowed), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Echo Responses".to_string(), - value: format_number(icmpv4.echo_responses), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Payload Short IP".to_string(), - value: format_number(icmpv4.payload_too_short_ip), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unmatch Src Orig".to_string(), - value: format_number(icmpv4.unmatching_src_from_original), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Payload Short Port".to_string(), - value: format_number(icmpv4.payload_too_short_port), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unexpected Trans".to_string(), - value: format_number(icmpv4.unexpected_transport), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unrecognized VS".to_string(), - value: format_number(icmpv4.unrecognized_vs), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Forwarded Pkts".to_string(), - value: format_number(icmpv4.forwarded_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Broadcasted Pkts".to_string(), - value: format_number(icmpv4.broadcasted_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clones Sent".to_string(), - value: format_number(icmpv4.packet_clones_sent), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clones Received".to_string(), - value: format_number(icmpv4.packet_clones_received), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clone Failures".to_string(), - value: format_number(icmpv4.packet_clone_failures), - }); - } - - // Separator between ICMPv4 and ICMPv6 - if stats.icmpv4.is_some() && stats.icmpv6.is_some() { - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "".to_string(), - value: "".to_string(), - }); - } - - if let Some(icmpv6) = &stats.icmpv6 { - rows.push(ModuleStatsRow { - category: "ICMPv6".to_string(), - metric: "Incoming Pkts".to_string(), - value: format_number(icmpv6.incoming_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Src Not Allowed".to_string(), - value: format_number(icmpv6.src_not_allowed), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Echo Responses".to_string(), - value: format_number(icmpv6.echo_responses), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Payload Short IP".to_string(), - value: format_number(icmpv6.payload_too_short_ip), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unmatch Src Orig".to_string(), - value: format_number(icmpv6.unmatching_src_from_original), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Payload Short Port".to_string(), - value: format_number(icmpv6.payload_too_short_port), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unexpected Trans".to_string(), - value: format_number(icmpv6.unexpected_transport), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Unrecognized VS".to_string(), - value: format_number(icmpv6.unrecognized_vs), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Forwarded Pkts".to_string(), - value: format_number(icmpv6.forwarded_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Broadcasted Pkts".to_string(), - value: format_number(icmpv6.broadcasted_packets), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clones Sent".to_string(), - value: format_number(icmpv6.packet_clones_sent), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clones Received".to_string(), - value: format_number(icmpv6.packet_clones_received), - }); - rows.push(ModuleStatsRow { - category: "".to_string(), - metric: "Clone Failures".to_string(), - value: format_number(icmpv6.packet_clone_failures), - }); - } - - print_table(rows); - println!(); - - // VS Stats (hierarchical display - reals nested under VS, similar to info) - if !stats.vs.is_empty() { - for vs in &stats.vs { - if let Some(vs_id) = &vs.vs { - if let Ok(vs_ip) = opt_addr_to_ip(&vs_id.addr) { - println!( - "{}:", - format!("VS {}", format_vs(vs_ip, vs_id.port, vs_id.proto)) - .bright_yellow() - .bold() - ); - - // Display VS-level stats - one metric per line - if let Some(s) = &vs.stats { - println!( - " Incoming Packets: {}", - format_number(s.incoming_packets).bright_green() - ); - println!(" Incoming Bytes: {}", format_bytes(s.incoming_bytes).bright_green()); - println!( - " Outgoing Packets: {}", - format_number(s.outgoing_packets).bright_green() - ); - println!(" Outgoing Bytes: {}", format_bytes(s.outgoing_bytes).bright_green()); - println!( - " Created Sessions: {}", - format_number(s.created_sessions).bright_green() - ); - println!(" OPS Packets: {}", format_number(s.ops_packets).bright_green()); - println!( - " Packet Src Not Allowed: {}", - format_number(s.packet_src_not_allowed).bright_green() - ); - println!( - " Session Table Overflow: {}", - format_number(s.session_table_overflow).bright_green() - ); - println!( - " Not Rescheduled Packets: {}", - format_number(s.not_rescheduled_packets).bright_green() - ); - println!( - " Real Is Disabled: {}", - format_number(s.real_is_disabled).bright_green() - ); - println!(" Real Is Removed: {}", format_number(s.real_is_removed).bright_green()); - println!(" No Reals: {}", format_number(s.no_reals).bright_green()); - println!( - " Echo ICMP Packets: {}", - format_number(s.echo_icmp_packets).bright_green() - ); - println!( - " Error ICMP Packets: {}", - format_number(s.error_icmp_packets).bright_green() - ); - println!( - " Broadcasted ICMP Packets: {}", - format_number(s.broadcasted_icmp_packets).bright_green() - ); - } - - // Display allowed sources stats - if !vs.allowed_sources.is_empty() { - println!(" {}:", "Allowed Sources".bright_cyan().bold()); - for allowed_src in &vs.allowed_sources { - let tag_str = if allowed_src.tag.is_empty() { - "None".to_string() - } else { - allowed_src.tag.clone() - }; - println!( - " Tag {}: {}", - tag_str, - format_number(allowed_src.passes).bright_green() - ); - } - } else { - println!( - " {}: {}", - "Allowed Sources".bright_cyan().bold(), - "None".bright_green() - ); - } - - // Display reals table for this VS - if !vs.reals.is_empty() { - #[derive(Tabled)] - struct RealStatsRow { - #[tabled(rename = "Real")] - real: String, - #[tabled(rename = "Packets")] - packets: String, - #[tabled(rename = "Bytes")] - bytes: String, - #[tabled(rename = "Created Sessions")] - sessions: String, - #[tabled(rename = "Disabled Pkts")] - disabled: String, - #[tabled(rename = "OPS Pkts")] - ops: String, - #[tabled(rename = "ICMP Pkts")] - error_icmp: String, - } - - let real_rows: Vec = vs - .reals - .iter() - .filter_map(|real| { - real.real.as_ref().and_then(|real_id| { - real_id.real.as_ref().and_then(|rel_real| { - opt_addr_to_ip(&rel_real.ip).ok().map(|real_ip| { - let s = real.stats.as_ref(); - RealStatsRow { - real: format_real(real_ip, rel_real.port as u16), - packets: s - .map(|s| format_number(s.packets)) - .unwrap_or_else(|| "0".to_string()), - bytes: s - .map(|s| format_bytes(s.bytes)) - .unwrap_or_else(|| "0 B".to_string()), - sessions: s - .map(|s| format_number(s.created_sessions)) - .unwrap_or_else(|| "0".to_string()), - disabled: s - .map(|s| format_number(s.packets_real_disabled)) - .unwrap_or_else(|| "0".to_string()), - ops: s - .map(|s| format_number(s.ops_packets)) - .unwrap_or_else(|| "0".to_string()), - error_icmp: s - .map(|s| format_number(s.error_icmp_packets)) - .unwrap_or_else(|| "0".to_string()), - } - }) - }) - }) - }) - .collect(); - - print_table(real_rows); - } - println!(); - } - } - } - } - } - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowSessions Output -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_sessions( - response: &balancerpb::ShowSessionsResponse, - format: OutputFormat, -) -> Result<(), Box> { - match format { - OutputFormat::Json => { - let json = json_output::convert_show_sessions(response); - println!("{}", serde_json::to_string(&json)?); - } - OutputFormat::Tree => print_show_sessions_tree(response)?, - OutputFormat::Table => print_show_sessions_table(response)?, - } - Ok(()) -} - -fn print_show_sessions_tree(response: &balancerpb::ShowSessionsResponse) -> Result<(), Box> { - let mut tree = TreeBuilder::new("Active Sessions".to_string()); - - tree.begin_child(format!("Config: {}", response.name)); - - tree.add_empty_child(format!( - "Total Sessions: {}", - format_number(response.sessions.len() as u64) - )); - - for (idx, session) in response.sessions.iter().enumerate() { - if let (Ok(client), Some(vs_id), Some(real_id)) = - (opt_addr_to_ip(&session.client_addr), &session.vs_id, &session.real_id) - { - if let (Ok(vs_ip), Some(rel_real)) = (opt_addr_to_ip(&vs_id.addr), &real_id.real) { - if let Ok(real_ip) = opt_addr_to_ip(&rel_real.ip) { - tree.begin_child(format!("[{}]", idx).cyan().to_string()); - tree.add_empty_child(format!( - "Client: {}", - format_ip_port(client, session.client_port as u16) - )); - tree.add_empty_child(format!("VS: {}", format_vs(vs_ip, vs_id.port, vs_id.proto))); - tree.add_empty_child(format!("Real: {}", format_real(real_ip, rel_real.port as u16))); - tree.add_empty_child(format!( - "Created: {}", - format_timestamp(session.create_timestamp.as_ref()) - )); - tree.add_empty_child(format!( - "Last Packet: {}", - format_timestamp(session.last_packet_timestamp.as_ref()) - )); - tree.add_empty_child(format!("Timeout: {}", format_duration(session.timeout.as_ref()))); - tree.end_child(); - } - } - } - } - - tree.end_child(); - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_show_sessions_table(response: &balancerpb::ShowSessionsResponse) -> Result<(), Box> { - // Print header - let subtitle = Some(format!("Config: {}", response.name)); - print_boxed_header("BALANCER SESSIONS", subtitle.as_deref()); - - println!(); - println!( - " Total Sessions: {}", - format_number(response.sessions.len() as u64).bright_green() - ); - - if !response.sessions.is_empty() { - #[derive(Tabled)] - struct SessionRow { - #[tabled(rename = "Client")] - client: String, - #[tabled(rename = "VS")] - vs: String, - #[tabled(rename = "Real")] - real: String, - #[tabled(rename = "Proto")] - proto: String, - #[tabled(rename = "Created At")] - created_at: String, - #[tabled(rename = "Last Packet")] - last_packet: String, - #[tabled(rename = "Timeout")] - timeout: String, - } - - let rows: Vec = response - .sessions - .iter() - .filter_map(|session| { - if let (Ok(client_ip), Some(vs_id), Some(real_id)) = - (opt_addr_to_ip(&session.client_addr), &session.vs_id, &session.real_id) - { - if let (Ok(vs_ip), Some(rel_real)) = (opt_addr_to_ip(&vs_id.addr), &real_id.real) { - if let Ok(real_ip) = opt_addr_to_ip(&rel_real.ip) { - return Some(SessionRow { - client: format_ip_port(client_ip, session.client_port as u16), - vs: format_ip_port(vs_ip, vs_id.port as u16), - real: format_real(real_ip, rel_real.port as u16), - proto: proto_to_string(vs_id.proto), - created_at: format_timestamp(session.create_timestamp.as_ref()), - last_packet: format_timestamp(session.last_packet_timestamp.as_ref()), - timeout: format_duration(session.timeout.as_ref()), - }); - } - } - } - None - }) - .collect(); - - print_table(rows); - } - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowGraph Output -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_graph(response: &balancerpb::ShowGraphResponse, format: OutputFormat) -> Result<(), Box> { - match format { - OutputFormat::Json => { - let json = json_output::convert_show_graph(response); - println!("{}", serde_json::to_string(&json)?); - } - OutputFormat::Tree => print_show_graph_tree(response)?, - OutputFormat::Table => print_show_graph_table(response)?, - } - Ok(()) -} - -fn print_show_graph_tree(response: &balancerpb::ShowGraphResponse) -> Result<(), Box> { - let mut tree: TreeBuilder = TreeBuilder::new("Balancer Graph".to_string()); - - tree.begin_child(format!("Config: {}", response.name)); - - if let Some(graph) = &response.graph { - if !graph.virtual_services.is_empty() { - tree.begin_child(format!("Virtual Services ({})", graph.virtual_services.len())); - - for (vs_idx, vs) in graph.virtual_services.iter().enumerate() { - if let Some(vs_id) = &vs.identifier { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - tree.begin_child(format!("[{}]", vs_idx).cyan().to_string()); - tree.add_empty_child(format!("VS: {}", format_vs(ip, vs_id.port, vs_id.proto))); - - if !vs.reals.is_empty() { - tree.begin_child(format!("Reals ({})", vs.reals.len())); - for (real_idx, real) in vs.reals.iter().enumerate() { - if let Some(real_id) = &real.identifier { - if let Ok(real_ip) = opt_addr_to_ip(&real_id.ip) { - let status = if real.enabled { "enabled" } else { "disabled" }; - tree.begin_child(format!("[{}]", real_idx).cyan().to_string()); - tree.add_empty_child(format!( - "Real: {}", - format_real(real_ip, real_id.port as u16) - )); - tree.add_empty_child(format!("Weight: {}", real.weight)); - tree.add_empty_child(format!("Effective Weight: {}", real.effective_weight)); - tree.add_empty_child(format!("Status: {}", status)); - tree.end_child(); - } - } - } - tree.end_child(); - } - tree.end_child(); - } - } - } - tree.end_child(); - } - } - - tree.end_child(); - - let tree = tree.build(); - ptree::print_tree(&tree)?; - Ok(()) -} - -fn print_show_graph_table(response: &balancerpb::ShowGraphResponse) -> Result<(), Box> { - // Print header - let subtitle = Some(format!("Config: {}", response.name)); - print_boxed_header("BALANCER GRAPH", subtitle.as_deref()); - - println!(); - - if let Some(graph) = &response.graph { - // Display each VS with its reals in a hierarchical format (similar to info - // output) - if !graph.virtual_services.is_empty() { - for vs in &graph.virtual_services { - if let Some(vs_id) = &vs.identifier { - if let Ok(vs_ip) = opt_addr_to_ip(&vs_id.addr) { - println!( - "{}:", - format!("VS {}", format_vs(vs_ip, vs_id.port, vs_id.proto)) - .bright_yellow() - .bold() - ); - - if !vs.reals.is_empty() { - #[derive(Tabled)] - struct RealGraphRow { - #[tabled(rename = "Real")] - real: String, - #[tabled(rename = "Weight")] - weight: String, - #[tabled(rename = "Effective Weight")] - effective_weight: String, - #[tabled(rename = "Status")] - status: String, - } - - let rows: Vec = vs - .reals - .iter() - .filter_map(|real| { - real.identifier.as_ref().and_then(|real_id| { - opt_addr_to_ip(&real_id.ip).ok().map(|real_ip| RealGraphRow { - real: format_real(real_ip, real_id.port as u16), - weight: real.weight.to_string(), - effective_weight: real.effective_weight.to_string(), - status: if real.enabled { - "enabled".to_string() - } else { - "disabled".to_string() - }, - }) - }) - }) - .collect(); - - print_table(rows); - } - println!(); - } - } - } - } - } - - Ok(()) -} - -//////////////////////////////////////////////////////////////////////////////// -// ShowInspect Output -//////////////////////////////////////////////////////////////////////////////// - -pub fn print_show_inspect( - response: &balancerpb::ShowInspectResponse, - format: InspectOutputFormat, -) -> Result<(), Box> { - match format { - InspectOutputFormat::Json => print_show_inspect_json(response), - InspectOutputFormat::Normal => print_show_inspect_normal(response), - InspectOutputFormat::Detail => print_show_inspect_detail(response), - } -} - -fn print_show_inspect_json(response: &balancerpb::ShowInspectResponse) -> Result<(), Box> { - let json = json_output::convert_show_inspect(response); - println!("{}", serde_json::to_string(&json)?); // Compact, not pretty - Ok(()) -} - -fn print_show_inspect_normal(response: &balancerpb::ShowInspectResponse) -> Result<(), Box> { - let inspect = match &response.inspect { - Some(i) => i, - None => return Ok(()), - }; - - // Print header - print_boxed_header("BALANCER MEMORY INSPECTION", None); - println!(); - - // Agent-level memory - let usage_percent = if inspect.memory_limit > 0 { - (inspect.memory_usage as f64 / inspect.memory_limit as f64) * 100.0 - } else { - 0.0 - }; - - println!("{}", "Agent Memory:".bright_cyan().bold()); - println!(" Limit: {}", format_bytes(inspect.memory_limit).bright_green()); - println!( - " Usage: {} ({:.1}%)", - format_bytes(inspect.memory_usage).bright_green(), - usage_percent - ); - println!(); - - // Per-balancer information - for (balancer_idx, balancer) in inspect.balancers.iter().enumerate() { - if balancer_idx > 0 { - println!("{}", "─".repeat(80).bright_black()); - println!(); - } - - println!( - "{} {}", - format!("Balancer \"{}\":", balancer.name).bright_yellow().bold(), - format_bytes(balancer.total_usage).bright_green() - ); - - // Packet handler breakdown - if let Some(ph) = &balancer.packet_handler_inspect { - let ph_percent = if balancer.total_usage > 0 { - (ph.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "Packet Handler:".bright_cyan().bold(), - format_bytes(ph.total_usage).bright_green(), - ph_percent - ); - - // IPv4 section - if let Some(ipv4) = &ph.vs_ipv4_inspect { - let ipv4_percent = if balancer.total_usage > 0 { - (ipv4.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "IPv4 VS Section:".bright_white(), - format_bytes(ipv4.total_usage).bright_green(), - ipv4_percent - ); - let matcher_pct = if balancer.total_usage > 0 { - (ipv4.matcher_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Matcher: {} ({:.1}%)", - format_bytes(ipv4.matcher_usage).bright_green(), - matcher_pct - ); - let announce_pct = if balancer.total_usage > 0 { - (ipv4.announce_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Announce: {} ({:.1}%)", - format_bytes(ipv4.announce_usage).bright_green(), - announce_pct - ); - let index_pct = if balancer.total_usage > 0 { - (ipv4.index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Index: {} ({:.1}%)", - format_bytes(ipv4.index_usage).bright_green(), - index_pct - ); - } - - // IPv6 section - if let Some(ipv6) = &ph.vs_ipv6_inspect { - let ipv6_percent = if balancer.total_usage > 0 { - (ipv6.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "IPv6 VS Section:".bright_white(), - format_bytes(ipv6.total_usage).bright_green(), - ipv6_percent - ); - let matcher_pct = if balancer.total_usage > 0 { - (ipv6.matcher_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Matcher: {} ({:.1}%)", - format_bytes(ipv6.matcher_usage).bright_green(), - matcher_pct - ); - let announce_pct = if balancer.total_usage > 0 { - (ipv6.announce_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Announce: {} ({:.1}%)", - format_bytes(ipv6.announce_usage).bright_green(), - announce_pct - ); - let index_pct = if balancer.total_usage > 0 { - (ipv6.index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Index: {} ({:.1}%)", - format_bytes(ipv6.index_usage).bright_green(), - index_pct - ); - } - - // Other packet handler components - println!(); - let vs_index_pct = if balancer.total_usage > 0 { - (ph.vs_index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " VS Index: {} ({:.1}%)", - format_bytes(ph.vs_index_usage).bright_green(), - vs_index_pct - ); - let reals_index_pct = if balancer.total_usage > 0 { - (ph.reals_index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Reals Index: {} ({:.1}%)", - format_bytes(ph.reals_index_usage).bright_green(), - reals_index_pct - ); - let counters_pct = if balancer.total_usage > 0 { - (ph.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(ph.counters_usage).bright_green(), - counters_pct - ); - let decap_pct = if balancer.total_usage > 0 { - (ph.decap_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Decap: {} ({:.1}%)", - format_bytes(ph.decap_usage).bright_green(), - decap_pct - ); - } - - // State breakdown - if let Some(state) = &balancer.state_inspect { - let state_percent = if balancer.total_usage > 0 { - (state.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "State:".bright_cyan().bold(), - format_bytes(state.total_usage).bright_green(), - state_percent - ); - let session_pct = if balancer.total_usage > 0 { - (state.session_table_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Session Table: {} ({:.1}%)", - format_bytes(state.session_table_usage).bright_green(), - session_pct - ); - } - - // Other usage - let other_percent = if balancer.total_usage > 0 { - (balancer.other_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "Other:".bright_cyan().bold(), - format_bytes(balancer.other_usage).cyan(), - other_percent - ); - println!(); - } - - Ok(()) -} - -fn print_show_inspect_detail(response: &balancerpb::ShowInspectResponse) -> Result<(), Box> { - let inspect = match &response.inspect { - Some(i) => i, - None => return Ok(()), - }; - - // Print header - print_boxed_header("BALANCER MEMORY INSPECTION (DETAILED)", None); - println!(); - - // Agent-level memory - let usage_percent = if inspect.memory_limit > 0 { - (inspect.memory_usage as f64 / inspect.memory_limit as f64) * 100.0 - } else { - 0.0 - }; - - println!("{}", "Agent Memory:".bright_cyan().bold()); - println!(" Limit: {}", format_bytes(inspect.memory_limit).bright_green()); - println!( - " Usage: {} ({:.1}%)", - format_bytes(inspect.memory_usage).bright_green(), - usage_percent - ); - println!(); - - // Per-balancer information - for (balancer_idx, balancer) in inspect.balancers.iter().enumerate() { - if balancer_idx > 0 { - println!("{}", "─".repeat(80).bright_black()); - println!(); - } - - println!( - "{} {}", - format!("Balancer \"{}\":", balancer.name).bright_yellow().bold(), - format_bytes(balancer.total_usage).bright_green() - ); - - // Packet handler breakdown - if let Some(ph) = &balancer.packet_handler_inspect { - let ph_percent = if balancer.total_usage > 0 { - (ph.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "Packet Handler:".bright_cyan().bold(), - format_bytes(ph.total_usage).bright_green(), - ph_percent - ); - - // IPv4 section with VS details - if let Some(ipv4) = &ph.vs_ipv4_inspect { - let ipv4_percent = if balancer.total_usage > 0 { - (ipv4.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "IPv4 VS Section:".bright_white(), - format_bytes(ipv4.total_usage).bright_green(), - ipv4_percent - ); - let matcher_pct = if balancer.total_usage > 0 { - (ipv4.matcher_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Matcher: {} ({:.1}%)", - format_bytes(ipv4.matcher_usage).bright_green(), - matcher_pct - ); - let announce_pct = if balancer.total_usage > 0 { - (ipv4.announce_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Announce: {} ({:.1}%)", - format_bytes(ipv4.announce_usage).bright_green(), - announce_pct - ); - let index_pct = if balancer.total_usage > 0 { - (ipv4.index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Index: {} ({:.1}%)", - format_bytes(ipv4.index_usage).bright_green(), - index_pct - ); - - // Per-VS breakdown - if !ipv4.vs_inspects.is_empty() { - println!(); - println!( - " {} ({}):", - "Virtual Services".bright_white(), - ipv4.vs_inspects.len() - ); - for (idx, vs_inspect) in ipv4.vs_inspects.iter().enumerate() { - if let Some(vs_id) = &vs_inspect.identifier { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - if let Some(inspect) = &vs_inspect.inspect { - let vs_percent = if balancer.total_usage > 0 { - (inspect.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " [{}] {} {} ({:.1}%)", - idx, - format_vs(ip, vs_id.port, vs_id.proto), - format_bytes(inspect.total_usage).bright_green(), - vs_percent - ); - let acl_pct = if balancer.total_usage > 0 { - (inspect.acl_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " ACL: {} ({:.1}%)", - format_bytes(inspect.acl_usage).bright_green(), - acl_pct - ); - let ring_pct = if balancer.total_usage > 0 { - (inspect.ring_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Ring: {} ({:.1}%)", - format_bytes(inspect.ring_usage).bright_green(), - ring_pct - ); - let counters_pct = if balancer.total_usage > 0 { - (inspect.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(inspect.counters_usage).bright_green(), - counters_pct - ); - if let Some(reals) = &inspect.reals_usage { - let reals_pct = if balancer.total_usage > 0 { - (reals.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Reals: {} ({:.1}%)", - format_bytes(reals.total_usage).bright_green(), - reals_pct - ); - let reals_counters_pct = if balancer.total_usage > 0 { - (reals.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(reals.counters_usage).bright_green(), - reals_counters_pct - ); - let reals_data_pct = if balancer.total_usage > 0 { - (reals.data_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Data: {} ({:.1}%)", - format_bytes(reals.data_usage).bright_green(), - reals_data_pct - ); - } - let other_pct = if balancer.total_usage > 0 { - (inspect.other_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Other: {} ({:.1}%)", - format_bytes(inspect.other_usage).bright_green(), - other_pct - ); - } - } - } - } - } - } - - // IPv6 section with VS details - if let Some(ipv6) = &ph.vs_ipv6_inspect { - let ipv6_percent = if balancer.total_usage > 0 { - (ipv6.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "IPv6 VS Section:".bright_white(), - format_bytes(ipv6.total_usage).bright_green(), - ipv6_percent - ); - let matcher_pct = if balancer.total_usage > 0 { - (ipv6.matcher_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Matcher: {} ({:.1}%)", - format_bytes(ipv6.matcher_usage).bright_green(), - matcher_pct - ); - let announce_pct = if balancer.total_usage > 0 { - (ipv6.announce_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Announce: {} ({:.1}%)", - format_bytes(ipv6.announce_usage).bright_green(), - announce_pct - ); - let index_pct = if balancer.total_usage > 0 { - (ipv6.index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Index: {} ({:.1}%)", - format_bytes(ipv6.index_usage).bright_green(), - index_pct - ); - - // Per-VS breakdown - if !ipv6.vs_inspects.is_empty() { - println!(); - println!( - " {} ({}):", - "Virtual Services".bright_white(), - ipv6.vs_inspects.len() - ); - for (idx, vs_inspect) in ipv6.vs_inspects.iter().enumerate() { - if let Some(vs_id) = &vs_inspect.identifier { - if let Ok(ip) = opt_addr_to_ip(&vs_id.addr) { - if let Some(inspect) = &vs_inspect.inspect { - let vs_percent = if balancer.total_usage > 0 { - (inspect.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " [{}] {} {} ({:.1}%)", - idx, - format_vs(ip, vs_id.port, vs_id.proto), - format_bytes(inspect.total_usage).bright_green(), - vs_percent - ); - let acl_pct = if balancer.total_usage > 0 { - (inspect.acl_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " ACL: {} ({:.1}%)", - format_bytes(inspect.acl_usage).bright_green(), - acl_pct - ); - let ring_pct = if balancer.total_usage > 0 { - (inspect.ring_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Ring: {} ({:.1}%)", - format_bytes(inspect.ring_usage).bright_green(), - ring_pct - ); - let counters_pct = if balancer.total_usage > 0 { - (inspect.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(inspect.counters_usage).bright_green(), - counters_pct - ); - if let Some(reals) = &inspect.reals_usage { - let reals_pct = if balancer.total_usage > 0 { - (reals.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Reals: {} ({:.1}%)", - format_bytes(reals.total_usage).bright_green(), - reals_pct - ); - let reals_counters_pct = if balancer.total_usage > 0 { - (reals.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(reals.counters_usage).bright_green(), - reals_counters_pct - ); - let reals_data_pct = if balancer.total_usage > 0 { - (reals.data_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Data: {} ({:.1}%)", - format_bytes(reals.data_usage).bright_green(), - reals_data_pct - ); - } - let other_pct = if balancer.total_usage > 0 { - (inspect.other_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Other: {} ({:.1}%)", - format_bytes(inspect.other_usage).bright_green(), - other_pct - ); - } - } - } - } - } - } - - // Other packet handler components - println!(); - let vs_index_pct = if balancer.total_usage > 0 { - (ph.vs_index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " VS Index: {} ({:.1}%)", - format_bytes(ph.vs_index_usage).bright_green(), - vs_index_pct - ); - let reals_index_pct = if balancer.total_usage > 0 { - (ph.reals_index_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Reals Index: {} ({:.1}%)", - format_bytes(ph.reals_index_usage).bright_green(), - reals_index_pct - ); - let counters_pct = if balancer.total_usage > 0 { - (ph.counters_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Counters: {} ({:.1}%)", - format_bytes(ph.counters_usage).bright_green(), - counters_pct - ); - let decap_pct = if balancer.total_usage > 0 { - (ph.decap_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Decap: {} ({:.1}%)", - format_bytes(ph.decap_usage).bright_green(), - decap_pct - ); - } - - // State breakdown - if let Some(state) = &balancer.state_inspect { - let state_percent = if balancer.total_usage > 0 { - (state.total_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "State:".bright_cyan().bold(), - format_bytes(state.total_usage).bright_green(), - state_percent - ); - let session_pct = if balancer.total_usage > 0 { - (state.session_table_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!( - " Session Table: {} ({:.1}%)", - format_bytes(state.session_table_usage).bright_green(), - session_pct - ); - } - - // Other usage - let other_percent = if balancer.total_usage > 0 { - (balancer.other_usage as f64 / balancer.total_usage as f64) * 100.0 - } else { - 0.0 - }; - println!(); - println!( - " {} {} ({:.1}%)", - "Other:".bright_cyan().bold(), - format_bytes(balancer.other_usage).cyan(), - other_percent - ); - println!(); - } - - Ok(()) -} diff --git a/modules/balancer/cli/src/rpc.rs b/modules/balancer/cli/src/rpc.rs deleted file mode 100644 index b3413dd55..000000000 --- a/modules/balancer/cli/src/rpc.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! gRPC proto module definitions - -#[allow(clippy::all, non_snake_case)] -pub mod commonpb { - tonic::include_proto!("commonpb"); -} - -#[allow(clippy::all, non_snake_case)] -pub mod balancerpb { - tonic::include_proto!("balancerpb"); -} - -pub use balancerpb::balancer_service_client::BalancerServiceClient; diff --git a/modules/balancer/cli/src/service.rs b/modules/balancer/cli/src/service.rs index 5c43c6751..690bb7564 100644 --- a/modules/balancer/cli/src/service.rs +++ b/modules/balancer/cli/src/service.rs @@ -1,329 +1,327 @@ -//! gRPC service client implementation - use std::error::Error; +use ptree::TreeBuilder; use tonic::codec::CompressionEncoding; +use yanet_cli_balancer::balancerpb::{ + self, FlushRealsRequest, GetConfigRequest, GetMetricsRequest, GetStateRequest, ListBalancersRequest, + ListSessionsRequest, PacketHandlerRef, RealUpdate, SetConfigRequest, UpdateRealsRequest, + balancer_client::BalancerClient, +}; use ync::client::{ConnectionArgs, LayeredChannel}; use crate::{ - cmd::*, - entities::{BalancerConfig, VsListConfig}, - output, - rpc::{BalancerServiceClient, balancerpb}, + ConfigCmd, DisableRealCmd, EnableRealCmd, FlushRealsCmd, MetricsCmd, ModeCmd, SessionsCmd, ShowCmd, UpdateCmd, + config::BalancerConfig, display, ip_to_bytes, parse_vs_identifier, }; -//////////////////////////////////////////////////////////////////////////////// -// Logging macros with custom target -//////////////////////////////////////////////////////////////////////////////// - -macro_rules! info { - ($($arg:tt)*) => { - log::info!(target: "yanet_cli_balancer", $($arg)*) - }; -} - -//////////////////////////////////////////////////////////////////////////////// -// Service -//////////////////////////////////////////////////////////////////////////////// - pub struct BalancerService { - client: BalancerServiceClient, + client: BalancerClient, } impl BalancerService { - /// Connect to the gRPC endpoint pub async fn connect(connection: &ConnectionArgs) -> Result> { let channel = ync::client::connect(connection).await?; - let client = BalancerServiceClient::new(channel) + let client = BalancerClient::new(channel) .send_compressed(CompressionEncoding::Gzip) .accept_compressed(CompressionEncoding::Gzip); Ok(Self { client }) } - /// Handle the command - pub async fn handle_cmd(&mut self, mode: Mode) -> Result<(), Box> { - log::trace!("Handling command: {:?}", mode); - + pub async fn handle(&mut self, mode: ModeCmd) -> Result<(), Box> { match mode { - Mode::Update(cmd) => self.update_config(cmd).await, - Mode::Reals(cmd) => self.handle_reals(cmd).await, - Mode::Vs(cmd) => self.handle_vs(cmd).await, - Mode::Config(cmd) => self.config(cmd).await, - Mode::List(cmd) => self.list(cmd).await, - Mode::Stats(cmd) => self.stats(cmd).await, - Mode::Info(cmd) => self.info(cmd).await, - Mode::Sessions(cmd) => self.sessions(cmd).await, - Mode::Graph(cmd) => self.graph(cmd).await, - Mode::Inspect(cmd) => self.inspect(cmd).await, - Mode::Metrics(cmd) => self.metrics(cmd).await, + ModeCmd::Update(cmd) => self.update(cmd).await, + ModeCmd::List => self.list().await, + ModeCmd::Config(cmd) => self.config(cmd).await, + ModeCmd::Show(cmd) => self.show(cmd).await, + ModeCmd::Sessions(cmd) => self.sessions(cmd).await, + ModeCmd::Metrics(cmd) => self.metrics(cmd).await, + ModeCmd::Reals(cmd) => match cmd.mode { + crate::RealsMode::Enable(cmd) => self.enable_real(cmd).await, + crate::RealsMode::Disable(cmd) => self.disable_real(cmd).await, + crate::RealsMode::Flush(cmd) => self.flush_reals(cmd).await, + }, } } - /// Update balancer configuration - async fn update_config(&mut self, cmd: UpdateCmd) -> Result<(), Box> { - info!("Loading configuration from: {}", cmd.config); + async fn update(&mut self, cmd: UpdateCmd) -> Result<(), Box> { + let yaml_config = BalancerConfig::from_yaml_file(&cmd.config)?; + let proto_config: balancerpb::BalancerConfig = yaml_config.try_into()?; - let config = BalancerConfig::from_yaml_file(&cmd.config)?; - let balancer_config: balancerpb::BalancerConfig = config.try_into()?; - - let request = balancerpb::UpdateConfigRequest { + let request = SetConfigRequest { name: cmd.name.clone(), - config: Some(balancer_config), + config: Some(proto_config), }; - - log::debug!("Sending UpdateConfig request"); - let response = self.client.update_config(request).await?.into_inner(); - - info!("Successfully updated configuration for '{}'", cmd.name); - - // Display update information if available - if let Some(update_info) = &response.update_info { - output::print_update_info(update_info, cmd.format.to_format())?; + log::trace!("set config request: {request:?}"); + + let response = self.client.set_config(request).await?.into_inner(); + log::debug!("set config response: {response:?}"); + + log::info!("Balancer '{}' updated successfully", response.name); + if let Some(reuse) = &response.reuse { + log::info!( + "Reuse: ipv4_vs_matcher={}, ipv6_vs_matcher={}, ipv4_decap={}, ipv6_decap={}", + reuse.ipv4_vs_matcher_reused, + reuse.ipv6_vs_matcher_reused, + reuse.ipv4_decap_filter_reused, + reuse.ipv6_decap_filter_reused, + ); + for vs_reuse in &reuse.vs_reuse_reports { + if let Some(id) = &vs_reuse.vs_identifier { + let ip = crate::bytes_to_ip(&id.addr) + .map(|a| a.to_string()) + .unwrap_or_else(|_| "?".to_string()); + let proto = match balancerpb::TransportProto::try_from(id.proto) { + Ok(balancerpb::TransportProto::Tcp) => "tcp", + Ok(balancerpb::TransportProto::Udp) => "udp", + _ => "?", + }; + log::info!( + " VS {}:{}/{}: acl_reused={}, selector_reused={}", + ip, + id.port, + proto, + vs_reuse.acl_reused, + vs_reuse.selector_reused, + ); + } + } + } + if response.session_table_capacity > 0 { + log::info!("Session table capacity: {}", response.session_table_capacity); } Ok(()) } - /// Handle reals commands - async fn handle_reals(&mut self, cmd: RealsCmd) -> Result<(), Box> { - match cmd.mode { - RealsMode::Enable(cmd) => self.enable_real(cmd).await, - RealsMode::Disable(cmd) => self.disable_real(cmd).await, - RealsMode::Flush(cmd) => self.flush_real_updates(cmd).await, - } - } - - /// Enable a real server - async fn enable_real(&mut self, cmd: EnableRealCmd) -> Result<(), Box> { - let flush = cmd.flush; - let name = cmd.name.clone(); - - let name_display = name.as_deref().unwrap_or(""); - info!( - "Enabling {} real(s) of VS {} for '{}'", - cmd.reals.len(), - cmd.vs, - name_display - ); - - let request: balancerpb::UpdateRealsRequest = cmd.try_into()?; + async fn list(&mut self) -> Result<(), Box> { + let request = ListBalancersRequest {}; + log::trace!("list balancers request: {request:?}"); - log::debug!("Sending UpdateReals request"); - self.client.update_reals(request).await?; + let response = self.client.list_balancers(request).await?.into_inner(); + log::debug!("list balancers response: {response:?}"); - info!("Successfully buffered real enable"); - - // If flush flag is set, immediately flush the updates - if flush { - let name_display = name.as_deref().unwrap_or(""); - info!("Flushing buffered real updates for '{}'", name_display); - let flush_request = balancerpb::FlushRealUpdatesRequest { name }; - let response = self.client.flush_real_updates(flush_request).await?.into_inner(); - info!("Successfully flushed {} update(s)", response.updates_flushed); + let mut tree = TreeBuilder::new("Balancers".to_string()); + for name in &response.names { + tree.add_empty_child(name.clone()); } + let tree = tree.build(); + ptree::print_tree(&tree)?; Ok(()) } - /// Disable a real server - async fn disable_real(&mut self, cmd: DisableRealCmd) -> Result<(), Box> { - let flush = cmd.flush; - let name = cmd.name.clone(); - let reals_count = cmd.reals.len(); - - let name_display = name.as_deref().unwrap_or(""); - info!( - "Disabling {} real(s) of VS {} for '{}'", - reals_count, cmd.vs, name_display - ); - - let request: balancerpb::UpdateRealsRequest = cmd.try_into()?; - - log::debug!("Sending UpdateReals request"); - self.client.update_reals(request).await?; + async fn config(&mut self, cmd: ConfigCmd) -> Result<(), Box> { + let request = GetConfigRequest { name: cmd.name }; + log::trace!("get config request: {request:?}"); - info!("Successfully buffered real disable"); + let response = self.client.get_config(request).await?.into_inner(); + log::debug!("get config response: {response:?}"); - // If flush flag is set, immediately flush the updates - if flush { - info!("Flushing buffered real updates"); - let flush_request = balancerpb::FlushRealUpdatesRequest { name }; - let response = self.client.flush_real_updates(flush_request).await?.into_inner(); - info!("Successfully flushed {} update(s)", response.updates_flushed); - } + let mut json_value = serde_json::to_value(&response)?; + display::prettify_json(&mut json_value); + let yaml = serde_yaml::to_string(&json_value)?; + print!("{yaml}"); Ok(()) } - /// Flush buffered real updates - async fn flush_real_updates(&mut self, cmd: FlushRealUpdatesCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - info!("Flushing buffered real updates for '{}'", name_display); + async fn show(&mut self, cmd: ShowCmd) -> Result<(), Box> { + let needs_table = cmd.needs_table(); + let include_counters = cmd.include_counters(); - let request: balancerpb::FlushRealUpdatesRequest = cmd.into(); + let opts = display::ShowOptions { + stats: cmd.stats || cmd.detail, + acl: cmd.acl || cmd.detail, + peers: cmd.peers || cmd.detail, + decap: cmd.decap || cmd.detail, + }; - log::debug!("Sending FlushRealUpdates request"); - let response = self.client.flush_real_updates(request).await?.into_inner(); + let packet_handler_ref = + if cmd.device.is_some() || cmd.pipeline.is_some() || cmd.function.is_some() || cmd.chain.is_some() { + Some(PacketHandlerRef { + device: cmd.device, + pipeline: cmd.pipeline, + function: cmd.function, + chain: cmd.chain, + }) + } else { + None + }; + + let request = GetStateRequest { + name: cmd.name, + packet_handler_ref, + filter: cmd.filter.to_proto(), + include_counters, + }; + log::trace!("get state request: {request:?}"); - info!("Successfully flushed {} update(s)", response.updates_flushed); - Ok(()) - } + let response = self.client.get_state(request).await?.into_inner(); + log::debug!("get state response: {response:?}"); - /// Show balancer configuration - async fn config(&mut self, cmd: ConfigCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - log::debug!("Fetching configuration for '{}'", name_display); + if response.state.is_empty() { + log::info!("No balancer state found"); + return Ok(()); + } - let request: balancerpb::ShowConfigRequest = (&cmd).into(); - let response = self.client.show_config(request).await?.into_inner(); + if needs_table { + display::print_table_view(&response.state, &opts); + } else { + display::print_compact(&response.state[0]); + } - output::print_show_config(&response, cmd.format.to_format())?; Ok(()) } - /// List all balancer configurations - async fn list(&mut self, cmd: ListCmd) -> Result<(), Box> { - log::debug!("Fetching all configurations"); - - let request = balancerpb::ListConfigsRequest {}; - let response = self.client.list_configs(request).await?.into_inner(); - - output::print_list_configs(&response, cmd.format.to_format())?; - Ok(()) - } + async fn sessions(&mut self, cmd: SessionsCmd) -> Result<(), Box> { + let request = ListSessionsRequest { + name: cmd.name, + filter: cmd.filter.to_proto(), + }; + log::trace!("list sessions request: {request:?}"); - /// Show configuration statistics - async fn stats(&mut self, cmd: StatsCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - log::debug!("Fetching statistics for '{}'", name_display); + let mut stream = self.client.list_sessions(request).await?.into_inner(); - let request: balancerpb::ShowStatsRequest = (&cmd).into(); - let response = self.client.show_stats(request).await?.into_inner(); + display::print_sessions_header(); + while let Some(session) = stream.message().await? { + display::print_session(&session); + } - output::print_show_stats(&response, cmd.format.to_format())?; Ok(()) } - /// Show state information - async fn info(&mut self, cmd: InfoCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - log::debug!("Fetching state info for '{}'", name_display); + async fn metrics(&mut self, _cmd: MetricsCmd) -> Result<(), Box> { + let request = GetMetricsRequest {}; + log::trace!("get metrics request: {request:?}"); - let request: balancerpb::ShowInfoRequest = (&cmd).into(); - let response = self.client.show_info(request).await?.into_inner(); - - output::print_show_info(&response, cmd.format.to_format())?; - Ok(()) - } - - /// Show sessions information - async fn sessions(&mut self, cmd: SessionsCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - log::debug!("Fetching sessions info for '{}'", name_display); + let response = self.client.get_metrics(request).await?.into_inner(); + log::debug!("get metrics response: {response:?}"); - let request: balancerpb::ShowSessionsRequest = (&cmd).into(); - let response = self.client.show_sessions(request).await?.into_inner(); + let mut json_value = serde_json::to_value(&response)?; + display::prettify_json(&mut json_value); + let json = serde_json::to_string(&json_value)?; + println!("{json}"); - output::print_show_sessions(&response, cmd.format.to_format())?; Ok(()) } - async fn graph(&mut self, cmd: GraphCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - log::debug!("Fetching graph info for '{}'", name_display); - - let request: balancerpb::ShowGraphRequest = (&cmd).into(); - let response = self.client.show_graph(request).await?.into_inner(); - - output::print_show_graph(&response, cmd.format.to_format())?; - Ok(()) - } + async fn enable_real(&mut self, cmd: EnableRealCmd) -> Result<(), Box> { + let (ip, port, proto) = parse_vs_identifier(&cmd.vs)?; + let vs_id = balancerpb::VsIdentifier { + addr: ip_to_bytes(ip), + port: port as u32, + proto: proto as i32, + }; - /// Show memory usage inspection - async fn inspect(&mut self, cmd: InspectCmd) -> Result<(), Box> { - log::debug!("Fetching memory inspection"); + let updates: Vec = cmd + .reals + .iter() + .map(|r| { + let real_ip: std::net::IpAddr = r.parse().map_err(|e| format!("invalid real IP '{}': {}", r, e))?; + Ok(RealUpdate { + real_id: Some(balancerpb::RealIdentifier { + vs: Some(vs_id.clone()), + real: Some(balancerpb::RelativeRealIdentifier { ip: ip_to_bytes(real_ip), port: 0 }), + }), + enable: Some(true), + weight: cmd.weight, + }) + }) + .collect::, String>>()?; + + let request = UpdateRealsRequest { + name: cmd.name, + updates, + buffer: !cmd.flush, + }; + log::trace!("update reals request: {request:?}"); - let request: balancerpb::ShowInspectRequest = (&cmd).into(); - let response = self.client.show_inspect(request).await?.into_inner(); + let response = self.client.update_reals(request).await?.into_inner(); + log::debug!("update reals response: {response:?}"); - output::print_show_inspect(&response, cmd.format.to_format())?; - Ok(()) - } + if response.updates_buffered > 0 { + log::info!( + "Balancer '{}': {} updates buffered (use 'reals flush' to apply)", + response.name, + response.updates_buffered + ); + } + if response.updates_applied > 0 { + log::info!( + "Balancer '{}': {} updates applied", + response.name, + response.updates_applied + ); + } - /// Metrics - async fn metrics(&mut self, cmd: MetricsCmd) -> Result<(), Box> { - log::debug!("Fetching metrics"); - let request: balancerpb::GetMetricsRequest = (&cmd).into(); - let response = self.client.get_metrics(request).await?.into_inner(); - let s = serde_json::to_string(&response)?; - println!("{}", s); Ok(()) } - /// Handle VS commands - async fn handle_vs(&mut self, cmd: VsCmd) -> Result<(), Box> { - match cmd.mode { - VsMode::Update(cmd) => self.update_vs(cmd).await, - VsMode::Delete(cmd) => self.delete_vs(cmd).await, - } - } - - /// Update virtual services - async fn update_vs(&mut self, cmd: UpdateVsCmd) -> Result<(), Box> { - let name_display = cmd.name.as_deref().unwrap_or(""); - info!("Loading VS configuration from: {}", cmd.config); - - let vs_config = VsListConfig::from_yaml_file(&cmd.config)?; - let vs_count = vs_config.vs.len(); - - // Convert VirtualService entities to protobuf - let vs_list: Result, String> = - vs_config.vs.into_iter().map(TryInto::try_into).collect(); - let vs_list = vs_list?; - - let request = balancerpb::UpdateVsRequest { name: cmd.name.clone(), vs: vs_list }; + async fn disable_real(&mut self, cmd: DisableRealCmd) -> Result<(), Box> { + let (ip, port, proto) = parse_vs_identifier(&cmd.vs)?; + let vs_id = balancerpb::VsIdentifier { + addr: ip_to_bytes(ip), + port: port as u32, + proto: proto as i32, + }; - log::debug!("Sending UpdateVS request for '{}'", name_display); - let response = self.client.update_vs(request).await?.into_inner(); + let updates: Vec = cmd + .reals + .iter() + .map(|r| { + let real_ip: std::net::IpAddr = r.parse().map_err(|e| format!("invalid real IP '{}': {}", r, e))?; + Ok(RealUpdate { + real_id: Some(balancerpb::RealIdentifier { + vs: Some(vs_id.clone()), + real: Some(balancerpb::RelativeRealIdentifier { ip: ip_to_bytes(real_ip), port: 0 }), + }), + enable: Some(false), + weight: None, + }) + }) + .collect::, String>>()?; + + let request = UpdateRealsRequest { + name: cmd.name, + updates, + buffer: !cmd.flush, + }; + log::trace!("update reals request: {request:?}"); - info!( - "Successfully updated {} virtual service(s) for '{}'", - vs_count, response.name - ); + let response = self.client.update_reals(request).await?.into_inner(); + log::debug!("update reals response: {response:?}"); - // Display update information - if let Some(update_info) = &response.info { - output::print_vs_update_info(update_info, cmd.format.to_format(), output::VsOperation::Update)?; + if response.updates_buffered > 0 { + log::info!( + "Balancer '{}': {} updates buffered (use 'reals flush' to apply)", + response.name, + response.updates_buffered + ); + } + if response.updates_applied > 0 { + log::info!( + "Balancer '{}': {} updates applied", + response.name, + response.updates_applied + ); } Ok(()) } - /// Delete virtual services - async fn delete_vs(&mut self, cmd: DeleteVsCmd) -> Result<(), Box> { - // Extract values before moving cmd - let name_for_display = cmd.name.clone(); - let name_display = name_for_display.as_deref().unwrap_or(""); - let vs_count = cmd.vs.len(); - let format = cmd.format.to_format(); + async fn flush_reals(&mut self, cmd: FlushRealsCmd) -> Result<(), Box> { + let request = FlushRealsRequest { name: cmd.name }; + log::trace!("flush reals request: {request:?}"); - info!("Deleting {} virtual service(s) from '{}'", vs_count, name_display); + let response = self.client.flush_reals(request).await?.into_inner(); + log::debug!("flush reals response: {response:?}"); - let request: balancerpb::DeleteVsRequest = cmd.try_into()?; - - log::debug!("Sending DeleteVS request for '{}'", name_display); - let response = self.client.delete_vs(request).await?.into_inner(); - - info!( - "Successfully deleted {} virtual service(s) from '{}'", - vs_count, response.name + log::info!( + "Balancer '{}': {} updates flushed", + response.name, + response.updates_flushed ); - // Display update information - if let Some(update_info) = &response.info { - output::print_vs_update_info(update_info, format, output::VsOperation::Delete)?; - } - Ok(()) } } diff --git a/modules/balancer/controlplane/agent.go b/modules/balancer/controlplane/agent.go new file mode 100644 index 000000000..db54da4a6 --- /dev/null +++ b/modules/balancer/controlplane/agent.go @@ -0,0 +1,87 @@ +package balancer + +import ( + "fmt" + + "github.com/c2h5oh/datasize" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "go.uber.org/zap" +) + +type Agent struct { + agent *yanet.Agent + balancers map[string]*Balancer +} + +func (a *Agent) AsYanetAgent() *yanet.Agent { + return a.agent +} + +func AttachNewAgent( + shm *yanet.SharedMemory, + instanceIdx uint32, + size datasize.ByteSize, +) (*Agent, error) { + agent, err := shm.AgentAttach("balancer", instanceIdx, size) + if err != nil { + return nil, fmt.Errorf("failed to attach balancer agent: %w", err) + } + return &Agent{ + agent: agent, + balancers: make(map[string]*Balancer), + }, nil +} + +func ReattachAgent( + shm *yanet.SharedMemory, + instanceIdx uint32, + size datasize.ByteSize, + log *zap.SugaredLogger, +) (*Agent, error) { + agent, err := shm.AgentReattach("balancer", instanceIdx, size) + if err != nil { + return nil, err + } + + // Restore balancers + balancerAgent := &Agent{ + agent: agent, + balancers: make(map[string]*Balancer), + } + packetHandlers := balancerAgent.list() + for _, ph := range packetHandlers { + balancer := restoreBalancerFromPacketHandler(balancerAgent, ph, log) + name := balancer.handler.name() + balancerAgent.balancers[name] = balancer + } + + return balancerAgent, nil +} + +func (a *Agent) GetBalancer(name string) (*Balancer, bool) { + b, ok := a.balancers[name] + return b, ok +} + +func (a *Agent) PutBalancer(name string, b *Balancer) { + a.balancers[name] = b +} + +func (a *Agent) BalancerNames() []string { + names := make([]string, 0, len(a.balancers)) + for name := range a.balancers { + names = append(names, name) + } + return names +} + +func (a *Agent) Balancers() map[string]*Balancer { + return a.balancers +} + +func (a *Agent) Close() error { + for _, balancer := range a.balancers { + balancer.refresher.Stop() + } + return nil +} diff --git a/modules/balancer/controlplane/api/balancer.c b/modules/balancer/controlplane/api/balancer.c deleted file mode 100644 index c618d3f00..000000000 --- a/modules/balancer/controlplane/api/balancer.c +++ /dev/null @@ -1,416 +0,0 @@ -#include "balancer.h" -#include "api/agent.h" -#include "graph.h" -#include "handler/info.h" -#include "inspect.h" -#include "session.h" -#include "state.h" - -#include "api/counter.h" - -#include - -#include "common/container_of.h" -#include "common/memory.h" -#include "common/memory_address.h" - -#include "lib/controlplane/agent/agent.h" -#include "lib/controlplane/config/cp_module.h" -#include "lib/controlplane/config/zone.h" -#include "lib/controlplane/diag/diag.h" - -#include "handler/handler.h" -#include "handler/inspect.h" -#include "handler/vs.h" -#include "state/session_table.h" -#include "state/state.h" -#include "vs.h" - -#include -#include -#include - -struct balancer_handle {}; - -struct balancer { - struct balancer_handle handle; - struct balancer_state state; - struct packet_handler *handler; - struct diag diag; -}; - -struct balancer * -balancer_handle_deref(struct balancer_handle *handle) { - return container_of(handle, struct balancer, handle); -} - -const char * -balancer_take_error_msg(struct balancer_handle *handle) { - struct balancer *balancer = balancer_handle_deref(handle); - return diag_take_msg(&balancer->diag); -} - -const char * -balancer_name(struct balancer_handle *handle) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - return handler->cp_module.name; -} - -int -balancer_resize_session_table( - struct balancer_handle *handle, size_t new_size, uint32_t now -) { - struct balancer *balancer = balancer_handle_deref(handle); - return DIAG_TRY( - &balancer->diag, - balancer_state_resize_session_table( - &balancer->state, new_size, now - ) - ); -} - -struct balancer_handle * -balancer_create( - struct agent *agent, const char *name, struct balancer_config *config -) { - agent_clean_error(agent); - - struct dp_config *dp_config = ADDR_OF(&agent->dp_config); - - struct memory_context *mctx = &agent->memory_context; - - struct balancer *balancer = - memory_balloc(mctx, sizeof(struct balancer)); - if (balancer == NULL) { - NEW_ERROR("no memory"); - goto error; - } - assert((uintptr_t)balancer % alignof(struct balancer) == 0); - memset(balancer, 0, sizeof(struct balancer)); - - int init_state_result = balancer_state_init( - &balancer->state, - mctx, - dp_config->worker_count, - config->state.table_capacity - ); - if (init_state_result != 0) { - PUSH_ERROR("failed to initialize balancer state"); - memory_bfree(mctx, balancer, sizeof(struct balancer)); - goto error; - } - - struct packet_handler *handler = packet_handler_setup( - agent, name, &config->handler, &balancer->state, NULL, NULL - ); - if (handler == NULL) { - PUSH_ERROR("packet handler"); - balancer_state_free(&balancer->state); - memory_bfree(mctx, balancer, sizeof(struct balancer)); - goto error; - } - - SET_OFFSET_OF(&balancer->handler, handler); - - return &balancer->handle; - -error: - diag_fill(&agent->diag); - - return NULL; -} - -int -balancer_update_packet_handler( - struct balancer_handle *handle, - struct packet_handler_config *config, - struct balancer_update_info *update_info -) { - int ret; - - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *prev_handler = ADDR_OF(&balancer->handler); - - const char *name = prev_handler->cp_module.name; - - struct agent *agent = ADDR_OF(&prev_handler->cp_module.agent); - - // Initialize update_info if provided - if (update_info != NULL) { - memset(update_info, 0, sizeof(*update_info)); - } - - // TODO: pass prev config here - struct packet_handler *handler = packet_handler_setup( - agent, name, config, &balancer->state, prev_handler, update_info - ); - if (handler == NULL) { - PUSH_ERROR("failed to setup packet handler"); - diag_fill(&balancer->diag); - ret = -1; - } else { - diag_reset(&balancer->diag); - SET_OFFSET_OF(&balancer->handler, handler); - packet_handler_free(prev_handler); - ret = 0; - } - - return ret; -} - -int -balancer_update_reals( - struct balancer_handle *handle, size_t count, struct real_update *update -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - return DIAG_TRY( - &balancer->diag, - packet_handler_update_reals(handler, count, update), - "failed to update reals in packet handler" - ); -} - -//////////////////////////////////////////////////////////////////////////////// - -int -balancer_info( - struct balancer_handle *handle, struct balancer_info *info, uint32_t now -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - packet_handler_balancer_info(handler, info, now); - return 0; -} - -void -balancer_active_sessions( - struct balancer_handle *handle, struct balancer_info *info -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - packet_handler_active_sessions(handler, info); -} - -//////////////////////////////////////////////////////////////////////////////// - -int -balancer_stats( - struct balancer_handle *handle, - struct balancer_stats *stats, - struct packet_handler_ref *ref -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - - if (ref->device == NULL) { - NEW_ERROR("device is required"); - goto err; - } - - if (ref->pipeline == NULL) { - NEW_ERROR("pipeline is required"); - goto err; - } - - if (ref->function == NULL) { - NEW_ERROR("function is required"); - goto err; - } - - if (ref->chain == NULL) { - NEW_ERROR("chain is required"); - goto err; - } - - // Reset diagnostics only after all validation passes - diag_reset(&balancer->diag); - - int res = packet_handler_fill_stats(handler, stats, ref); - if (res != 0) { - PUSH_ERROR("invalid balancer reference"); - goto err; - } - - return 0; - -err: - diag_fill(&balancer->diag); - return -1; -} - -//////////////////////////////////////////////////////////////////////////////// - -void -balancer_sessions( - struct balancer_handle *handle, struct sessions *sessions, uint32_t now -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - struct named_session_info *sessions_info; - size_t count = - packet_handler_sessions_info(handler, &sessions_info, now); - *sessions = (struct sessions){.sessions_count = count, - .sessions = sessions_info}; -} - -//////////////////////////////////////////////////////////////////////////////// - -void -balancer_stats_free(struct balancer_stats *stats) { - if (stats->vs_count > 0) { - struct named_vs_stats *first_vs = &stats->vs[0]; - struct named_real_stats *reals = first_vs->reals; - free(reals); - for (size_t vs_idx = 0; vs_idx < stats->vs_count; ++vs_idx) { - struct allowed_sources_stats *allowed_sources = - stats->vs[vs_idx].allowed_sources; - free(allowed_sources); - } - } - free(stats->vs); -} - -void -balancer_sessions_free(struct sessions *sessions) { - free(sessions->sessions); -} - -void -balancer_info_free(struct balancer_info *info) { - if (info->vs_count > 0) { - struct named_vs_info *first_vs = &info->vs[0]; - struct named_real_info *reals = first_vs->reals; - free(reals); - } - free(info->vs); -} - -void -balancer_update_info_free(struct balancer_update_info *update_info) { - if (update_info == NULL) { - return; - } - free(update_info->vs_acl_reused); - update_info->vs_acl_reused = NULL; - update_info->vs_acl_reused_count = 0; -} - -void -balancer_graph_free(struct balancer_graph *graph) { - for (size_t i = 0; i < graph->vs_count; i++) { - free(graph->vs[i].reals); - } - free(graph->vs); -} - -void -balancer_graph(struct balancer_handle *handle, struct balancer_graph *graph) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - - // Allocate VS array - graph->vs_count = handler->vs_count; - graph->vs = calloc(graph->vs_count, sizeof(struct graph_vs)); - if (graph->vs == NULL) { - graph->vs_count = 0; - return; - } - - // Iterate through each virtual service - struct vs *vss = ADDR_OF(&handler->vs); - for (size_t vs_idx = 0; vs_idx < handler->vs_count; vs_idx++) { - struct vs *vs = &vss[vs_idx]; - struct graph_vs *graph_vs = &graph->vs[vs_idx]; - - // Copy VS identifier - graph_vs->identifier = vs->identifier; - - // Allocate reals array for this VS - graph_vs->real_count = vs->reals_count; - graph_vs->reals = - calloc(vs->reals_count, sizeof(struct graph_real)); - if (graph_vs->reals == NULL) { - graph_vs->real_count = 0; - continue; - } - - // Iterate through each real in this VS - const struct real *reals = ADDR_OF(&vs->reals); - for (size_t real_idx = 0; real_idx < vs->reals_count; - real_idx++) { - const struct real *real = &reals[real_idx]; - struct graph_real *graph_real = - &graph_vs->reals[real_idx]; - - // Copy real identifier (relative to VS) - graph_real->identifier = real->identifier.relative; - - // Get weight and enabled flag from struct real - graph_real->weight = real->weight; - graph_real->enabled = real->enabled; - } - } -} - -size_t -balancer_session_table_capacity(struct balancer_handle *handle) { - struct balancer *balancer = balancer_handle_deref(handle); - struct balancer_state *state = &balancer->state; - return session_table_capacity(&state->session_table); -} - -int -balancer_real_ph_idx( - struct balancer_handle *handle, - struct real_identifier *real, - struct real_ph_index *real_idx -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct packet_handler *handler = ADDR_OF(&balancer->handler); - return packet_handler_real_idx(handler, real, real_idx); -} - -void -balancer_inspect( - struct balancer_handle *handle, struct balancer_inspect *inspect -) { - struct balancer *balancer = balancer_handle_deref(handle); - struct balancer_state *state = &balancer->state; - packet_handler_inspect( - ADDR_OF(&balancer->handler), - &inspect->packet_handler_inspect, - state->workers - ); - balancer_state_inspect(state, &inspect->state_inspect); - inspect->other_usage = - sizeof(struct balancer) + sizeof(struct packet_handler); - inspect->total_usage = inspect->other_usage + - inspect->packet_handler_inspect.total_usage + - inspect->state_inspect.total_usage; -} - -void -balancer_inspect_free(struct balancer_inspect *inspect) { - if (inspect == NULL) { - return; - } - - // Free packet handler inspect nested structures - if (inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects != - NULL) { - free(inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects - ); - inspect->packet_handler_inspect.vs_ipv4_inspect.vs_inspects = - NULL; - } - - if (inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects != - NULL) { - free(inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects - ); - inspect->packet_handler_inspect.vs_ipv6_inspect.vs_inspects = - NULL; - } -} \ No newline at end of file diff --git a/modules/balancer/controlplane/api/balancer.h b/modules/balancer/controlplane/api/balancer.h deleted file mode 100644 index ffc68fd5a..000000000 --- a/modules/balancer/controlplane/api/balancer.h +++ /dev/null @@ -1,525 +0,0 @@ -#pragma once - -#include "handler.h" -#include "inspect.h" -#include "real.h" -#include "session.h" -#include "state.h" -#include "stats.h" -#include "vs.h" - -/** - * Diagnostics - * - * Unless otherwise stated, on error each API function records a human-readable - * diagnostic message associated with the balancer handle. Retrieve it via - * balancer_take_error_msg(). - * - * Ownership: The returned message is heap-allocated and must be freed by the - * caller with free(). - * - * For creation-time failures use the diag parameter of balancer_create(). - */ - -/** - * Balancer module configuration. - * - * Combines packet handler configuration and session/state configuration - * required to instantiate a balancer instance. - */ -struct balancer_config { - /** Packet handling/session parameters */ - struct packet_handler_config handler; - - /** Session table sizing/config */ - struct state_config state; -}; - -/** - * Information about balancer configuration update operation. - * - * Provides visibility into filter reuse decisions made during - * packet handler update. Helps understand configuration change - * impact and optimization opportunities. - */ -struct balancer_update_info { - /** - * IPv4 virtual service matcher was reused from previous handler. - * - * When true (non-zero): VS lookup filter for IPv4 was not recompiled - * When false (zero): VS lookup filter for IPv4 was recompiled - */ - int vs_ipv4_matcher_reused; - - /** - * IPv6 virtual service matcher was reused from previous handler. - * - * When true (non-zero): VS lookup filter for IPv6 was not recompiled - * When false (zero): VS lookup filter for IPv6 was recompiled - */ - int vs_ipv6_matcher_reused; - - /** - * Number of virtual services that reused ACL from previous handler. - * - * These VS did not need ACL recompilation because their - * allowed_src rules matched the previous configuration. - */ - size_t vs_acl_reused_count; - - /** - * Array of VS identifiers that reused ACL filters. - * - * Contains identifiers for virtual services where ACL was - * reused (not recompiled) from the previous handler. - * - * Array length is vs_acl_reused_count. - * Allocated by balancer_update_packet_handler(). - * Caller must free with free(). - */ - struct vs_identifier *vs_acl_reused; -}; - -struct agent; - -/** - * Opaque handle to a balancer instance. - * - * The handle is returned by balancer_create() and balancers() and is used with - * all other API calls. Its internals are private to the implementation. - * - * Thread-Safety: Does not allow multithreading access. - * Safe to work concurrently with the controlplane and dataplane. - */ -struct balancer_handle; - -struct diag; - -/** - * Create a new balancer instance and register it. - * - * On success returns a handle to the created balancer. On failure returns - * NULL and records diagnostic information in the provided diag object. - * - * Diagnostics: On error, details are written into 'diag'. After a successful - * creation, subsequent API calls record diagnostics on the balancer and can be - * retrieved via balancer_take_error_msg(). - * - * @param agent Agent that will own the balancer. - * @param name Human-readable balancer name (used for identification). - * @param config Initial configuration. - * @param diag Diagnostics sink for error details (must not be NULL). - * @return Newly created balancer handle on success, or NULL on error. - */ -struct balancer_handle * -balancer_create( - struct agent *agent, const char *name, struct balancer_config *config -); - -/** - * Retrieve the last diagnostic error message for this balancer. - * - * Ownership: The returned string is heap-allocated for the caller; you must - * free() it when no longer needed. Returns NULL if no message is available. - * - * @param handle Balancer handle. - * @return Null-terminated error message string to be freed by caller, or NULL. - */ -const char * -balancer_take_error_msg(struct balancer_handle *handle); - -/** - * Get the name of the balancer instance. - * - * @param handle Balancer handle. - * @return Pointer to the balancer name string (owned by the balancer, do not - * free). - */ -const char * -balancer_name(struct balancer_handle *handle); - -// Update - -/** - * Resize the session table used by the balancer. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(handle). - * - * @param handle Balancer handle. - * @param new_size New number of entries to allocate. - * @param now Current monotonic timestamp used for migration bookkeeping. - * @return 0 on success, -1 on error. - */ -int -balancer_resize_session_table( - struct balancer_handle *handle, size_t new_size, uint32_t now -); - -/** - * Get the current session table capacity. - * - * Returns the current maximum number of concurrent sessions the session - * table can hold. This is the hash table size, not the number of active - * sessions. - * - * The capacity can change over time due to: - * - Manual resizing via balancer_resize_session_table() - * - Automatic resizing when load factor exceeds threshold - * - * @param handle Balancer handle. - * @return Current session table capacity (number of entries). - */ -size_t -balancer_session_table_capacity(struct balancer_handle *handle); - -/** - * Update packet handler configuration. - * - * This call applies changes such as timeouts, VS list or source addresses. - * Returns information about filter reuse decisions in the update_info - * parameter. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(balancer). - * - * @param balancer Balancer handle. - * @param config New packet handler configuration. - * @param update_info Output structure filled with update information. - * May be NULL if caller doesn't need this information. - * @return 0 on success, -1 on error. - */ -int -balancer_update_packet_handler( - struct balancer_handle *balancer, - struct packet_handler_config *config, - struct balancer_update_info *update_info -); - -/** - * Free all allocations inside a balancer_update_info structure. - * - * Releases memory allocated by balancer_update_packet_handler() for the - * vs_acl_reused array. Safe to call with partially-initialized structures; - * ignores NULL pointers. - * - * NOTE: This function does NOT free the balancer_update_info structure itself, - * only the dynamically allocated array inside it. - * - * @param update_info Structure to release. The struct itself is not freed. - */ -void -balancer_update_info_free(struct balancer_update_info *update_info); - -/** - * Apply a batch of real server updates. - * - * Each update may change weight and/or enabled state; to skip a field - * use DONT_UPDATE_REAL_WEIGHT and DONT_UPDATE_REAL_ENABLED. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(balancer). - * - * @param balancer Balancer handle. - * @param count Number of updates in the array. - * @param updates Array of updates. - * @return 0 on success, -1 on error. - */ -int -balancer_update_reals( - struct balancer_handle *balancer, - size_t count, - struct real_update *updates -); - -// Stats - -/** - * Optional reference to narrow statistics to a particular packet handler - * attachment point. - * - * Any field may be NULL to indicate no filtering on that dimension. - */ -struct packet_handler_ref { - const char *device; // Optional device name (NULL for any) - const char *pipeline; // Optional pipeline name (NULL for any) - const char *function; // Optional function name (NULL for any) - const char *chain; // Optional chain name (NULL for any) -}; - -/** - * Read balancer statistics, optionally filtered by packet handler reference. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(balancer). - * - * @param balancer Balancer handle. - * @param stats Output structure to be filled. - * @param ref Optional filter; pass NULL for aggregate stats. - * @return 0 on success, -1 on error. - */ -int -balancer_stats( - struct balancer_handle *balancer, - struct balancer_stats *stats, - struct packet_handler_ref *ref -); - -/** - * Free all allocations inside a balancer_stats structure. - * - * Releases memory allocated by balancer_stats() for the VS and real - * statistics arrays. Safe to call with partially-initialized structures; - * ignores NULL pointers. - * - * NOTE: This function does NOT free the balancer_stats structure itself, - * only the dynamically allocated arrays inside it. - * - * @param stats Structure to release. The struct itself is not freed. - */ -void -balancer_stats_free(struct balancer_stats *stats); - -/** - * Aggregated information about a balancer instance. - * - * Provides a comprehensive snapshot of the balancer's operational state, - * including active session counts, last activity timestamp, and detailed - * information about all virtual services and their real servers. - * - * DATA FRESHNESS: - * - active_sessions: Updated during periodic refresh (if enabled) or on-demand - * - last_packet_timestamp: Real-time from dataplane - * - vs array: Contains per-VS and per-real runtime information - * - * MEMORY MANAGEMENT: - * - balancer_info() allocates the 'vs' array and all nested structures - * - Caller must call balancer_info_free() to release all allocations - * - Safe to call balancer_info_free() on partially-initialized structures - * - * USAGE PATTERN: - * ```c - * struct balancer_info info; - * if (balancer_info(handle, &info, now) == 0) { - * // Use info.active_sessions, info.vs, etc. - * balancer_info_free(&info); - * } - * ``` - */ -struct balancer_info { - /** - * Total number of active sessions across all virtual services. - * - * This is the sum of active sessions for all VSs and represents - * the current load on the balancer. - */ - size_t active_sessions; - - /** - * Timestamp of the most recent packet processed by any VS. - * - * Monotonic timestamp (seconds since boot) representing the last - * activity across the entire balancer instance. This is the maximum - * of all VS last_packet_timestamp values. - * - * Updated in real-time by the dataplane when packets are processed. - */ - uint32_t last_packet_timestamp; - - /** - * Number of virtual services in the 'vs' array. - * - * This matches the number of virtual services configured in the - * packet handler configuration. - */ - size_t vs_count; - - /** - * Array of virtual service runtime information. - * - * Contains detailed information for each VS including: - * - Active session counts per VS - * - Per-real server information (active sessions, last activity) - * - Last packet timestamps - * - * OWNERSHIP: - * - Allocated by balancer_info() - * - Must be freed with balancer_info_free() - * - Array length is vs_count - */ - struct named_vs_info *vs; -}; - -/** - * Query aggregated balancer information. - * - * On success fills the provided structure and allocates arrays inside it. - * Release all memory with balancer_info_free(). - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(balancer). - * - * @param balancer Balancer handle. - * @param info Output structure to be filled. - * @return 0 on success, -1 on error. - */ -int -balancer_info( - struct balancer_handle *balancer, - struct balancer_info *info, - uint32_t now -); - -void -balancer_active_sessions( - struct balancer_handle *balancer, struct balancer_info *info -); - -/** - * Free all allocations inside a balancer_info previously filled by - * balancer_info(). - * - * Safe to call with partially-initialized structures; ignores NULL pointers. - * - * @param info Structure to release. The struct itself is not freed. - */ -void -balancer_info_free(struct balancer_info *info); - -/** - * Enumerate active sessions tracked by the balancer. - * - * Returns a heap-allocated array of named_session_info entries representing - * a point-in-time snapshot. The caller owns the array and must - * balancer_sessions_free() it. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(balancer). - * - * @param balancer Balancer handle. - * @param sessions Output pointer to a heap-allocated array of session infos. - * @return Number of entries on success - */ -void -balancer_sessions( - struct balancer_handle *balancer, - struct sessions *sessions, - uint32_t now -); - -/** - * Free all allocations inside a sessions structure. - * - * Releases memory allocated by balancer_sessions() for the session - * information array. Safe to call with partially-initialized structures; - * ignores NULL pointers. - * - * NOTE: This function does NOT free the sessions structure itself, - * only the dynamically allocated array inside it. - * - * @param sessions Structure to release. The struct itself is not freed. - */ -void -balancer_sessions_free(struct sessions *sessions); - -struct balancer_graph; - -/** - * Retrieve the balancer topology graph. - * - * Returns a snapshot of the complete balancer topology showing all - * virtual services and their real servers with current operational - * states (effective weights, enabled status). - * - * The graph provides visibility into: - * - Current effective weights (may differ from config due to WLC) - * - Real server enabled/disabled states - * - Complete VS-to-real relationships - * - * MEMORY MANAGEMENT: - * - Allocates memory for the graph structure and all nested arrays - * - Caller must free with balancer_graph_free() when done - * - Safe to call balancer_graph_free() on partially-initialized graphs - * - * USAGE PATTERN: - * ```c - * struct balancer_graph graph; - * balancer_graph(handle, &graph); - * // Use graph data... - * balancer_graph_free(&graph); - * ``` - * - * @param handle Balancer handle. - * @param graph Output structure to be filled with graph data. - */ -void -balancer_graph(struct balancer_handle *handle, struct balancer_graph *graph); - -/** - * Free all allocations inside a balancer_graph structure. - * - * Releases memory allocated by balancer_graph() for the virtual service - * and real server arrays. This includes: - * - The top-level VS array (graph->vs) - * - Each VS's real server array (vs->reals) - * - * Safe to call with partially-initialized structures; ignores NULL pointers. - * - * NOTE: This function does NOT free the balancer_graph structure itself, - * only the dynamically allocated arrays inside it. - * - * @param graph Structure to release. The struct itself is not freed. - */ -void -balancer_graph_free(struct balancer_graph *graph); - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Get packet handler indices for a real server. - * - * Translates a real server identifier (VS + real) into packet handler - * internal indices. This is useful for low-level operations that need - * to directly access packet handler data structures. - * - * The returned indices identify: - * - vs_idx: Index of the virtual service in the packet handler's VS array - * - real_idx: Index of the real within that virtual service's real array - * - * These indices can be used to: - * - Access real server configuration in packet handler structures - * - Perform direct updates to packet handler state - * - Map between high-level identifiers and internal indices - * - * USAGE: - * This is primarily an internal API used by the manager layer to - * coordinate between the high-level balancer API and the low-level - * packet handler implementation. - * - * Diagnostics: On error, a message is recorded and retrievable via - * balancer_take_error_msg(handle). - * - * @param handle Balancer handle. - * @param real Real server identifier (VS + real address/port). - * @param real_idx Output structure to be filled with indices. - * @return 0 on success, -1 on error (e.g., real not found). - */ -int -balancer_real_ph_idx( - struct balancer_handle *handle, - struct real_identifier *real, - struct real_ph_index *real_idx -); - -void -balancer_inspect( - struct balancer_handle *handle, struct balancer_inspect *inspect -); - -void -balancer_inspect_free(struct balancer_inspect *inspect); - -void -balancer_active_sessions( - struct balancer_handle *handle, struct balancer_info *info -); \ No newline at end of file diff --git a/modules/balancer/controlplane/api/graph.h b/modules/balancer/controlplane/api/graph.h deleted file mode 100644 index 806708cf4..000000000 --- a/modules/balancer/controlplane/api/graph.h +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once - -#include "real.h" -#include "vs.h" -#include - -/** - * Real server state in the balancer topology graph. - * - * Represents the current operational state of a real server within a - * virtual service, including its effective weight and enabled status. - * This is a snapshot of the runtime state, which may differ from the - * configured state due to dynamic weight adjustments (WLC). - * - * WEIGHT SEMANTICS: - * - This is the EFFECTIVE weight currently used by the scheduler - * - For non-WLC virtual services: weight == configured weight - * - For WLC virtual services: weight may be dynamically adjusted - * - The original configured weight is preserved in the configuration - * - * USE CASES: - * - Monitoring current load distribution - * - Debugging WLC weight adjustments - * - Visualizing balancer topology - * - Detecting disabled or removed reals - */ -struct graph_real { - /** Real server identifier (relative to its virtual service) */ - struct relative_real_identifier identifier; - - /** - * Current effective weight used by the scheduler. - * - * This is the weight currently active in the dataplane for traffic - * distribution. It may differ from the configured weight if: - * - WLC is enabled and has adjusted weights based on session counts - * - Real was recently updated via UpdateReals/UpdateRealsWlc - * - * For WLC-enabled virtual services, this weight is recalculated - * every refresh_period based on active session distribution. - */ - uint16_t weight; - - /** - * Whether the real server is currently enabled. - * - * When false: - * - Real receives no NEW sessions - * - Existing sessions may continue to be forwarded (until timeout) - * - Real is excluded from scheduling decisions - * - Real is excluded from WLC calculations - * - * When true: - * - Real participates in scheduling - * - Real can receive new sessions - * - Real is included in WLC calculations - */ - bool enabled; -}; - -/** - * Virtual service state in the balancer topology graph. - * - * Represents a virtual service and all its associated real servers - * with their current operational states. This provides a complete - * snapshot of the VS topology at a point in time. - * - * MEMORY MANAGEMENT: - * - The 'reals' array is heap-allocated - * - Caller must free with balancer_graph_free() after use - * - Do not modify the array contents directly - */ -struct graph_vs { - /** Virtual service identifier */ - struct vs_identifier identifier; - - /** Number of real servers in the 'reals' array */ - size_t real_count; - - /** - * Array of real server states. - * - * Contains current state for all reals configured for this VS, - * including both enabled and disabled reals. The order matches - * the configuration order. - * - * Ownership: Allocated by balancer_graph(), freed by - * balancer_graph_free() - */ - struct graph_real *reals; -}; - -/** - * Complete balancer topology graph. - * - * Provides a snapshot of the entire balancer configuration showing - * all virtual services and their real servers with current operational - * states (weights, enabled status). - * - * This structure is useful for: - * - Visualizing the complete load balancer topology - * - Monitoring real server states across all virtual services - * - Debugging configuration and WLC behavior - * - Understanding current traffic distribution - * - Detecting configuration inconsistencies - * - * USAGE PATTERN: - * ```c - * struct balancer_graph graph; - * balancer_graph(handle, &graph); - * - * // Use graph data... - * for (size_t i = 0; i < graph.vs_count; i++) { - * struct graph_vs *vs = &graph.vs[i]; - * for (size_t j = 0; j < vs->real_count; j++) { - * struct graph_real *real = &vs->reals[j]; - * // Process real state... - * } - * } - * - * balancer_graph_free(&graph); - * ``` - * - * MEMORY MANAGEMENT: - * - All arrays are heap-allocated by balancer_graph() - * - Must be freed with balancer_graph_free() when done - * - Safe to call balancer_graph_free() on partially-initialized graphs - */ -struct balancer_graph { - /** Number of virtual services in the 'vs' array */ - size_t vs_count; - - /** - * Array of virtual service states. - * - * Contains state for all configured virtual services in the - * balancer. The order matches the configuration order. - * - * Ownership: Allocated by balancer_graph(), freed by - * balancer_graph_free() - */ - struct graph_vs *vs; -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/api/handler.h b/modules/balancer/controlplane/api/handler.h deleted file mode 100644 index e60966f12..000000000 --- a/modules/balancer/controlplane/api/handler.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include "common/network.h" - -#include "session.h" - -#include - -/** - * Packet handler configuration. - * - * Defines runtime parameters for session handling and the set of virtual - * services available for scheduling, as well as optional decapsulation - * behavior at the start of the processing pipeline. - * - * COMPONENTS: - * - Session timeouts: Control when idle sessions expire - * - Virtual services: List of load-balanced services - * - Source addresses: Used for generated packets (ICMP, health checks) - * - Decapsulation: Optional tunnel unwrapping before processing - * - * MEMORY MANAGEMENT: - * - Caller allocates and manages all arrays (vs, decap_v4, decap_v6) - * - Arrays must remain valid for the lifetime of the configuration - * - Use balancer_update_packet_handler() to apply changes - */ -struct packet_handler_config { - /** - * Session timeout configuration. - * - * Defines how long sessions remain active based on the last - * observed packet type (TCP SYN, FIN, UDP, etc.). Different - * timeouts allow fine-grained control over session lifecycle. - */ - struct sessions_timeouts sessions_timeouts; - - /** Number of virtual services in the 'vs' array */ - size_t vs_count; - - /** - * Array of virtual service configurations. - * - * Each entry defines a load-balanced service including: - * - Service identifier (IP, port, protocol) - * - List of real servers (backends) - * - Scheduling flags (WLC, OPS, Pure L3, etc.) - * - * Ownership: Caller allocates and manages this array - */ - struct named_vs_config *vs; - - /** - * IPv4 source address for generated packets. - * - * Used when the balancer generates packets such as: - * - ICMP error responses - * - Health check probes (if implemented) - * - Other control plane traffic - */ - struct net4_addr source_v4; - - /** - * IPv6 source address for generated packets. - * - * Used when the balancer generates IPv6 packets such as: - * - ICMPv6 error responses - * - Health check probes (if implemented) - * - Other control plane traffic - */ - struct net6_addr source_v6; - - /** Number of IPv4 decapsulation endpoints in 'decap_v4' array */ - size_t decap_v4_count; - - /** - * Array of IPv4 addresses for tunnel decapsulation. - * - * Packets arriving with these destination addresses will be - * decapsulated (tunnel unwrapped) before load balancing. - * Useful for GRE, IPIP, or other tunnel protocols. - * - * Ownership: Caller allocates and manages this array - */ - struct net4_addr *decap_v4; - - /** Number of IPv6 decapsulation endpoints in 'decap_v6' array */ - size_t decap_v6_count; - - /** - * Array of IPv6 addresses for tunnel decapsulation. - * - * Packets arriving with these destination addresses will be - * decapsulated (tunnel unwrapped) before load balancing. - * Useful for GRE, IPIP, or other tunnel protocols. - * - * Ownership: Caller allocates and manages this array - */ - struct net6_addr *decap_v6; -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/api/inspect.h b/modules/balancer/controlplane/api/inspect.h deleted file mode 100644 index 8411476b3..000000000 --- a/modules/balancer/controlplane/api/inspect.h +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -#include "vs.h" - -#include - -//////////////////////////////////////////////////////////////////////////////// - -// TODO: docs -struct reals_usage { - uint64_t counters_usage; - uint64_t data_usage; - uint64_t total_usage; -}; - -// TODO: docs -struct vs_inspect { - uint64_t acl_usage; - uint64_t ring_usage; - uint64_t counters_usage; - struct reals_usage reals_usage; - uint64_t other_usage; - uint64_t total_usage; -}; - -// TODO: docs -struct named_vs_inspect { - struct vs_identifier identifier; - struct vs_inspect inspect; -}; - -// TODO: docs -struct packet_handler_vs_inspect { - uint64_t matcher_usage; - - uint64_t summary_vs_usage; - - size_t vs_count; - struct named_vs_inspect *vs_inspects; - - uint64_t announce_usage; - - uint64_t index_usage; - - uint64_t total_usage; -}; - -// TODO: docs -struct packet_handler_inspect { - struct packet_handler_vs_inspect vs_ipv4_inspect; - struct packet_handler_vs_inspect vs_ipv6_inspect; - - uint64_t summary_vs_usage; - - uint64_t vs_index_usage; - - uint64_t reals_index_usage; - - uint64_t counters_usage; - - uint64_t decap_usage; - - uint64_t total_usage; -}; - -struct state_inspect { - size_t session_table_usage; - size_t total_usage; -}; - -struct balancer_inspect { - struct packet_handler_inspect packet_handler_inspect; - struct state_inspect state_inspect; - size_t other_usage; - size_t total_usage; -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/api/meson.build b/modules/balancer/controlplane/api/meson.build deleted file mode 100644 index c3edbccff..000000000 --- a/modules/balancer/controlplane/api/meson.build +++ /dev/null @@ -1,28 +0,0 @@ -dependencies = [ - lib_common_dep, - lib_agent_cp_dep, - lib_balancer_state_dep, - lib_balancer_packet_handler_dep, - lib_filter_compiler_dep, -] - -includes = include_directories('.', '../') - -sources = files( - 'balancer.c', -) - -lib_balancer_cp = static_library( - 'balancer_cp', - sources, - c_args: yanet_c_args, - link_args: yanet_link_args, - dependencies: dependencies, - include_directories: includes, - install: false, -) - -lib_balancer_cp_dep = declare_dependency( - link_with: [lib_balancer_cp], -) - diff --git a/modules/balancer/controlplane/api/real.h b/modules/balancer/controlplane/api/real.h deleted file mode 100644 index eb323ac58..000000000 --- a/modules/balancer/controlplane/api/real.h +++ /dev/null @@ -1,264 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "common/network.h" -#include "vs.h" - -/** - * Maximum allowed scheduler weight for a real server. - */ -#define MAX_REAL_WEIGHT ((uint16_t)1024) - -/** - * Real server identifier within a virtual service context. - * - * Identifies a specific real server by its IP address and port, relative - * to its parent virtual service. This is the "relative" identifier because - * it doesn't include the VS information. - * - * PORT SEMANTICS: - * - Currently RESERVED FOR FUTURE USE - * - The actual destination port is determined by: - * * Standard mode (pure_l3=false): Uses the virtual service port - * * Pure L3 mode (pure_l3=true): Uses the client's original destination port - * - This field is reserved for future functionality where real servers - * might listen on different ports than the virtual service - */ -struct relative_real_identifier { - /** Real server IP address (IPv4 or IPv6) */ - struct net_addr addr; - - /** - * IP protocol version indicator. - * - * Values: - * - 0: IPPROTO_IP (IPv4) - * - 41: IPPROTO_IPV6 (IPv6) - * - * This is derived from the address type and used internally - * for protocol-specific processing. - */ - uint8_t ip_proto; - - /** - * Destination port on the real server. - * - * CURRENT STATUS: RESERVED FOR FUTURE USE - * - * The actual port used when forwarding to the real is currently - * determined by the virtual service configuration: - * - Standard mode: VS port is used - * - Pure L3 mode: Client's original destination port is preserved - * - * FUTURE USE: - * This field is reserved for port translation functionality where - * real servers could listen on different ports than the VS. - */ - uint16_t port; -}; - -/** - * Identifier of a real endpoint within a virtual service. - * - * Combines the parent VS identifier with address, transport protocol and port. - */ -struct real_identifier { - /** Parent virtual service identifier */ - struct vs_identifier vs_identifier; - - /** Identifier of real relative to its virtual service */ - struct relative_real_identifier relative; -}; - -/** - * Static configuration of a real server. - */ -struct real_config { - /** Source network/addresses used when sending to this real. */ - struct net src; - - /** Scheduler weight [0..MAX_REAL_WEIGHT] */ - uint16_t weight; -}; - -/** - * Sentinel value meaning "do not change weight" in real_update. - */ -#define DONT_UPDATE_REAL_WEIGHT ((uint16_t)-1) - -/** - * Sentinel value meaning "do not change enabled flag" in real_update. - */ -#define DONT_UPDATE_REAL_ENABLED ((uint8_t)-1) - -/** - * Partial update for a real server configuration. - * - * Use DONT_UPDATE_REAL_WEIGHT or DONT_UPDATE_REAL_ENABLED to skip fields. - */ -struct real_update { - /** Real key to update */ - struct real_identifier identifier; - - /** New weight (ignored if DONT_UPDATE_REAL_WEIGHT) */ - uint16_t weight; - - /** (ignored if DONT_UPDATE_REAL_ENABLED) */ - /** 0 = disabled, non-zero = enabled */ - uint8_t enabled; -}; - -/** - * Per-real-server statistics. - * - * Tracks packet processing and session creation for a specific real - * server within a virtual service. - */ -struct real_stats { - /** - * Packets for sessions assigned to this real when it was disabled. - * - * Incremented when: - * - A session exists for this real - * - The real is currently disabled - * - A packet arrives for that session - * - * This indicates packets that were dropped or rescheduled because - * the real was disabled after the session was created. - */ - uint64_t packets_real_disabled; - - /** - * One-Packet Scheduling packets sent without creating a session. - * - * Incremented when VS_OPS_FLAG is set and packets are forwarded - * to this real without session tracking. - */ - uint64_t ops_packets; - - /** - * ICMP error packets forwarded to this real server. - * - * Includes ICMP errors related to sessions assigned to this real, - * such as destination unreachable or time exceeded messages. - */ - uint64_t error_icmp_packets; - - /** - * Total number of new sessions created with this real as backend. - * - * Incremented each time a new session is created and this real - * is selected by the scheduler. Does not include OPS packets. - */ - uint64_t created_sessions; - - /** - * Total packets forwarded to this real server. - * - * Includes: - * - Regular session packets - * - OPS packets (if VS_OPS_FLAG is set) - * - ICMP error packets - */ - uint64_t packets; - - /** - * Total bytes forwarded to this real server. - * - * Includes all packet types (regular, OPS, ICMP). - * Measured at the IP layer (includes IP header and payload). - */ - uint64_t bytes; -}; - -/** - * Real server statistics with identifier. - * - * Associates statistics with a specific real server within a virtual - * service context. - */ -struct named_real_stats { - /** Real server identifier (relative to its VS) */ - struct relative_real_identifier real; - - /** Statistics for this real server */ - struct real_stats stats; -}; - -/** - * Real server runtime information. - * - * Provides runtime information about a specific real server including - * active session count and last activity timestamp. - */ -struct named_real_info { - /** Real server identifier (relative to its VS) */ - struct relative_real_identifier real; - - /** - * Timestamp of the last packet processed for this real server. - * - * Monotonic timestamp of when any packet - * was forwarded to this real. - * - * Updated in real-time by the dataplane when: - * - Packets are forwarded to the real - * - ICMP errors are forwarded to the real. - */ - uint32_t last_packet_timestamp; - - /** - * Number of active sessions currently assigned to this real server. - * - * This count represents sessions tracked by the balancer where - * this real was selected as the backend. Does not include - * OPS packets (no session tracking). - */ - size_t active_sessions; -}; - -/** - * Real server configuration with identifier. - * - * Associates configuration with a specific real server within a virtual - * service context. - */ -struct named_real_config { - /** Real server identifier (relative to its VS) */ - struct relative_real_identifier real; - - /** Configuration for this real server */ - struct real_config config; -}; - -/** - * Packet handler internal indices for a real server. - * - * Maps a real server to its position in the packet handler's internal - * data structures. Used for low-level operations that need direct access - * to packet handler arrays. - * - * USAGE: - * Primarily used internally by the manager layer to coordinate between - * high-level balancer API and low-level packet handler implementation. - */ -struct real_ph_index { - /** - * Index of the virtual service in the packet handler's VS array. - * - * This is the position of the parent VS in the - * packet_handler_config.vs array. - */ - size_t vs_idx; - - /** - * Index of the real within the virtual service's real array. - * - * This is the position of the real in the vs_config.reals array - * for the parent virtual service. - */ - size_t real_idx; -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/api/session.h b/modules/balancer/controlplane/api/session.h deleted file mode 100644 index 1cad14eb9..000000000 --- a/modules/balancer/controlplane/api/session.h +++ /dev/null @@ -1,126 +0,0 @@ -#pragma once - -#include "common/network.h" -#include "real.h" -#include - -/** - * Session timeout configuration per transport/state. - * - * Time values are expressed in seconds and are used to expire idle sessions - * depending on the last observed packet type. - */ -struct sessions_timeouts { - /** Timeout for sessions created/updated by TCP SYN-ACK */ - uint32_t tcp_syn_ack; - - /** Timeout for sessions created/updated by TCP SYN */ - uint32_t tcp_syn; - - /** Timeout for sessions updated by TCP FIN */ - uint32_t tcp_fin; - - /** Default timeout for TCP packets */ - uint32_t tcp; - - /** Default timeout for UDP packets */ - uint32_t udp; - - /** Fallback timeout for other/non-matching packets */ - uint32_t def; -}; - -/** - * Unique key that identifies a session between a client and a real. - * - * Consists of client address/port and the selected real endpoint. - */ -struct session_identifier { - /** Client source IP (IPv4/IPv6) */ - struct net_addr client_ip; - - /** Client source port (host byte order) */ - uint16_t client_port; - - /** Selected real endpoint */ - struct real_identifier real; -}; - -/** - * Runtime session metadata. - * - * All timestamps are monotonic time values. - */ -struct session_info { - /** Session creation time */ - uint32_t create_timestamp; - - /** Last packet time observed */ - uint32_t last_packet_timestamp; - - /** Current timeout applied to this session */ - uint32_t timeout; -}; - -/** - * Session information paired with its identifier. - * - * Combines the unique session key (client + real) with runtime metadata - * (timestamps, timeout). Used when enumerating active sessions. - */ -struct named_session_info { - /** Unique session identifier (client IP/port + real endpoint) */ - struct session_identifier identifier; - - /** Runtime session metadata (timestamps, timeout) */ - struct session_info info; -}; - -/** - * Container for a collection of active sessions. - * - * Holds a snapshot of all active sessions tracked by the balancer at a - * specific point in time. Used by balancer_sessions() to return session - * enumeration results. - * - * MEMORY MANAGEMENT: - * - The 'sessions' array is heap-allocated by balancer_sessions() - * - Caller must free with balancer_sessions_free() when done - * - Safe to call balancer_sessions_free() on partially-initialized structures - * - * USAGE PATTERN: - * ```c - * struct sessions sessions; - * balancer_sessions(handle, &sessions, now); - * for (size_t i = 0; i < sessions.sessions_count; i++) { - * // Process sessions.sessions[i] - * } - * balancer_sessions_free(&sessions); - * ``` - */ -struct sessions { - /** - * Number of active sessions in the 'sessions' array. - * - * This is the count of sessions that were active at the time - * balancer_sessions() was called. The count may change between - * calls as sessions are created and expire. - */ - size_t sessions_count; - - /** - * Array of active session information. - * - * Contains detailed information for each active session including: - * - Client IP address and port - * - Selected real server endpoint - * - Creation and last activity timestamps - * - Current timeout value - * - * OWNERSHIP: - * - Allocated by balancer_sessions() - * - Must be freed with balancer_sessions_free() - * - Array length is sessions_count - */ - struct named_session_info *sessions; -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/api/state.h b/modules/balancer/controlplane/api/state.h deleted file mode 100644 index 534af216c..000000000 --- a/modules/balancer/controlplane/api/state.h +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include -#include - -/** - * Session table state configuration. - * - * Controls the sizing of the session table used to track active connections - * between clients and real servers. The session table is a hash table that - * stores session state including client address, selected real server, and - * timeout information. - * - * MEMORY USAGE: - * Each session entry consumes approximately 64-128 bytes depending on the - * platform. The actual memory usage is: - * memory ≈ table_capacity * sizeof(session_entry) * (1 + overhead) - * where overhead accounts for hash table load factor and metadata. - * - * PERFORMANCE CONSIDERATIONS: - * - Larger capacity: Lower collision rate, faster lookups, more memory - * - Smaller capacity: Higher collision rate, slower lookups, less memory - * - Recommended load factor: 0.7-0.9 (70-90% full before resizing) - * - * AUTOMATIC RESIZING: - * When refresh_period is enabled and session_table_max_load_factor is set, - * the table automatically doubles in size when: - * (active_sessions / table_capacity) > max_load_factor - * - * SIZING GUIDELINES: - * - Expected sessions: Set capacity to expected_sessions / 0.75 - * - High-traffic: Start with 100K-1M capacity - * - Medium-traffic: Start with 10K-100K capacity - * - Low-traffic: Start with 1K-10K capacity - * - Enable auto-resize to handle traffic spikes - */ -struct state_config { - /** - * Maximum number of concurrent sessions the table can hold. - * - * This is the hash table size, not the maximum number of active - * sessions. Due to hash collisions and load factor considerations, - * the effective capacity is typically 70-90% of this value. - * - * CONSTRAINTS: - * - Must be > 0 - * - Should be a power of 2 for optimal hash distribution - * - Typical range: 1024 to 10,000,000 - * - * RESIZING: - * - Can be changed via balancer_resize_session_table() - * - Automatically doubled when load factor exceeds threshold - * - Resizing migrates existing sessions to new table - */ - size_t table_capacity; -}; diff --git a/modules/balancer/controlplane/api/stats.h b/modules/balancer/controlplane/api/stats.h deleted file mode 100644 index f0aa1a87f..000000000 --- a/modules/balancer/controlplane/api/stats.h +++ /dev/null @@ -1,190 +0,0 @@ -#pragma once - -#include -#include - -/** - * Module counters for L4 packets. - * - * Tracks packet processing through the L4 (TCP/UDP) load balancing path, - * including successful forwards and various failure conditions. - */ -struct balancer_l4_stats { - /** Total L4 packets received for processing */ - uint64_t incoming_packets; - - /** Packets that failed virtual service selection (no matching VS) */ - uint64_t select_vs_failed; - - /** Invalid or malformed packets that couldn't be processed */ - uint64_t invalid_packets; - - /** - * Packets that failed real server selection. - * - * Incremented when: - * - No real servers are configured for the VS - * - All real servers are disabled or have zero weight - */ - uint64_t select_real_failed; - - /** Packets successfully forwarded to a selected real server */ - uint64_t outgoing_packets; -}; - -/** - * Counters for ICMP packets of a specific IP version (IPv4 or IPv6). - * - * Tracks ICMP error handling including forwarding to real servers, - * echo reply generation, and various validation failures. - */ -struct balancer_icmp_stats { - /** Total ICMP packets received for processing */ - uint64_t incoming_packets; - - /** ICMP packets rejected due to source address policy */ - uint64_t src_not_allowed; - - /** ICMP echo replies generated by the balancer itself */ - uint64_t echo_responses; - - /** - * ICMP error packets with payload too short to extract IP header. - * - * ICMP errors should contain the original IP header that triggered - * the error. This counter tracks cases where the payload is truncated. - */ - uint64_t payload_too_short_ip; - - /** - * ICMP errors where original source doesn't match current destination. - * - * For proper ICMP error handling, the original packet's source - * should match the current packet's destination. This validates - * the error is related to traffic we forwarded. - */ - uint64_t unmatching_src_from_original; - - /** - * ICMP error packets with payload too short to extract port numbers. - * - * After extracting the IP header, we need the transport header - * (TCP/UDP) to get port numbers for session lookup. - */ - uint64_t payload_too_short_port; - - /** - * ICMP errors for unexpected transport protocols. - * - * The balancer only handles TCP and UDP. ICMP errors for other - * protocols (SCTP, DCCP, etc.) are counted here. - */ - uint64_t unexpected_transport; - - /** - * ICMP errors for unrecognized virtual services. - * - * The original packet's destination doesn't match any configured - * virtual service, so we can't determine where to forward the error. - */ - uint64_t unrecognized_vs; - - /** ICMP error packets successfully forwarded to real servers */ - uint64_t forwarded_packets; - - /** - * ICMP error packets broadcasted to peer balancers. - * - * In multi-balancer setups, ICMP errors may be cloned and sent - * to other balancer instances for session synchronization. - */ - uint64_t broadcasted_packets; - - /** Number of ICMP packet clones created and sent to peers */ - uint64_t packet_clones_sent; - - /** Number of ICMP packet clones received from peers */ - uint64_t packet_clones_received; - - /** Failures when attempting to create ICMP packet clones */ - uint64_t packet_clone_failures; -}; - -/** - * Total counts of incoming/outgoing packets and decapsulation results. - * - * Tracks overall pipeline statistics including total traffic volume - * and tunnel decapsulation operations. - */ -struct balancer_common_stats { - /** Total packets entering the balancer pipeline */ - uint64_t incoming_packets; - - /** Total bytes of incoming packets */ - uint64_t incoming_bytes; - - /** - * Packets with unsupported network protocol. - * - * The balancer supports IPv4 and IPv6. Other protocols - * (ARP, MPLS, etc.) are counted here and dropped. - */ - uint64_t unexpected_network_proto; - - /** Packets successfully decapsulated from tunnels */ - uint64_t decap_successful; - - /** - * Packets that failed tunnel decapsulation. - * - * Incremented when: - * - Tunnel header is malformed - * - Inner packet is invalid - */ - uint64_t decap_failed; - - /** Total packets exiting the balancer pipeline */ - uint64_t outgoing_packets; - - /** Total bytes of outgoing packets */ - uint64_t outgoing_bytes; -}; - -/** - * Aggregated statistics for the balancer. - * - * Contains per-module counters and optional per-service/per-real snapshots. - * Provides a comprehensive view of balancer performance and traffic patterns. - * - * MEMORY MANAGEMENT: - * - The 'vs' array is allocated by balancer_stats() - * - Caller must free with balancer_stats_free() when done - * - Safe to call balancer_stats_free() on partially-initialized structures - */ -struct balancer_stats { - /** L4 (TCP/UDP) module statistics */ - struct balancer_l4_stats l4; - - /** ICMP (IPv4) module statistics */ - struct balancer_icmp_stats icmp_ipv4; - - /** ICMP (IPv6) module statistics */ - struct balancer_icmp_stats icmp_ipv6; - - /** Common pipeline statistics (all traffic) */ - struct balancer_common_stats common; - - /** Number of virtual services in the 'vs' array */ - size_t vs_count; - - /** - * Array of per-virtual-service statistics. - * - * Contains detailed statistics for each configured virtual service, - * including per-real server counters. - * - * Ownership: Allocated by balancer_stats(), freed by - * balancer_stats_free() - */ - struct named_vs_stats *vs; -}; diff --git a/modules/balancer/controlplane/api/vs.h b/modules/balancer/controlplane/api/vs.h deleted file mode 100644 index 26d3d13c9..000000000 --- a/modules/balancer/controlplane/api/vs.h +++ /dev/null @@ -1,827 +0,0 @@ -#pragma once - -#include "common/network.h" - -#include -#include - -/** - * Virtual service feature flags. - * - * These flags control various aspects of virtual service behavior including - * encapsulation method, packet modifications, and routing mode. - */ - -/** - * Pure Layer 3 routing mode flag. - * - * When set, the virtual service matches ALL traffic with the specified IP - * address and transport protocol, regardless of destination port. - * - * BEHAVIOR: - * - Virtual service port MUST be 0 (configuration rejected otherwise) - * - Matches traffic to ANY port for the specified IP and protocol - * - Packets are forwarded to reals using the client's original destination port - * - No two pure L3 services can have the same (IP, protocol) combination - * - * USE CASES: - * - Load balancing all traffic to an IP regardless of port - * - Transparent proxying scenarios - * - When port-based routing is not needed - * - * STANDARD MODE (flag not set): - * - Virtual service port can be any valid value (1-65535) - * - Matches traffic to the specific (IP, port, protocol) combination - * - Packets are forwarded to reals using the virtual service port - */ -#define VS_PURE_L3_FLAG ((uint8_t)(1ull << 0)) - -/** - * Fix TCP MSS (Maximum Segment Size) option flag. - * - * When set, the balancer adjusts the TCP MSS option in SYN packets to - * account for encapsulation overhead (IPIP or GRE), preventing packet - * fragmentation. - * - * BEHAVIOR: - * - Inspects TCP SYN packets for MSS option - * - Reduces MSS value by encapsulation overhead: - * * IPIP: 20 bytes (IPv4) or 40 bytes (IPv6) - * * GRE: 24 bytes (IPv4) or 44 bytes (IPv6) - * - Ensures end-to-end MTU compatibility - * - * RECOMMENDATION: - * - Enable when using tunneling (IPIP or GRE) - * - Prevents fragmentation issues - * - Improves TCP performance - */ -#define VS_FIX_MSS_FLAG ((uint8_t)(1ull << 1)) - -/** - * Use GRE encapsulation flag. - * - * When set, packets are tunneled to real servers using GRE (Generic - * Routing Encapsulation) instead of IPIP (IP-in-IP). - * - * COMPARISON: - * - GRE: More flexible, can carry additional metadata, 4 extra bytes overhead - * - IPIP: Simpler, lower overhead, less flexible - * - * OVERHEAD: - * - GRE adds 24 bytes (IPv4) or 44 bytes (IPv6) to packet size - * - IPIP adds 20 bytes (IPv4) or 40 bytes (IPv6) to packet size - * - * RECOMMENDATION: - * - Use GRE when you need protocol flexibility - * - Use IPIP (flag not set) for lower overhead - */ -#define VS_GRE_FLAG ((uint8_t)(1ull << 2)) - -/** - * One Packet Scheduling (OPS) mode flag. - * - * When set, each packet is independently scheduled to a real server - * without creating or tracking sessions. This is useful for stateless - * protocols or when session tracking is not needed. - * - * BEHAVIOR WHEN SET: - * - No session table entries created - * - Each packet scheduled independently - * - Scheduler algorithm still applies (source_hash or round_robin) - * - Lower memory usage (no session state) - * - Lower CPU usage (no session lookups) - * - * BEHAVIOR WHEN NOT SET: - * - Sessions are created and tracked - * - All packets of a connection go to the same real server - * - Session table memory required - * - Session lookup overhead per packet - * - * USE CASES: - * - Stateless protocols (e.g., DNS queries) - * - When session affinity is not required - * - High packet rate, short-lived connections - * - Memory-constrained environments - * - * LIMITATIONS: - * - No session affinity (same client may hit different reals) - * - Cannot track connection state - * - May cause issues with stateful protocols - */ -#define VS_OPS_FLAG ((uint8_t)(1ull << 3)) - -/** - * Identifier of a virtual service. - * - * Uniquely identifies a load-balanced service by its network address, - * transport protocol, and destination port. This combination defines - * which traffic will be matched and load-balanced. - * - * PORT SEMANTICS: - * - Standard mode: port specifies the exact service port (1-65535) - * - Pure L3 mode (VS_PURE_L3_FLAG): port MUST be 0, matches all ports - */ -struct vs_identifier { - /** - * Virtual service IP address (IPv4 or IPv6). - * - * This is the address clients connect to. Traffic destined for - * this address will be load-balanced across real servers. - */ - struct net_addr addr; - - /** - * IP protocol version indicator. - * - * Values: - * - 0: IPPROTO_IP (IPv4) - * - 41: IPPROTO_IPV6 (IPv6) - * - * Derived from the address type and used for protocol-specific - * processing. - */ - uint8_t ip_proto; - - /** - * Destination port for the virtual service. - * - * STANDARD MODE (VS_PURE_L3_FLAG not set): - * - Valid range: 1-65535 - * - Matches traffic to this specific port - * - Forwarded packets use this port (unless real has port override) - * - * PURE L3 MODE (VS_PURE_L3_FLAG set): - * - MUST be 0 (configuration rejected otherwise) - * - Matches traffic to ANY port - * - Forwarded packets preserve client's original destination port - */ - uint16_t port; - - /** - * Transport layer protocol. - * - * Values: - * - 6: IPPROTO_TCP - * - 17: IPPROTO_UDP - * - * Determines which transport protocol traffic will be matched - * and how sessions are tracked (TCP state machine vs UDP timeout). - */ - uint8_t transport_proto; -}; - -/** - * Virtual service scheduler algorithm. - * - * Determines how new connections/flows are distributed across real servers. - * The scheduler runs when a new session is created or when OPS mode is used. - * - * WEIGHT CONSIDERATION: - * Both algorithms respect real server weights when making selections. - * Higher weight reals receive proportionally more traffic. - */ -enum vs_scheduler { - /** - * Source hash scheduling. - * - * Selects real server based on a hash of the client's source - * address and port. Provides stable, consistent routing where - * the same client always hits the same real server. - * - * CHARACTERISTICS: - * - Deterministic: Same client → same real - * - Session affinity across connections - * - Good for caching scenarios - * - Distribution depends on client diversity - * - * ALGORITHM: - * hash = hash(client_ip, client_port) - * real = weighted_selection(hash, reals, weights) - */ - source_hash = 0, - - /** - * Round-robin scheduling. - * - * Rotates through real servers for successive new flows, - * distributing load evenly regardless of client identity. - * - * CHARACTERISTICS: - * - Non-deterministic: Same client may hit different reals - * - Even distribution across reals - * - No session affinity across connections - * - Good for stateless services - * - * ALGORITHM: - * counter = atomic_increment(vs_counter) - * real = weighted_selection(counter, reals, weights) - */ - round_robin = 1, -}; - -/** - * Source port range for allowed_src filtering. - * - * Defines an inclusive range of source ports that are permitted for - * traffic matching a specific network prefix. Used in conjunction with - * allowed_src to provide fine-grained access control based on both - * source IP address and source port. - */ -struct ports_range { - /** - * Starting port of the range (inclusive). - * - * Valid range: 0-65535 - * Must be less than or equal to 'to' field. - */ - uint16_t from; - - /** - * Ending port of the range (inclusive). - * - * Valid range: 0-65535 - * Must be greater than or equal to 'from' field. - */ - uint16_t to; -}; - -/** - * Allowed source address and port configuration. - * - * Defines a network prefix and optional port ranges that are permitted - * to access a virtual service. When configured, only traffic from matching - * source addresses and ports will be accepted; all other traffic is dropped - * and counted in the packet_src_not_allowed counter. - * - * FILTERING BEHAVIOR: - * - If allowed_src array is empty (allowed_src_count = 0): All sources - * denied - * - If allowed_src contains entries: Only matching sources are permitted - * - Multiple allowed_src entries are evaluated with OR logic (any match allows) - * - * PORT FILTERING: - * - If port_ranges is NULL or port_ranges_count = 0: All source ports permitted - * - If port_ranges contains ranges: Only source ports within ranges permitted - * - Multiple port ranges are evaluated with OR logic (any match allows) - * - * EXAMPLES: - * 1. Allow all traffic from 10.0.0.0/8: - * net = {10.0.0.0, 255.0.0.0}, port_ranges = NULL, port_ranges_count = 0 - * - * 2. Allow only high ports from 192.168.0.0/16: - * net = {192.168.0.0, 255.255.0.0}, port_ranges = [{1024, 65535}], count = 1 - * - * 3. Allow specific ports from 172.16.0.0/12: - * net = {172.16.0.0, 255.240.0.0}, port_ranges = [{80, 80}, {443, 443}], - * count = 2 - */ -struct allowed_sources { - /** Number of networks in the nets array */ - size_t nets_count; - - /** - * Network prefixes (address and mask) for source filtering. - * - * Packets are matched against these networks using: - * (packet_src_ip & mask) == (net.addr & mask) - * - * Special cases: - * - 0.0.0.0/0.0.0.0 (IPv4) or ::/:: (IPv6): Matches all addresses - * - Single host: Use full mask (255.255.255.255 or all-ones for IPv6) - */ - struct net *nets; - - /** Number of port ranges in the port_ranges array */ - size_t port_ranges_count; - - /** - * Array of source port ranges for additional filtering. - * - * When NULL or port_ranges_count = 0: All source ports are permitted - * When specified: Only source ports within these ranges are permitted - * - * Common use cases: - * - Restrict to high ports: [{1024, 65535}] - * - Allow specific services: [{80, 80}, {443, 443}] - * - Custom application ranges: [{8000, 9000}] - * - * Ownership: Caller allocates and manages this array - */ - struct ports_range *port_ranges; - - /** - * Tag identifier for tracking allowed source statistics. - * - * When non-NULL, enables per-tag statistics tracking for packets - * matching this allowed source entry. Multiple allowed_sources entries - * can share the same tag to aggregate statistics across different - * network prefixes or port ranges. - * - * BEHAVIOR: - * - tag = NULL: No statistics tracking for this entry (default) - * - tag = "name": Track packets matching this entry under the specified - * tag - * - * STATISTICS: - * - Tracked in allowed_sources_stats array in named_vs_stats - * - Each unique tag gets its own statistics entry - * - Counts total packets that passed allowed source filtering - * - * USE CASES: - * - Track traffic from different customer networks separately - * - Monitor access patterns by source category - * - Aggregate statistics across multiple network ranges - * - Identify which allowed sources are actively used - * - * EXAMPLES: - * 1. Track internal vs external traffic: - * - Internal networks (10.0.0.0/8, 172.16.0.0/12): tag = "internal" - * - External networks (0.0.0.0/0): tag = "external" - * - * 2. Track per-customer traffic: - * - Customer A networks: tag = "customer_a" - * - Customer B networks: tag = "customer_b" - * - Customer C networks: tag = "customer_c" - * - * CONSTRAINTS: - * - Maximum tag length: 240 characters - * - Tags exceeding this limit will be rejected during configuration - * - * MEMORY MANAGEMENT: - * - Caller owns the string memory - * - String must remain valid for the lifetime of the configuration - * - Balancer does not free this pointer - */ - const char *tag; -}; - -struct named_real_config; - -/** - * Static configuration of a virtual service. - * - * Defines all parameters for a load-balanced service including behavior - * flags, scheduling algorithm, real server backends, and access control. - * - * MEMORY MANAGEMENT: - * - Caller allocates and manages all arrays (reals, allowed_src, peers) - * - Arrays must remain valid for the lifetime of the configuration - * - Use balancer_update_packet_handler() to apply changes - */ -struct vs_config { - /** - * Feature flags bitmask. - * - * Combination of VS_* flags controlling virtual service behavior: - * - VS_PURE_L3_FLAG: Match all ports, preserve client port - * - VS_FIX_MSS_FLAG: Adjust TCP MSS for tunnel overhead - * - VS_GRE_FLAG: Use GRE encapsulation instead of IPIP - * - VS_OPS_FLAG: One-packet scheduling, no session tracking - * - * Multiple flags can be combined with bitwise OR. - */ - uint8_t flags; - - /** - * Scheduling algorithm for new connections. - * - * Determines how new sessions/flows are distributed across - * real servers. See vs_scheduler enum for details. - */ - enum vs_scheduler scheduler; - - /** Number of real servers in the 'reals' array */ - size_t real_count; - - /** - * Array of real server configurations. - * - * Each entry defines a backend server including: - * - Server address and port - * - Weight for load distribution - * - Source address for forwarded packets - * - * REQUIREMENTS: - * - At least one real server must be configured - * - Array length must match real_count - * - * Ownership: Caller allocates and manages this array - */ - struct named_real_config *reals; - - /** Number of allowed source entries in the 'allowed_src' array */ - size_t allowed_src_count; - - /** - * Array of allowed source configurations for access control. - * - * When specified, only traffic from matching source addresses and ports - * will be accepted by this virtual service. Traffic from non-matching - * sources is dropped and counted in the packet_src_not_allowed counter. - * - * BEHAVIOR: - * - NULL or allowed_src_count = 0: All sources are denied (no traffic - * allowed) - * - Non-NULL with allowed_src_count > 0: Only matching sources - * permitted - * - * MATCHING LOGIC: - * For each incoming packet: - * 1. If allowed_src is NULL or count = 0 → DROP - * 2. For each allowed_src entry: - * a. Check if packet source IP matches the network prefix - * b. If port_ranges is NULL or count = 0 → ACCEPT (IP match - * sufficient) c. If port_ranges specified, check if source port matches - * any range d. If both IP and port match → ACCEPT - * 3. If no entry matches → DROP (increment packet_src_not_allowed) - * - * USE CASES: - * - Restrict access to trusted networks - * - Implement IP-based access control lists - * - Prevent unauthorized access to services - * - Combine with port filtering for fine-grained control - * - * Ownership: Caller allocates and manages this array - */ - struct allowed_sources *allowed_src; - - /** Number of IPv4 peer balancers in 'peers_v4' array */ - size_t peers_v4_count; - - /** - * IPv4 peer balancer addresses for ICMP coordination. - * - * In multi-balancer deployments, ICMP error packets may be - * broadcasted to peer balancers for proper error handling - * and session synchronization. - * - * BEHAVIOR: - * - ICMP errors are cloned and sent to all peers - * - Peers can forward errors to appropriate real servers - * - Enables distributed ICMP error handling - * - * Ownership: Caller allocates and manages this array - */ - struct net4_addr *peers_v4; - - /** Number of IPv6 peer balancers in 'peers_v6' array */ - size_t peers_v6_count; - - /** - * IPv6 peer balancer addresses for ICMP coordination. - * - * Same as peers_v4 but for IPv6 deployments. See peers_v4 - * documentation for behavior details. - * - * Ownership: Caller allocates and manages this array - */ - struct net6_addr *peers_v6; -}; - -/** - * Virtual service configuration paired with its identifier. - * - * Combines the unique identifier (address, port, protocol) with the - * complete configuration (flags, reals, scheduling, etc.) for a - * virtual service. - */ -struct named_vs_config { - /** Virtual service identifier (address, port, protocol) */ - struct vs_identifier identifier; - - /** Virtual service configuration (flags, reals, scheduling) */ - struct vs_config config; -}; - -/** - * Per-virtual-service runtime counters. - * - * Tracks packet processing statistics for a specific virtual service, - * including successful forwards, various failure conditions, and - * session management metrics. - */ -struct vs_stats { - /** Total packets received matching this virtual service */ - uint64_t incoming_packets; - - /** Total bytes received matching this virtual service (IP layer) */ - uint64_t incoming_bytes; - - /** - * Packets dropped due to source address not in allowlist. - * - * Incremented when: - * - vs_config.allowed_src is configured (not NULL) - * - Client source address doesn't match any allowed range - * - Packet is dropped before scheduling - */ - uint64_t packet_src_not_allowed; - - /** - * Packets that failed real server selection. - * - * Incremented when: - * - No real servers are configured - * - All real servers are disabled - * - All real servers have zero weight - * - Scheduler cannot select a valid real - */ - uint64_t no_reals; - - /** - * One-Packet Scheduling packets sent without session creation. - * - * Incremented when: - * - VS_OPS_FLAG is set - * - Packet is forwarded to a real - * - No session table entry is created - * - * This counter tracks stateless packet forwarding. - */ - uint64_t ops_packets; - - /** - * Session creation failures due to table capacity. - * - * Incremented when: - * - Session table is full (at capacity) - * - New session cannot be allocated - * - Packet is dropped - * - * MITIGATION: - * - Increase session table capacity - * - Enable auto-resize with appropriate max_load_factor - * - Review session timeout configuration - */ - uint64_t session_table_overflow; - - /** - * ICMP echo request/reply packets processed. - * - * Tracks ICMP echo (ping) packets that matched this virtual - * service and were handled by the balancer. - */ - uint64_t echo_icmp_packets; - - /** - * ICMP error packets forwarded to real servers. - * - * Tracks ICMP errors (destination unreachable, time exceeded, - * etc.) that were matched to sessions and forwarded to the - * appropriate real server. - */ - uint64_t error_icmp_packets; - - /** - * Packets for sessions where the real server is disabled. - * - * Incremented when: - * - Session exists for a specific real - * - That real is currently disabled - * - Packet arrives for the session - * - * These packets are typically dropped or rescheduled depending - * on configuration. - */ - uint64_t real_is_disabled; - - /** - * Packets for sessions where the real server was removed. - * - * Incremented when: - * - Session exists for a specific real - * - That real is no longer in the configuration - * - Packet arrives for the session - * - * This can occur after configuration updates that remove reals. - * Sessions are eventually cleaned up by timeout. - */ - uint64_t real_is_removed; - - /** - * Packets that couldn't be rescheduled. - * - * Incremented when: - * - No existing session found - * - Packet doesn't start a new session (e.g., TCP non-SYN) - * - Packet is dropped - * - * Common for: - * - TCP packets without SYN flag when no session exists - * - Packets arriving after session timeout - */ - uint64_t not_rescheduled_packets; - - /** - * ICMP packets broadcasted to peer balancers. - * - * Incremented when: - * - ICMP error has this VS as source - * - Packet is cloned and sent to configured peers - * - Used for distributed ICMP error handling - * - * Requires vs_config.peers_v4 or peers_v6 to be configured. - */ - uint64_t broadcasted_icmp_packets; - - /** - * Total sessions created for this virtual service. - * - * Tracks the cumulative number of sessions created since - * the balancer started or statistics were reset. Does not - * include OPS packets (which don't create sessions). - */ - uint64_t created_sessions; - - /** Packets successfully forwarded to real servers */ - uint64_t outgoing_packets; - - /** Bytes successfully forwarded to real servers (IP layer) */ - uint64_t outgoing_bytes; -}; - -/** - * Statistics for packets matching allowed source entries with a specific tag. - * - * Tracks the number of packets that passed allowed source filtering for - * entries with a specific tag value. Multiple allowed_sources entries can - * share the same tag, and their statistics are aggregated together. - * - * AGGREGATION: - * - All allowed_sources entries with the same non-NULL tag share one stats - * entry - * - Statistics are cumulative across all matching entries - * - Only non-NULL tags generate statistics entries - * - * LIFECYCLE: - * - Created when first packet matches an allowed source with this tag - * - Persists until virtual service is reconfigured or removed - * - Reset when statistics are cleared - */ -struct allowed_sources_stats { - /** - * Tag identifier matching allowed_sources.tag. - * - * This corresponds to the tag field in allowed_sources entries. - * All entries with this tag contribute to these statistics. - * - * MEMORY MANAGEMENT: - * - This is a heap-allocated copy of the original tag string - * - Must be freed by caller (typically via balancer_stats_free()) - */ - const char *tag; - - /** - * Total packets that passed allowed source filtering for this tag. - * - * Incremented when: - * - Packet source IP matches an allowed_sources network prefix - * - Packet source port matches allowed port ranges (if specified) - * - The matching allowed_sources entry has this tag value - * - Packet proceeds to scheduling (not dropped by other checks) - * - * This counter helps identify: - * - Which allowed source categories are actively used - * - Traffic volume from different source groups - * - Effectiveness of access control policies - */ - uint64_t passes; -}; - -/** - * Virtual service statistics with identifier. - * - * Associates statistics with a specific virtual service and includes - * per-real statistics for all reals backing this VS. - * - * MEMORY MANAGEMENT: - * - The 'reals' array is heap-allocated - * - Must be freed by caller (typically via balancer_stats_free()) - */ -struct named_vs_stats { - /** Virtual service identifier */ - struct vs_identifier identifier; - - /** Statistics for this virtual service */ - struct vs_stats stats; - - /** Number of real servers in the 'reals' array */ - size_t reals_count; - - /** - * Per-real statistics for all reals backing this virtual service. - * - * Array length matches reals_count. Order corresponds to the - * configuration order of reals in the virtual service. - */ - struct named_real_stats *reals; - - /** - * Number of allowed source statistics entries. - * - * This is the count of unique non-NULL tags across all allowed_sources - * entries in the virtual service configuration. Each unique tag gets - * one statistics entry. - * - * RELATIONSHIP TO CONFIG: - * - allowed_sources_count <= vs_config.allowed_src_count - * - Only non-NULL tags are counted - * - Duplicate tags share one statistics entry - * - * EXAMPLES: - * - Config has 3 allowed_src entries with tags ["a", "b", "a"] → count - * = 2 - * - Config has 2 allowed_src entries with tags [NULL, NULL] → count = 0 - * - Config has 4 allowed_src entries with tags ["a", "b", "c", "d"] → - * count = 4 - */ - size_t allowed_sources_count; - - /** - * Per-tag statistics for allowed source filtering. - * - * Array of statistics entries, one per unique non-zero tag in the - * virtual service's allowed_sources configuration. Tracks how many - * packets passed filtering for each tag category. - * - * ARRAY PROPERTIES: - * - Length matches allowed_sources_count - * - Heap-allocated, must be freed by caller - * - NULL if allowed_sources_count = 0 (no tagged entries) - * - Entries are not guaranteed to be in any particular order - * - * USE CASES: - * - Monitor traffic from different source categories - * - Validate access control effectiveness - * - Identify unused allowed source entries - * - Track customer or network-specific traffic volumes - * - * MEMORY MANAGEMENT: - * - Allocated by balancer_show_stats() - * - Must be freed by caller (typically via balancer_stats_free()) - */ - struct allowed_sources_stats *allowed_sources; -}; - -/** - * Virtual service runtime information with identifier. - * - * Provides runtime information about a specific virtual service including - * active session count, last activity, and per-real information. - * - * MEMORY MANAGEMENT: - * - The 'reals' array is heap-allocated - * - Must be freed by caller (typically via balancer_info_free()) - * - * DATA FRESHNESS: - * - Session counts updated during periodic refresh (if enabled) - * - May lag behind actual current state by up to refresh_period - * - last_packet_timestamp updated in real-time by dataplane - */ -struct named_vs_info { - /** Virtual service identifier */ - struct vs_identifier identifier; - - /** - * Timestamp of the last packet processed for this virtual service. - * - * Monotonic timestamp (seconds since boot) of when any packet - * matched this virtual service. Updated in real-time by the - * dataplane. - * - * Useful for: - * - Detecting inactive services - * - Monitoring traffic patterns - * - Identifying stale configurations - */ - uint32_t last_packet_timestamp; - - /** - * Number of active sessions for this virtual service. - * - * This is the sum of active sessions across all real servers - * backing this virtual service. - * - * UPDATE FREQUENCY: - * - Updated asynchronously during periodic refresh - * - Controlled by StateConfig.refresh_period - * - May lag behind actual state by up to refresh_period - * - * NOTE: Represents sessions tracked by the balancer, not - * necessarily all active connections to real servers (which may - * have additional direct connections). - */ - size_t active_sessions; - - /** Number of real servers in the 'reals' array */ - size_t reals_count; - - /** - * Runtime information for each real server backing this VS. - * - * Provides per-real session counts and activity timestamps. - * Array length matches reals_count. Order corresponds to the - * configuration order of reals in the virtual service. - */ - struct named_real_info *reals; -}; diff --git a/modules/balancer/controlplane/balancer.go b/modules/balancer/controlplane/balancer.go new file mode 100644 index 000000000..723c7dc3a --- /dev/null +++ b/modules/balancer/controlplane/balancer.go @@ -0,0 +1,676 @@ +// Package balancer implements the balancer control plane. +package balancer + +import ( + "context" + "fmt" + "slices" + "sync" + "time" + + "github.com/yanet-platform/yanet2/common/commonpb" + "github.com/yanet-platform/yanet2/common/go/relptr" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" +) + +// Protocol constants matching C netinet/in.h values. +const ( + ipprotoIP = 0 // IPv4 + ipprotoIPv6 = 41 // IPv6 + ipprotoTCP = 6 + ipprotoUDP = 17 +) + +var errNoAgentMemory = CodedErrorf(codes.ResourceExhausted, "no agent memory") + +// Balancer manages a single balancer instance including its shared-memory +// packet handler, session table, and lookup indices. +type Balancer struct { + // Pointer to the packet handler instance in the shared memory. + // It uses relative pointers, so one can use relptr package to access it. + handler *PacketHandler + agent *Agent + + realUpdateBuffer []*balancerpb.RealUpdate + + // vsIndex maps virtual service identities to their positions in the + // shared-memory arrays. Rebuilt after every config update via buildIndexes. + vsIndex map[vsKey]vsSlot + + // Last applied config for diffing on Update. + config *balancerpb.BalancerConfig + + log *zap.SugaredLogger + + refresher *Refresher +} + +func (b *Balancer) startRefreshing(mu *sync.Mutex) { + if b.refresher != nil { + b.refresher.Stop() + } + b.refresher = NewRefresher(b, mu) + b.refresher.Run(context.Background()) +} + +// nullifyReusedFields clears pointers to resources that were reused (via relptr.Equate) +// by the new handler. Both old and new handlers share these resources; nullifying them +// on the old handler prevents Free from double-freeing the shared resources. +// Must be called before freeing the old handler. +func (b *Balancer) nullifyReusedFields( + newHandler *PacketHandler, + reuseReport *balancerpb.ReuseReport, +) { + handler := b.handler + services := relptr.Slice(&handler.Vs, handler.Vs_count) + for _, vsReuse := range reuseReport.VsReuseReports { + key := makeVsKey(vsReuse.VsIdentifier) + slot, ok := b.vsIndex[key] + if !ok { + continue + } + if vsReuse.AclReused { + relptr.Set(&services[slot.index].Acl, nil) + } + if vsReuse.SelectorReused { + relptr.Set(&services[slot.index].Selector, nil) + } + } + if reuseReport.Ipv4VsMatcherReused { + relptr.Set(&handler.Ipv4_vs_matcher, nil) + } + if reuseReport.Ipv6VsMatcherReused { + relptr.Set(&handler.Ipv6_vs_matcher, nil) + } + if reuseReport.Ipv4DecapFilterReused { + relptr.Set(&handler.Decap_ipv4_filter, nil) + } + if reuseReport.Ipv6DecapFilterReused { + relptr.Set(&handler.Decap_ipv6_filter, nil) + } + + // Nullify tracker_shards on old reals that were inherited by new reals. + // real.populate copies tracker_shards via relptr.Equate, creating shared + // ownership. Without nullifying the old side, vs.free would free shared memory. + nullifySharedTrackerShards(handler, newHandler) +} + +// nullifySharedTrackerShards clears tracker_shards on old reals whose shards +// were inherited by the corresponding new real. +func nullifySharedTrackerShards(oldHandler, newHandler *PacketHandler) { + oldVsList := relptr.Slice(&oldHandler.Vs, oldHandler.Vs_count) + newVsList := relptr.Slice(&newHandler.Vs, newHandler.Vs_count) + for i := range min(len(oldVsList), len(newVsList)) { + oldReals := relptr.Slice(&oldVsList[i].Reals, oldVsList[i].Reals_count) + newReals := relptr.Slice(&newVsList[i].Reals, newVsList[i].Reals_count) + for j := range min(len(oldReals), len(newReals)) { + oldTracker := relptr.Deref(&oldReals[j].Tracker_shards) + newTracker := relptr.Deref(&newReals[j].Tracker_shards) + if oldTracker != nil && oldTracker == newTracker { + relptr.Set(&oldReals[j].Tracker_shards, nil) + } + } + } +} + +func (b *Balancer) buildIndexes() { + services := relptr.Slice(&b.handler.Vs, b.handler.Vs_count) + b.vsIndex = make(map[vsKey]vsSlot, b.handler.Vs_count) + + for idx := range services { + vs := &services[idx] + if vs.isRemoved() { + continue + } + slot := vsSlot{ + index: idx, + realSlots: make(map[realKey]int, int(vs.Reals_count)), + } + + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + for realIdx := range reals { + rl := &reals[realIdx] + if rl.isRemoved() { + continue + } + slot.realSlots[rl.key()] = realIdx + } + + b.vsIndex[vs.key()] = slot + } +} + +func (b *Balancer) Config() *balancerpb.BalancerConfig { + return b.config +} + +// SessionTableCapacity returns the current session table capacity. +func (b *Balancer) SessionTableCapacity() uint64 { + st := relptr.Deref(&b.handler.Session_table) + if st == nil { + return 0 + } + return uint64(st.capacity()) +} + +// BufferedRealUpdates returns the currently buffered real updates. +func (b *Balancer) BufferedRealUpdates() []*balancerpb.RealUpdate { + return b.realUpdateBuffer +} + +// Update applies a new configuration to the balancer. If config.PacketHandler is nil, +// only the state (session table, WLC params) is updated in-place. Otherwise a new +// packet handler is built, installed, and the old one is freed. +func (b *Balancer) Update( + config *balancerpb.BalancerConfig, + now *time.Time, +) (*balancerpb.ReuseReport, error) { + st := relptr.Deref(&b.handler.Session_table) + + if now != nil && config.State.SessionTableCapacity != nil { + newStSize := int(*config.State.SessionTableCapacity) + if err := b.handler.resizeSessionTable(st, newStSize, *now); err != nil { + return nil, AsStatus(fmt.Errorf("resize session table: %w", err), codes.Internal) + } + } + + mergedStateConfig := mergeStateConfig(b.config.State, config.State) + + b.handler.setState(mergedStateConfig, st) + + if b.refresher != nil { + b.refresher.UpdateRefreshPeriod(mergedStateConfig.RefreshPeriod.AsDuration()) + } + + if config.PacketHandler == nil { + return nil, nil + } + + if err := validatePacketHandlerConfig(config.PacketHandler); err != nil { + return nil, AsStatus( + fmt.Errorf("invalid packet handler config: %w", err), + codes.InvalidArgument, + ) + } + + handler, reuseReport, err := NewPacketHandler( + config, + b.handler.name(), + relptr.Deref(&b.handler.Session_table), + b.agent, + b.handler, + ) + if err != nil { + return nil, AsStatus(fmt.Errorf("create packet handler: %w", err), codes.Internal) + } + + if err := b.agent.install(handler); err != nil { + handler.free(b.agent) + yanet.Free(b.agent.AsYanetAgent(), handler) + return nil, AsStatus(fmt.Errorf("install handler: %w", err), codes.Internal) + } + + // The ordering below is critical: + // 1. Nullify reused fields on the OLD handler so Free won't double-free shared resources. + // 2. Forget the old handler (frees its slot in the agent's handler table). + // 3. Register the new handler (takes the freed slot — cannot fail after Forget). + // 4. Free the old handler's remaining (non-reused) resources. + b.nullifyReusedFields(handler, reuseReport) + + b.agent.forget(b.handler) + if err := b.agent.register(handler); err != nil { + panic("register after forget should never fail: agent slot was just freed") + } + + b.handler.free(b.agent) + yanet.Free(b.agent.AsYanetAgent(), b.handler) + + b.handler = handler + b.config = config + + b.buildIndexes() + + return reuseReport, nil +} + +// Stable index encoding: +// A stable index is a uint64 that uniquely identifies a VS or real across config updates. +// High 32 bits = epoch (incremented when a slot is reused by a different entity). +// Low 32 bits = config index (position in the allocated array). +// When an entity keeps its position across an update, it inherits the same stable index. +// When a new entity occupies a previously-used slot, the epoch is bumped. +func makeStableIdx(epoch uint32, configIndex uint32) uint64 { + return uint64(epoch)<<32 | uint64(configIndex) +} + +func epochOf(stableIdx uint64) uint32 { + return uint32(stableIdx >> 32) +} + +func configIndexOf(stableIdx uint64) uint32 { + return uint32(stableIdx & 0xFFFFFFFF) +} + +// NewBalancer creates a new balancer instance from the given config. +// +// It allocates a packet handler and session table in shared memory, +// populates all fields, compiles filters, registers counters, and +// installs the handler into the dataplane. +// +// On any failure, all allocated resources are freed via Destroy. +func NewBalancer( + agent *Agent, + name string, + config *balancerpb.BalancerConfig, + log *zap.SugaredLogger, +) (*Balancer, error) { + if err := validateBalancerConfig(config); err != nil { + return nil, AsStatus(fmt.Errorf("invalid config: %w", err), codes.InvalidArgument) + } + + stateConfig := config.State + + st := agent.createSessionTable(int(*stateConfig.SessionTableCapacity)) + if st == nil { + return nil, errNoAgentMemory + } + + handler, _, err := NewPacketHandler(config, name, st, agent, nil) + if err != nil { + agent.destroySessionTable(st) + return nil, AsStatus(fmt.Errorf("create handler: %w", err), codes.Internal) + } + + // From this point on, handler is properly initialized and Destroy + // can be called safely for cleanup on any subsequent failure. + + b := &Balancer{ + handler: handler, + agent: agent, + config: config, + log: log, + } + + // Register handler in agent storage, then install into dataplane. + if err := agent.register(handler); err != nil { + b.Destroy() + return nil, AsStatus(fmt.Errorf("register handler: %w", err), codes.Internal) + } + + if err := agent.install(handler); err != nil { + agent.forget(handler) + b.Destroy() + return nil, AsStatus(fmt.Errorf("install handler: %w", err), codes.Internal) + } + + b.buildIndexes() + + return b, nil +} + +func (b *Balancer) Destroy() { + if b.refresher != nil { + b.refresher.Stop() + } + handler := b.handler + agent := b.agent + if handler.Session_table != nil { + agent.destroySessionTable(relptr.Deref(&handler.Session_table)) + } + handler.free(agent) + yanet.Free(agent.AsYanetAgent(), handler) +} + +func (b *Balancer) UpdateVS( + vsList []*balancerpb.VirtualService, +) (*balancerpb.ReuseReport, error) { + currentVs := b.config.PacketHandler.Vs + + vsMap := make(map[vsKey]int, len(currentVs)) + for idx, vs := range currentVs { + vsMap[makeVsKey(vs.Id)] = idx + } + + newVsList := slices.Clone(currentVs) + for _, vs := range vsList { + key := makeVsKey(vs.Id) + if idx, ok := vsMap[key]; ok { + newVsList[idx] = vs + } else { + newVsList = append(newVsList, vs) + } + } + + config := proto.Clone(b.config).(*balancerpb.BalancerConfig) + config.PacketHandler.Vs = newVsList + + return b.Update(config, nil) +} + +func (b *Balancer) DeleteVS( + vsList []*balancerpb.VirtualService, +) (*balancerpb.ReuseReport, error) { + for _, vs := range vsList { + k := makeVsKey(vs.Id) + if _, ok := b.vsIndex[k]; !ok { + return nil, CodedErrorf( + codes.NotFound, + "virtual service %s not found", vsIDToString(vs.Id), + ) + } + } + + deletedVs := make(map[vsKey]struct{}) + for _, vs := range vsList { + k := makeVsKey(vs.Id) + deletedVs[k] = struct{}{} + } + + newVsList := make([]*balancerpb.VirtualService, 0, len(b.config.PacketHandler.Vs)) + for _, vs := range b.config.PacketHandler.Vs { + k := makeVsKey(vs.Id) + if _, ok := deletedVs[k]; !ok { + newVsList = append(newVsList, vs) + } + } + + config := proto.Clone(b.config).(*balancerpb.BalancerConfig) + config.PacketHandler.Vs = newVsList + + return b.Update(config, nil) +} + +func (b *Balancer) GetState( + handlerRef *balancerpb.PacketHandlerRef, + filter *balancerpb.Filter, + includeCounters bool, + now time.Time, +) ([]*balancerpb.BalancerState, error) { + if err := validateFilter(filter); err != nil { + return nil, err + } + + matcher := newFilterMatcher(filter) + dpConfig := b.agent.AsYanetAgent().DPConfig() + workers := dpConfig.WorkerCount() + balancerName := b.handler.name() + + if !includeCounters { + // No counters means no need in module positions. + state := b.buildState(workers, &matcher, nil, now) + matcher.filterReals(state) + compactBalancerState(state) + return []*balancerpb.BalancerState{state}, nil + } + + var results []*balancerpb.BalancerState + + for position := range dpConfig.AllModulePositions("balancer") { + if position.ModuleName != balancerName { + continue + } + if !matchesHandlerRef(handlerRef, &position) { + continue + } + + state := b.buildState(workers, &matcher, &position, now) + b.applyCounters(state, dpConfig, &position) + matcher.filterReals(state) + compactBalancerState(state) + results = append(results, state) + } + + return results, nil +} + +func (b *Balancer) buildState( + workers uint32, + matcher *filterMatcher, + position *yanet.ModuleReference, + now time.Time, +) *balancerpb.BalancerState { + services := relptr.Slice(&b.handler.Vs, b.handler.Vs_count) + + state := &balancerpb.BalancerState{ + BalancerName: b.handler.name(), + L4Stats: &balancerpb.L4Stats{}, + CommonStats: &balancerpb.CommonStats{}, + IcmpIpv4Stats: &balancerpb.IcmpStats{}, + IcmpIpv6Stats: &balancerpb.IcmpStats{}, + SourceIpv4: append([]byte(nil), b.handler.Source_v4.Bytes[:]...), + SourceIpv6: append([]byte(nil), b.handler.Source_v6.Bytes[:]...), + DecapAddresses: restoreDecapAddrs(b.handler), + } + if position != nil { + state.Ref = &balancerpb.PacketHandlerRef{ + Device: &position.Device, + Pipeline: &position.Pipeline, + Function: &position.Function, + Chain: &position.Chain, + } + } + + state.VirtualServices = make([]*balancerpb.VsState, len(services)) + for vsIdx := range services { + vs := &services[vsIdx] + if vs.isRemoved() { + continue + } + if matcher.hasVsFilter && !matcher.matchVsID(vs.id()) { + continue + } + state.VirtualServices[vsIdx] = vs.state(workers, now) + vsState := state.VirtualServices[vsIdx] + state.ActiveSessions += vsState.ActiveSessions + if vsState.LastPacketTimestamp != nil { + if state.LastPacketTimestamp == nil || + vsState.LastPacketTimestamp.Seconds > state.LastPacketTimestamp.Seconds { + state.LastPacketTimestamp = vsState.LastPacketTimestamp + } + } + } + + return state +} + +func (b *Balancer) applyCounters( + state *balancerpb.BalancerState, + dpConfig *yanet.DPConfig, + position *yanet.ModuleReference, +) { + counters := dpConfig.ModuleCounters( + position.Device, + position.Pipeline, + position.Function, + position.Chain, + "balancer", + b.handler.name(), + []string{}, + ) + for _, counter := range counters { + applyCounter(b.handler, state, counter) + } +} + +func matchesHandlerRef(ref *balancerpb.PacketHandlerRef, pos *yanet.ModuleReference) bool { + if ref == nil { + return true + } + if ref.Device != nil && *ref.Device != pos.Device { + return false + } + if ref.Pipeline != nil && *ref.Pipeline != pos.Pipeline { + return false + } + if ref.Function != nil && *ref.Function != pos.Function { + return false + } + if ref.Chain != nil && *ref.Chain != pos.Chain { + return false + } + return true +} + +func (m *filterMatcher) filterReals(state *balancerpb.BalancerState) { + if !m.hasRealFilter { + return + } + for _, vsState := range state.VirtualServices { + if vsState == nil { + continue + } + for realIdx, realState := range vsState.Reals { + if realState != nil && !m.matchRealID(realState.Id) { + vsState.Reals[realIdx] = nil + } + } + } +} + +func (b *Balancer) FlushRealUpdates() (int, error) { + if len(b.realUpdateBuffer) == 0 { + return 0, nil + } + + updates := b.realUpdateBuffer + + updatesApplied, err := b.UpdateReals(updates, false) + if err != nil { + return 0, fmt.Errorf("failed to update reals: %w", err) + } + + b.realUpdateBuffer = nil + + return updatesApplied, nil +} + +func (b *Balancer) UpdateReals(updates []*balancerpb.RealUpdate, buffer bool) (int, error) { + if buffer { + b.realUpdateBuffer = append(b.realUpdateBuffer, updates...) + return len(updates), nil + } + + services := relptr.Slice(&b.handler.Vs, b.handler.Vs_count) + affectedVs := make(map[vsKey]int) + + for updateIdx, update := range updates { + serviceKey := makeVsKey(update.RealId.Vs) + serviceSlot, ok := b.vsIndex[serviceKey] + if !ok { + return 0, CodedErrorf( + codes.NotFound, + "real update at index %d: virtual service %s not found", + updateIdx, vsIDToString(update.RealId.Vs), + ) + } + realSlot, ok := serviceSlot.realSlots[makeRealKey(update.RealId.Real)] + if !ok { + return 0, CodedErrorf( + codes.NotFound, + "real update at index %d: real %s not found", + updateIdx, realIDToString(update.RealId.Real), + ) + } + vs := &services[serviceSlot.index] + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + r := &reals[realSlot] + if update.Weight != nil { + r.Weight = *update.Weight + } + if update.Enable != nil { + if *update.Enable { + r.Flags |= RealFlagEnabled + } else { + r.Flags &^= uint8(RealFlagEnabled) + } + } + affectedVs[serviceKey] = serviceSlot.index + } + + for _, idx := range affectedVs { + vs := &services[idx] + if err := vs.updateRealSelector(&b.handler.Rcu, b.agent); err != nil { + return 0, AsStatus( + fmt.Errorf("failed to update ring for virtual service %s: %w", vs, err), + codes.Internal, + ) + } + } + + return len(updates), nil +} + +func mergeStateConfig(old, update *balancerpb.StateConfig) *balancerpb.StateConfig { + result := &balancerpb.StateConfig{ + SessionTableCapacity: old.SessionTableCapacity, + SessionTableMaxLoadFactor: old.SessionTableMaxLoadFactor, + Wlc: old.Wlc, + RefreshPeriod: old.RefreshPeriod, + } + if update == nil { + return result + } + if update.SessionTableCapacity != nil { + result.SessionTableCapacity = update.SessionTableCapacity + } + if update.SessionTableMaxLoadFactor != nil { + result.SessionTableMaxLoadFactor = update.SessionTableMaxLoadFactor + } + if update.Wlc != nil { + mergedWlc := &balancerpb.WlcConfig{ + Power: old.Wlc.Power, + MaxWeight: old.Wlc.MaxWeight, + } + if update.Wlc.Power != nil { + mergedWlc.Power = update.Wlc.Power + } + if update.Wlc.MaxWeight != nil { + mergedWlc.MaxWeight = update.Wlc.MaxWeight + } + result.Wlc = mergedWlc + } + if update.RefreshPeriod != nil { + result.RefreshPeriod = update.RefreshPeriod + } + return result +} + +// Metrics reads dataplane counters for all positions where this balancer is +// installed and returns a flat slice of commonpb.Metric. +func (b *Balancer) Metrics(now time.Time) ([]*commonpb.Metric, error) { + dpConfig := b.agent.AsYanetAgent().DPConfig() + balancerName := b.handler.name() + workers := dpConfig.WorkerCount() + services := relptr.Slice(&b.handler.Vs, b.handler.Vs_count) + + var result []*commonpb.Metric + + for position := range dpConfig.AllModulePositions("balancer") { + if position.ModuleName != balancerName { + continue + } + + refLabels := []*commonpb.Label{ + {Name: "device", Value: position.Device}, + {Name: "pipeline", Value: position.Pipeline}, + {Name: "function", Value: position.Function}, + {Name: "chain", Value: position.Chain}, + {Name: "config", Value: balancerName}, + } + + counters := dpConfig.ModuleCounters( + position.Device, position.Pipeline, + position.Function, position.Chain, + "balancer", balancerName, []string{}, + ) + + result = append(result, collectCounterMetrics(services, counters, refLabels)...) + result = append(result, collectSessionMetrics(services, workers, now, refLabels)...) + } + + return result, nil +} diff --git a/modules/balancer/agent/balancerpb/balancer.proto b/modules/balancer/controlplane/balancerpb/balancer.proto similarity index 58% rename from modules/balancer/agent/balancerpb/balancer.proto rename to modules/balancer/controlplane/balancerpb/balancer.proto index c4f103c81..6ab7af5e9 100644 --- a/modules/balancer/agent/balancerpb/balancer.proto +++ b/modules/balancer/controlplane/balancerpb/balancer.proto @@ -2,21 +2,20 @@ syntax = "proto3"; package balancerpb; -import "modules/balancer/agent/balancerpb/graph.proto"; -import "modules/balancer/agent/balancerpb/info.proto"; -import "modules/balancer/agent/balancerpb/inspect.proto"; -import "modules/balancer/agent/balancerpb/module.proto"; -import "modules/balancer/agent/balancerpb/stats.proto"; +import "modules/balancer/controlplane/balancerpb/common.proto"; +import "modules/balancer/controlplane/balancerpb/filter.proto"; +import "modules/balancer/controlplane/balancerpb/memory.proto"; +import "modules/balancer/controlplane/balancerpb/state.proto"; import "common/commonpb/metric.proto"; -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; +option go_package = "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb;balancerpb"; -// Balancer agent service. +// Balancer service. // // Provides management operations for load balancer instances including // configuration, real server updates, statistics, and runtime information. -service BalancerService { +service Balancer { // Create or update a balancer configuration. // // This RPC operates in two distinct modes: @@ -31,16 +30,16 @@ service BalancerService { // UPDATE MODE (balancer exists): // - Updates existing balancer configuration // - Non-optional fields in PacketHandlerConfig MUST be provided: - // * vs (virtual services list) - // * source_address_v4 - // * source_address_v6 - // * decap_addresses - // * sessions_timeouts + // * vs (virtual services list) + // * source_address_v4 + // * source_address_v6 + // * decap_addresses + // * sessions_timeouts // - Optional fields in StateConfig (only updated if specified): - // * session_table_capacity - // * session_table_max_load_factor - // * adjust_weights_config - // * refresh_period (if set to 0, disables periodic refresh) + // * session_table_capacity + // * session_table_max_load_factor + // * adjust_weights_config + // * refresh_period (if set to 0, disables periodic refresh) // - Session table may be resized if capacity changes // - Background tasks are restarted with new configuration // @@ -50,7 +49,28 @@ service BalancerService { // - No conflicting virtual services allowed // // Returns error if validation fails or operation cannot be completed. - rpc UpdateConfig(UpdateConfigRequest) returns (UpdateConfigResponse); + rpc SetConfig(SetConfigRequest) returns (SetConfigResponse); + + // List all balancers. + // + // Returns a list of all balancer instances currently managed + // by the controlplane + rpc ListBalancers(ListBalancersRequest) returns (ListBalancersResponse); + + rpc GetConfig(GetConfigRequest) returns (GetConfigResponse); + + rpc GetState(GetStateRequest) returns (GetStateResponse); + + rpc GetMetrics(GetMetricsRequest) returns (GetMetricsResponse); + + rpc GetMemoryUsage(GetMemoryUsageRequest) returns (GetMemoryUsageResponse); + + // List active sessions. + // + // Returns detailed information about all active sessions tracked + // by the balancer, including client addresses, virtual service + // mappings, real server assignments, and timeout information. + rpc ListSessions(ListSessionsRequest) returns (stream Session); // Update real server properties. // @@ -60,7 +80,7 @@ service BalancerService { // Supports two modes: // - Immediate: Updates are applied immediately (buffer=false) // - Buffered: Updates queued and applied atomically on flush - // (buffer=true) + // (buffer=true) // // Buffered mode is useful for applying multiple updates atomically // to avoid intermediate inconsistent states. @@ -73,62 +93,7 @@ service BalancerService { // intermediate states where only some updates are applied. // // Returns the number of updates that were flushed. - rpc FlushRealUpdates(FlushRealUpdatesRequest) returns (FlushRealUpdatesResponse); - - // Show balancer configuration. - // - // Returns the current configuration for the specified balancer, - // including any buffered real server updates that haven't been - // flushed yet. - rpc ShowConfig(ShowConfigRequest) returns (ShowConfigResponse); - - // List all balancer configurations. - // - // Returns a list of all balancer instances currently managed - // by the control plane, along with their configurations. - rpc ListConfigs(ListConfigsRequest) returns (ListConfigsResponse); - - // Show balancer statistics. - // - // Returns packet processing statistics for the balancer. - // Statistics depend on the packet handler's topological position - // in the processing pipeline. - // - // Optional filtering by device, pipeline, function, and chain - // allows retrieving statistics for specific processing contexts. - rpc ShowStats(ShowStatsRequest) returns (ShowStatsResponse); - - // Show balancer runtime information. - // - // Returns runtime information including active session counts, - // last packet timestamps, and per-virtual-service/per-real statistics. - // - // Session counts are updated asynchronously during periodic refresh, - // so they may not reflect the exact current state. - rpc ShowInfo(ShowInfoRequest) returns (ShowInfoResponse); - - // Show balancer topology graph. - // - // Returns a snapshot of the complete balancer topology showing all - // virtual services and their real servers with current operational - // states (effective weights, enabled status). - rpc ShowGraph(ShowGraphRequest) returns (ShowGraphResponse); - - // Show memory usage inspection. - // - // Returns agent-level memory usage information including memory limit, - // current usage, and detailed memory inspection for all balancer - // instances. - rpc ShowInspect(ShowInspectRequest) returns (ShowInspectResponse); - - // Show active sessions. - // - // Returns detailed information about all active sessions tracked - // by the balancer, including client addresses, virtual service - // mappings, real server assignments, and timeout information. - rpc ShowSessions(ShowSessionsRequest) returns (ShowSessionsResponse); - - rpc GetMetrics(GetMetricsRequest) returns (GetMetricsResponse); + rpc FlushReals(FlushRealsRequest) returns (FlushRealsResponse); // Update specific virtual services. // @@ -181,7 +146,7 @@ message GetMetricsResponse { /////////////////////////////////////////////////////////////////////////////// // Request to create or update a balancer configuration. -message UpdateConfigRequest { +message SetConfigRequest { // Balancer instance name (required). // // Unique identifier for this balancer instance. @@ -196,39 +161,13 @@ message UpdateConfigRequest { BalancerConfig config = 2; } -// Response to UpdateConfig request. -// -// Contains the balancer name and metadata about filter reuse during the update. -message UpdateConfigResponse { - string name = 1; - - // Metadata about filter reuse during configuration update. - // - // Provides visibility into which filters were reused (not recompiled) - // vs which were recompiled. This helps understand the impact of - // configuration changes and identify optimization opportunities. - UpdateInfo update_info = 2; +message VsReuseReport { + VsIdentifier vs_identifier = 1; + bool acl_reused = 2; + bool selector_reused = 3; } -// Metadata about filter reuse during configuration update. -// -// Tracks which packet processing filters were reused from the previous -// configuration vs which needed to be recompiled. Filter compilation -// can be expensive, so reusing filters when possible improves update -// performance. -message UpdateInfo { - // Whether the configuration was newly created (true) or updated - // (false). - // - // When true: This is a new balancer configuration that was just - // created. - // Filter reuse fields will be empty/false since there was no - // previous configuration to reuse from. - // When false: This is an update to an existing configuration. - // Filter reuse fields indicate what was reused from the - // previous configuration. - bool created = 1; - +message ReuseReport { // Whether the IPv4 virtual service matcher was reused. // // The VS matcher is the filter that maps incoming packets to @@ -238,7 +177,7 @@ message UpdateInfo { // false: IPv4 VS matcher was recompiled // // Note: Always false when created=true (no previous config to reuse). - bool vs_ipv4_matcher_reused = 2; + bool ipv4_vs_matcher_reused = 1; // Whether the IPv6 virtual service matcher was reused. // @@ -246,19 +185,24 @@ message UpdateInfo { // false: IPv6 VS matcher was recompiled // // Note: Always false when created=true (no previous config to reuse). - bool vs_ipv6_matcher_reused = 3; + bool ipv6_vs_matcher_reused = 2; - // List of virtual services that reused their ACL filters. - // - // ACL (Access Control List) filters enforce allowed_src rules - // for each virtual service. When a VS's allowed_src configuration - // hasn't changed, its ACL filter can be reused. - // - // This list contains the identifiers of VSs where ACL was reused. - // VSs not in this list had their ACL filters recompiled. - // - // Note: Always empty when created=true (no previous config to reuse). - repeated VsIdentifier vs_acl_reuses = 4; + bool ipv4_decap_filter_reused = 3; + + bool ipv6_decap_filter_reused = 4; + + repeated VsReuseReport vs_reuse_reports = 5; +} + +// Response to UpdateConfig request. +// +// Contains the balancer name and metadata about filter reuse during the update. +message SetConfigResponse { + string name = 1; + optional ReuseReport reuse = 2; + + // As resize can be done to the bigger size, it returns the new capacity + uint64 session_table_capacity = 3; } /////////////////////////////////////////////////////////////////////////////// @@ -311,29 +255,29 @@ message UpdateRealsResponse { /////////////////////////////////////////////////////////////////////////////// -// Request to flush buffered real server updates. -message FlushRealUpdatesRequest { +// Request to flush buffered real updates. +message FlushRealsRequest { // Balancer instance name (optional, auto-selects if only one exists) optional string name = 1; } // Response to FlushRealUpdates request. -message FlushRealUpdatesResponse { +message FlushRealsResponse { string name = 1; // Number of buffered updates that were flushed and applied - uint32 updates_flushed = 2; + uint64 updates_flushed = 2; } /////////////////////////////////////////////////////////////////////////////// // Request to show balancer configuration. -message ShowConfigRequest { +message GetConfigRequest { // Balancer instance name (optional, auto-selects if only one exists) optional string name = 1; } // Response containing balancer configuration. -message ShowConfigResponse { +message GetConfigResponse { // Balancer instance name string name = 1; @@ -348,80 +292,24 @@ message ShowConfigResponse { } // Request to list all balancer configurations. -message ListConfigsRequest {} +message ListBalancersRequest {} // Response containing list of balancer names. -message ListConfigsResponse { +message ListBalancersResponse { // Names of all balancer instances - repeated string configs = 1; -} - -/////////////////////////////////////////////////////////////////////////////// - -// Request to show balancer runtime information. -message ShowInfoRequest { - // Balancer instance name (optional, auto-selects if only one exists) - optional string name = 1; + repeated string names = 1; } -// Response containing balancer runtime information. -message ShowInfoResponse { - // Balancer instance name - string name = 1; - - // Runtime information including session counts and timestamps. - // - // Session counts are updated asynchronously during periodic refresh, - // so they may lag behind the actual current state. - BalancerInfo info = 2; -} - -// Request to show balancer statistics. -// -// If name is specified, returns stats only for that balancer instance. -// If name is not specified, returns stats for all balancer instances. -// -// PacketHandlerRef fields are treated as filters: -// - if device is specified: only positions for that device are included -// - same for pipeline, function, chain. -message ShowStatsRequest { - // Balancer instance name (optional). - optional string name = 1; - - // Packet handler reference filter - // (optional fields inside PacketHandlerRef). - PacketHandlerRef ref = 2; -} - -// A single stats entry bound to a specific balancer instance name and packet -// handler position. -message StatsEntry { - // Balancer instance name (manager/module config name). - string name = 1; - - // Packet handler reference identifying the dataplane position. - PacketHandlerRef ref = 2; - - // Packet processing statistics for this position. - BalancerStats stats = 3; -} - -// Response containing balancer statistics for one or more positions. -message ShowStatsResponse { - // List of stats entries. Each entry includes (name, ref, stats). - repeated StatsEntry entries = 1; -} - -/////////////////////////////////////////////////////////////////////////////// - // Request to show active sessions. -message ShowSessionsRequest { +message ListSessionsRequest { // Balancer instance name (optional, auto-selects if only one exists) optional string name = 1; + + optional Filter filter = 2; } // Response containing active session information. -message ShowSessionsResponse { +message ListSessionsResponse { // Balancer instance name string name = 1; @@ -429,32 +317,19 @@ message ShowSessionsResponse { // // Each session includes client information, virtual service mapping, // real server assignment, creation time, last packet time, and timeout. - repeated SessionInfo sessions = 2; + repeated Session sessions = 2; } -message ShowGraphRequest { - // Balancer instance name (optional, auto-selects if only one exists) - optional string name = 1; -} - -message ShowGraphResponse { - // Balancer instance name - string name = 1; - - Graph graph = 2; -} - -/////////////////////////////////////////////////////////////////////////////// - // Request to show memory usage inspection. -message ShowInspectRequest { - // Empty - returns agent-level inspect with all balancers +message GetMemoryUsageRequest { + optional string name = 1; + optional Filter filter = 2; } // Response containing memory usage inspection. -message ShowInspectResponse { +message GetMemoryUsageResponse { // Agent-level memory inspection including all balancers - AgentInspect inspect = 1; + AgentMemoryUsage memory_usage = 1; } // Request to update specific virtual services. @@ -474,7 +349,7 @@ message UpdateVSRequest { // Each VS must include a complete configuration. VSs with matching // identifiers in the current config will be replaced; new identifiers // will be added. VSs not in this list remain unchanged. - repeated VirtualService vs = 2; + repeated VirtualService services = 2; } // Response to UpdateVS request. @@ -483,15 +358,7 @@ message UpdateVSRequest { message UpdateVSResponse { // Balancer instance name that was updated. string name = 1; - - // Metadata about filter reuse during the update. - // - // The vs_acl_reuses list only contains virtual services that were - // included in the update request. This allows tracking which ACL - // filters were reused vs recompiled specifically for the updated VSs. - // - // Note: created field will always be false for UpdateVS operations. - UpdateInfo info = 2; + ReuseReport reuse = 2; } // Request to delete specific virtual services. @@ -511,7 +378,7 @@ message DeleteVSRequest { // Only the VS identifiers (id field) are used for matching; other // fields in the VirtualService message are ignored. VSs with matching // identifiers will be removed from the configuration. - repeated VirtualService vs = 2; + repeated VirtualService services = 2; } // Response to DeleteVS request. @@ -520,12 +387,17 @@ message DeleteVSRequest { message DeleteVSResponse { // Balancer instance name that was updated. string name = 1; + ReuseReport reuse = 2; +} - // Metadata about filter reuse during the update. - // - // The vs_acl_reuses list will always be empty for delete operations, - // since deleted VSs don't have ACL filters to reuse. - // - // Note: created field will always be false for DeleteVS operations. - UpdateInfo info = 2; -} \ No newline at end of file +message GetStateRequest { + optional string name = 1; + + optional PacketHandlerRef packet_handler_ref = 2; + + optional Filter filter = 3; + + bool include_counters = 4; +} + +message GetStateResponse { repeated BalancerState state = 1; } \ No newline at end of file diff --git a/modules/balancer/agent/balancerpb/module.proto b/modules/balancer/controlplane/balancerpb/common.proto similarity index 78% rename from modules/balancer/agent/balancerpb/module.proto rename to modules/balancer/controlplane/balancerpb/common.proto index 960d00c20..978932f5f 100644 --- a/modules/balancer/agent/balancerpb/module.proto +++ b/modules/balancer/controlplane/balancerpb/common.proto @@ -3,34 +3,11 @@ syntax = "proto3"; package balancerpb; import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; +import "common/filterpb/filter.proto"; -// IP address representation supporting both IPv4 and IPv6. -// -// The bytes field contains the raw IP address bytes: -// - IPv4: 4 bytes -// - IPv6: 16 bytes -message Addr { bytes bytes = 1; } - -// Network prefix representation. -// -// Represents a network address with a prefix length (CIDR notation). -// For example: 192.168.1.0/24 or 2001:db8::/32 -message Net { - // Network address (IPv4 or IPv6) - Addr addr = 1; - - // Network mask (IPv4 or IPv6). - // - // Defines which bits of the address are significant for matching. - // For example: - // - 255.255.255.0 (IPv4 /24): Match first 24 bits - // - ffff:ffff:ffff:: (IPv6 /48): Match first 48 bits - // - 0.0.0.0 (IPv4 /0): Match all addresses (any) - // - :: (IPv6 /0): Match all addresses (any) - Addr mask = 2; -} +option go_package = "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb;balancerpb"; // Transport layer protocol. // @@ -61,14 +38,14 @@ enum VsScheduler { // - Real server weights affect the hash space distribution // - Higher weight reals receive proportionally more hash buckets // - A real with weight 2 gets approximately twice the traffic - // of weight 1 + // of weight 1 // // Use cases: // - Stateful applications requiring client stickiness // - Applications with client-side caching // - When consistent routing is more important than perfect load // distribution - SOURCE_HASH = 0; + SH = 0; // Weighted Round Robin Scheduler. // @@ -86,7 +63,9 @@ enum VsScheduler { // - Stateless applications // - When even load distribution is critical // - Applications that don't require session affinity - ROUND_ROBIN = 1; + WRR = 1; + + WLC = 2; } // Virtual service unique identifier. @@ -101,7 +80,7 @@ enum VsScheduler { // and transport protocol, regardless of destination port. message VsIdentifier { // Virtual service IP address (IPv4 or IPv6) - Addr addr = 1; + bytes addr = 1; // Virtual service port number. // @@ -119,26 +98,6 @@ message VsIdentifier { TransportProto proto = 3; } -// Source port range for allowed_src filtering. -// -// Defines an inclusive range of source ports that are permitted for -// traffic matching a specific network prefix. Used in conjunction with -// AllowedSrc to provide fine-grained access control based on both -// source IP address and source port. -message PortsRange { - // Starting port of the range (inclusive). - // - // Valid range: 0-65535 - // Must be less than or equal to 'to' field. - uint32 from = 1; - - // Ending port of the range (inclusive). - // - // Valid range: 0-65535 - // Must be greater than or equal to 'from' field. - uint32 to = 2; -} - // Allowed source address and port configuration. // // Defines a network prefix and optional port ranges that are permitted @@ -158,25 +117,25 @@ message PortsRange { // // EXAMPLES: // 1. Allow all traffic from 10.0.0.0/8: -// net = 10.0.0.0/255.0.0.0, ports = [] +// net = 10.0.0.0/255.0.0.0, ports = [] // // 2. Allow only high ports from 192.168.0.0/16: -// net = 192.168.0.0/255.255.0.0, ports = [{from: 1024, to: 65535}] +// net = 192.168.0.0/255.255.0.0, ports = [{from: 1024, to: 65535}] // // 3. Allow specific ports from 172.16.0.0/12: -// net = 172.16.0.0/255.240.0.0, ports = [{from: 80, to: 80}, {from: 443, to: -// 443}] +// net = 172.16.0.0/255.240.0.0, ports = [{from: 80, to: 80}, {from: 443, to: +// 443}] message AllowedSources { // Network prefixes (address and mask) for source filtering. // // Packets are matched against these networks using: - // (packet_src_ip & mask) == (net.addr & mask) + // (packet_src_ip & mask) == (net.addr & mask) // // Special cases: // - 0.0.0.0/0.0.0.0 (IPv4) or ::/:: (IPv6): Matches all addresses // - Single host: Use full mask (255.255.255.255 or // ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff) - repeated Net nets = 1; + repeated filterpb.IPNet nets = 1; // Optional source port ranges for additional filtering. // @@ -187,7 +146,7 @@ message AllowedSources { // - Restrict to high ports: [{from: 1024, to: 65535}] // - Allow specific services: [{from: 80, to: 80}, {from: 443, to: 443}] // - Custom application ranges: [{from: 8000, to: 9000}] - repeated PortsRange ports = 2; + repeated filterpb.PortRange ports = 2; // Tag identifier for tracking allowed source statistics. // @@ -219,18 +178,18 @@ message AllowedSources { // // EXAMPLES: // 1. Track internal vs external traffic: - // - Internal networks (10.0.0.0/8, 172.16.0.0/12): tag = "internal" - // - External networks (0.0.0.0/0): tag = "external" + // - Internal networks (10.0.0.0/8, 172.16.0.0/12): tag = "internal" + // - External networks (0.0.0.0/0): tag = "external" // // 2. Track per-customer traffic: - // - Customer A networks: tag = "customer_a" - // - Customer B networks: tag = "customer_b" - // - Customer C networks: tag = "customer_c" + // - Customer A networks: tag = "customer_a" + // - Customer B networks: tag = "customer_b" + // - Customer C networks: tag = "customer_c" // // 3. Aggregate multiple ranges under one tag: - // - Entry 1: nets=[10.0.0.0/8], tag="internal" - // - Entry 2: nets=[172.16.0.0/12], tag="internal" - // - Both entries contribute to the same statistics counter + // - Entry 1: nets=[10.0.0.0/8], tag="internal" + // - Entry 2: nets=[172.16.0.0/12], tag="internal" + // - Both entries contribute to the same statistics counter optional string tag = 3; } @@ -256,16 +215,16 @@ message VirtualService { // BEHAVIOR: // - Empty list: All sources are denied (no traffic allowed) // - Non-empty list: Only sources matching at least one AllowedSrc entry - // are permitted + // are permitted // // MATCHING LOGIC: // For each incoming packet: // 1. If allowed_srcs is empty → DROP // 2. For each AllowedSrc entry: - // a. Check if packet source IP matches the network prefix - // b. If ports list is empty → ACCEPT (IP match is sufficient) - // c. If ports list is non-empty, check if source port matches any - // range d. If both IP and port match → ACCEPT + // a. Check if packet source IP matches the network prefix + // b. If ports list is empty → ACCEPT (IP match is sufficient) + // c. If ports list is non-empty, check if source port matches any + // range d. If both IP and port match → ACCEPT // 3. If no AllowedSrc entry matches → DROP (increment // packet_src_not_allowed) // @@ -290,7 +249,7 @@ message VirtualService { // Lists the IP addresses of other balancer instances serving the same // virtual service. Used for ICMP error message broadcasting and // coordinated load balancing. Does not include this balancer's address. - repeated Addr peers = 6; + repeated bytes peers = 6; } // Real server identifier within a virtual service. @@ -299,7 +258,7 @@ message VirtualService { // Note: The port field is currently reserved for future use. message RelativeRealIdentifier { // Real server IP address (IPv4 or IPv6) - Addr ip = 1; + bytes ip = 1; // Real server port (RESERVED FOR FUTURE USE). // @@ -343,18 +302,15 @@ message Real { // adjusted based on active session counts. uint32 weight = 2; - // Source address for encapsulated packets. + // Source address and mask for encapsulated packets. // // When packets are tunneled to the real server (via IPIP or GRE), // the source address is calculated as: - // src = (src_addr & src_mask) | (original_src & (~src_mask)) + // src = (src_addr & src_mask) | (original_src & (~src_mask)) // // This allows preserving parts of the original source address while // setting specific bits for routing or identification purposes. - Addr src_addr = 3; - - // Source address mask for encapsulation - Addr src_mask = 4; + filterpb.IPNet src = 3; } // Virtual service feature flags. @@ -365,7 +321,7 @@ message VsFlags { // Use GRE encapsulation instead of IPIP. // // When true: Packets are tunneled to real servers using GRE - // (Generic Routing Encapsulation) + // (Generic Routing Encapsulation) // When false: Packets are tunneled using IPIP (IP-in-IP) // // GRE provides more flexibility and can carry additional metadata, @@ -411,44 +367,6 @@ message VsFlags { // - Transparent proxying scenarios // - When port-based routing is not needed bool pure_l3 = 4; - - // Enable Weighted Least Connection (WLC) dynamic weight adjustment. - // - // When true, the balancer dynamically adjusts real server weights - // based on active session counts to balance load more evenly. - // - // REQUIREMENTS: - // To enable WLC, ALL of the following must be configured: - // 1. Set this flag to true - // 2. Set StateConfig.refresh_period to a non-zero value (e.g., 1s-60s) - // 3. Set StateConfig.session_table_max_load_factor (e.g., 0.7-0.9) - // 4. Configure StateConfig.wlc with power and max_weight values - // - // If any of these requirements are missing, configuration will be - // rejected. - // - // BEHAVIOR: - // - During each refresh cycle (every refresh_period): - // * Scans session table to count active sessions per real - // * Calculates new effective weights using WLC algorithm - // * Updates real weights in the dataplane (state only, not config) - // - Formula: effective_weight = weight * max(1.0, power * (1.0 - - // ratio)) - // where ratio = - // (real_sessions * total_weight) / (total_sessions * - // real_weight) - // - Effective weights are capped at wlc.max_weight - // - Original config weights are preserved for future calculations - // - // COMPATIBILITY: - // - Works with both SOURCE_HASH and ROUND_ROBIN schedulers - // - The adjustment is independent of the scheduling algorithm - // - Helps prevent overloading when session durations vary significantly - // - // PERFORMANCE IMPACT: - // - Adds overhead during refresh cycle (scan sessions + weight updates) - // - Impact depends on session count and refresh_period - bool wlc = 5; } // Real server update operation. @@ -464,7 +382,7 @@ message RealUpdate { // // When true: Real server is enabled and receives traffic // When false: Real server is disabled and receives no new sessions - // (existing sessions may continue) + // (existing sessions may continue) // When not specified: Enabled state is not changed optional bool enable = 2; @@ -503,9 +421,6 @@ message SessionsTimeouts { // Timeout for UDP sessions (seconds) // UDP is connectionless, so this defines inactivity timeout uint32 udp = 5; - - // Default timeout for other cases (seconds) - uint32 default = 6; } // Top-level balancer configuration. @@ -551,7 +466,7 @@ message PacketHandlerConfig { // // Used as the outer source address when tunneling packets to // real servers via IPIP or GRE (for IPv4 destinations). - Addr source_address_v4 = 2; + bytes source_address_v4 = 2; // Source IPv6 address for encapsulated packets. // @@ -560,7 +475,7 @@ message PacketHandlerConfig { // // Used as the outer source address when tunneling packets to // real servers via IPIP or GRE (for IPv6 destinations). - Addr source_address_v6 = 3; + bytes source_address_v6 = 3; // Decapsulation addresses. // @@ -571,7 +486,7 @@ message PacketHandlerConfig { // from this list, it attempts to decapsulate it (remove the // outer IP header). Used for return traffic in DSR setups or // multi-tier load balancing. - repeated Addr decap_addresses = 4; + repeated bytes decap_addresses = 4; // Session timeout configuration. // @@ -606,23 +521,23 @@ message PacketHandlerRef { // ALGORITHM OVERVIEW: // For each real server in a WLC-enabled virtual service: // 1. Calculate connections ratio: -// ratio = (real_sessions * total_weight) / (total_sessions * real_weight) -// where: -// - real_sessions: Active sessions on this real -// - total_sessions: Active sessions across all enabled reals in the VS -// - real_weight: Original configured weight of this real -// - total_weight: Sum of original weights of all enabled reals +// ratio = (real_sessions * total_weight) / (total_sessions * real_weight) +// where: +// - real_sessions: Active sessions on this real +// - total_sessions: Active sessions across all enabled reals in the VS +// - real_weight: Original configured weight of this real +// - total_weight: Sum of original weights of all enabled reals // // 2. Calculate WLC adjustment factor: -// wlc_factor = max(1.0, power * (1.0 - ratio)) -// - If ratio < 1.0: Real has fewer sessions than expected → increase weight -// - If ratio > 1.0: -// Real has more sessions than expected → keep weight -// at 1.0x -// - power parameter controls adjustment aggressiveness +// wlc_factor = max(1.0, power * (1.0 - ratio)) +// - If ratio < 1.0: Real has fewer sessions than expected → increase weight +// - If ratio > 1.0: +// Real has more sessions than expected → keep weight +// at 1.0x +// - power parameter controls adjustment aggressiveness // // 3. Calculate effective weight: -// effective_weight = min(real_weight * wlc_factor, max_weight) +// effective_weight = min(real_weight * wlc_factor, max_weight) // // CONFIGURATION REQUIREMENTS: // Both power and max_weight must be specified if any VS has wlc flag enabled. @@ -640,13 +555,13 @@ message WlcConfig { // - Value of 0: Disables WLC (effective_weight = weight) // // FORMULA IMPACT: - // effective_weight = weight * max(1.0, power * (1.0 - - // connections_ratio)) + // effective_weight = weight * max(1.0, power * (1.0 - + // connections_ratio)) // // EXAMPLES: // - power=2: If a real has 50% fewer sessions than expected // (ratio=0.5), - // its weight increases by 2 * (1.0 - 0.5) = 1.0x → doubles + // its weight increases by 2 * (1.0 - 0.5) = 1.0x → doubles // - power=4: Same scenario → weight increases by 2.0x → triples // - power=1: Same scenario → weight increases by 0.5x → 1.5x original // @@ -743,23 +658,23 @@ message StateConfig { // BEHAVIOR WHEN NON-ZERO: // The balancer performs these operations every refresh_period: // 1. Session Table Scan: - // - Counts active sessions per virtual service and per real server - // - Updates BalancerInfo statistics + // - Counts active sessions per virtual service and per real server + // - Updates BalancerInfo statistics // (active_sessions, last_packet_timestamp) // // 2. Automatic Session Table Resizing: - // - Calculates load factor: active_sessions / table_capacity - // - If load_factor > session_table_max_load_factor: - // * Doubles the session table capacity - // * Migrates existing sessions to new table - // * Prevents session table overflow + // - Calculates load factor: active_sessions / table_capacity + // - If load_factor > session_table_max_load_factor: + // * Doubles the session table capacity + // * Migrates existing sessions to new table + // * Prevents session table overflow // // 3. WLC Weight Adjustment (if any VS has wlc flag enabled): - // - For each VS with wlc=true: - // * Calculates new effective weights based on session distribution - // * Updates real server weights in dataplane (state only) - // * Preserves original config weights for future calculations - // - Uses wlc.power and wlc.max_weight parameters + // - For each VS with wlc=true: + // * Calculates new effective weights based on session distribution + // * Updates real server weights in dataplane (state only) + // * Preserves original config weights for future calculations + // - Uses wlc.power and wlc.max_weight parameters // // BEHAVIOR WHEN ZERO: // - Disables all periodic refresh operations @@ -783,4 +698,47 @@ message StateConfig { // - Stable/predictable traffic: 30s-60s // - Static configuration (no WLC, no auto-resize): 0 (disabled) optional google.protobuf.Duration refresh_period = 4; +} + +message Session { + // Client source IP address (IPv4 or IPv6) + bytes client_addr = 1; + + // Client source port number + uint32 client_port = 2; + + // Virtual service this session is associated with. + // + // Identifies which virtual service the client connected to. + VsIdentifier vs_id = 3; + + // Real server this session is assigned to. + // Relative to the virtual service. + RelativeRealIdentifier real_id = 4; + + // Session creation timestamp. + // + // When the first packet of this session was processed and + // the session was created in the session table. + google.protobuf.Timestamp create_timestamp = 5; + + // Last packet timestamp. + // + // When the most recent packet for this session was processed. + // Used to calculate session age and determine if timeout has expired. + google.protobuf.Timestamp last_packet_timestamp = 6; + + // Session timeout duration. + // + // How long the session remains active without receiving packets. + // The timeout value depends on the protocol and TCP state: + // - TCP SYN: sessions_timeouts.tcp_syn + // - TCP SYN-ACK: sessions_timeouts.tcp_syn_ack + // - TCP FIN: sessions_timeouts.tcp_fin + // - TCP established: sessions_timeouts.tcp + // - UDP: sessions_timeouts.udp + // + // If (current_time - last_packet_timestamp) > timeout, the session + // is removed from the session table during the next cleanup cycle. + google.protobuf.Duration timeout = 7; } \ No newline at end of file diff --git a/modules/balancer/controlplane/balancerpb/filter.proto b/modules/balancer/controlplane/balancerpb/filter.proto new file mode 100644 index 000000000..27b4131f4 --- /dev/null +++ b/modules/balancer/controlplane/balancerpb/filter.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package balancerpb; + +import "modules/balancer/controlplane/balancerpb/common.proto"; + +option go_package = "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb;balancerpb"; + +message Filter { + optional bytes vip = 1; + optional uint32 vs_port = 2; + optional TransportProto proto = 3; + optional bytes real_ip = 4; + optional uint32 real_port = 5; +} diff --git a/modules/balancer/controlplane/balancerpb/memory.proto b/modules/balancer/controlplane/balancerpb/memory.proto new file mode 100644 index 000000000..d2a0d6d29 --- /dev/null +++ b/modules/balancer/controlplane/balancerpb/memory.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package balancerpb; + +import "modules/balancer/controlplane/balancerpb/common.proto"; + +option go_package = "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb;balancerpb"; + +// Memory usage for real servers within a VS +message RealsUsage { + uint64 counters = 1; + uint64 interval_counters = 2; + uint64 data = 3; + uint64 total = 4; +} + +// Memory usage for a single virtual service +message VsUsage { + uint64 acl = 1; + uint64 ring = 2; + uint64 counters = 3; + RealsUsage reals = 4; + uint64 other = 5; + uint64 total = 6; +} + +// Named VS inspect with identifier +message NamedVsUsage { + VsIdentifier identifier = 1; + VsUsage inspect = 2; +} + +message PacketHandlerVsUsage { + uint64 matcher = 1; + uint64 summary_vs = 2; + repeated NamedVsUsage vs = 3; + uint64 index = 4; + uint64 total = 5; +} + +message PacketHandlerUsage { + repeated NamedVsUsage vs = 1; + uint64 summary_vs = 3; + uint64 vs_index = 4; + uint64 reals_index = 5; + uint64 counters = 6; + uint64 interval_counters = 7; + uint64 decap_usage = 8; + uint64 total_usage = 9; +} + +message BalancerMemoryUsage { + string name = 1; + PacketHandlerUsage packet_handler = 2; + uint64 session_table = 3; + uint64 other = 4; + uint64 total = 5; +} + +message AgentMemoryUsage { + uint64 memory_limit = 1; + uint64 memory_usage = 2; + repeated BalancerMemoryUsage balancers = 3; +} \ No newline at end of file diff --git a/modules/balancer/controlplane/balancerpb/meson.build b/modules/balancer/controlplane/balancerpb/meson.build new file mode 100644 index 000000000..1be2cba64 --- /dev/null +++ b/modules/balancer/controlplane/balancerpb/meson.build @@ -0,0 +1,31 @@ +root_dir = meson.project_source_root() +proto_files = [ + join_paths(root_dir, 'modules/balancer/controlplane/balancerpb/balancer.proto'), + join_paths(root_dir, 'modules/balancer/controlplane/balancerpb/common.proto'), + join_paths(root_dir, 'modules/balancer/controlplane/balancerpb/filter.proto'), + join_paths(root_dir, 'modules/balancer/controlplane/balancerpb/memory.proto'), + join_paths(root_dir, 'modules/balancer/controlplane/balancerpb/state.proto'), +] + +protoc_gen = custom_target( + 'balancer-protoc', + output: [ + 'balancer.pb.go', + 'balancer_grpc.pb.go', + 'common.pb.go', + 'filter.pb.go', + 'memory.pb.go', + 'state.pb.go', + ], + input: proto_files, + command: [ + protoc, + '-I', root_dir, + '--experimental_allow_proto3_optional', + '--go_out=paths=source_relative:' + root_dir, + '--go-grpc_out=paths=source_relative:' + root_dir, + '@INPUT@', + ], + build_by_default: true, +) +balancer_protoc_gen = protoc_gen diff --git a/modules/balancer/agent/balancerpb/stats.proto b/modules/balancer/controlplane/balancerpb/state.proto similarity index 84% rename from modules/balancer/agent/balancerpb/stats.proto rename to modules/balancer/controlplane/balancerpb/state.proto index 7076dc436..4ebf495fc 100644 --- a/modules/balancer/agent/balancerpb/stats.proto +++ b/modules/balancer/controlplane/balancerpb/state.proto @@ -2,9 +2,49 @@ syntax = "proto3"; package balancerpb; -import "modules/balancer/agent/balancerpb/module.proto"; +import "google/protobuf/timestamp.proto"; + +import "modules/balancer/controlplane/balancerpb/common.proto"; + +option go_package = "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb;balancerpb"; + +message BalancerState { + string balancer_name = 1; + PacketHandlerRef ref = 2; + L4Stats l4_stats = 3; + CommonStats common_stats = 4; + IcmpStats icmp_ipv4_stats = 5; + IcmpStats icmp_ipv6_stats = 6; + uint64 active_sessions = 7; + google.protobuf.Timestamp last_packet_timestamp = 8; + repeated VsState virtual_services = 9; + bytes source_ipv4 = 10; + bytes source_ipv6 = 11; + repeated bytes decap_addresses = 12; +} -option go_package = "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb;balancerpb"; +message VsState { + VsIdentifier id = 1; + VsFlags flags = 2; + VsScheduler scheduler = 3; + VsStats stats = 4; + uint64 active_sessions = 5; + google.protobuf.Timestamp last_packet_timestamp = 6; + repeated AllowedSourcesStats allowed_sources = 7; + repeated RealState reals = 8; + repeated AllowedSources allowed_srcs_config = 9; + repeated bytes peers = 10; +} + +message RealState { + RelativeRealIdentifier id = 1; + RealStats real_stats = 2; + uint64 active_sessions = 3; + google.protobuf.Timestamp last_packet_timestamp = 4; + uint64 weight = 5; + uint64 effective_weight = 6; + bool enabled = 7; +} // Layer 4 (TCP/UDP) packet processing statistics. // @@ -137,44 +177,40 @@ message VsStats { // All reals are either disabled, removed, or have zero weight. uint64 no_reals = 4; - // Packets processed in One Packet Scheduler (OPS) mode. - // Each packet is independently scheduled without session tracking. - uint64 ops_packets = 5; - // Packets dropped due to session table overflow. // Session table is full and cannot create new sessions. - uint64 session_table_overflow = 6; + uint64 session_table_overflow = 5; // ICMP echo (ping) packets processed for this virtual service - uint64 echo_icmp_packets = 7; + uint64 echo_icmp_packets = 6; // ICMP error messages processed for this virtual service - uint64 error_icmp_packets = 8; + uint64 error_icmp_packets = 7; // Packets for sessions assigned to disabled real servers. // Real was disabled after the session was created. - uint64 real_is_disabled = 9; + uint64 real_is_disabled = 8; // Packets for sessions assigned to removed real servers. // Real was removed from configuration after the session was created. - uint64 real_is_removed = 10; + uint64 real_is_removed = 9; // Packets that could not be rescheduled to a different real. // Happens when the assigned real is unavailable and no alternative // exists. - uint64 not_rescheduled_packets = 11; + uint64 not_rescheduled_packets = 10; // ICMP error messages broadcasted to peer balancers - uint64 broadcasted_icmp_packets = 12; + uint64 broadcasted_icmp_packets = 11; // Total number of new sessions created for this virtual service - uint64 created_sessions = 13; + uint64 created_sessions = 12; // Total packets successfully forwarded to real servers - uint64 outgoing_packets = 14; + uint64 outgoing_packets = 13; // Total bytes successfully forwarded to real servers - uint64 outgoing_bytes = 15; + uint64 outgoing_bytes = 14; } // Per-real-server statistics. @@ -273,14 +309,14 @@ message BalancerStats { L4Stats l4 = 1; // ICMPv4 processing statistics - IcmpStats icmpv4 = 2; + IcmpStats icmp_ipv4 = 2; // ICMPv6 processing statistics - IcmpStats icmpv6 = 3; + IcmpStats icmp_ipv6 = 3; // Common statistics (all protocols) CommonStats common = 4; // Per-virtual-service statistics with per-real breakdowns - repeated NamedVsStats vs = 5; + repeated NamedVsStats virtual_services = 5; } diff --git a/modules/balancer/controlplane/cdefs.go b/modules/balancer/controlplane/cdefs.go new file mode 100644 index 000000000..0f526beb9 --- /dev/null +++ b/modules/balancer/controlplane/cdefs.go @@ -0,0 +1,96 @@ +//go:generate sh -c "go tool cgo -godefs -- -I../../../ -I../../../filter -I../../../lib -I../../../modules/balancer/dataplane -I../../../modules/balancer/dataplane/types cdefs.go > ctypes.go" + +//go:build ignore + +package balancer + +/* +#include "filter/rule.h" +#include "common/rcu.h" + +#include "modules/balancer/dataplane/types/vs.h" +#include "modules/balancer/dataplane/types/stats.h" +#include "modules/balancer/dataplane/types/selector.h" +#include "modules/balancer/dataplane/types/real.h" +#include "modules/balancer/dataplane/types/sessions_tracker.h" +#include "modules/balancer/dataplane/types/session.h" +#include "modules/balancer/dataplane/dataplane.h" +#include "modules/balancer/controlplane/helpers/sessions.h" +*/ +import "C" + +// Network types. +type ( + Net4Addr C.struct_net4_addr + Net6Addr C.struct_net6_addr + NetAddr C.struct_net_addr + Net4 C.struct_net4 + Net6 C.struct_net6 + Net C.struct_net +) + +// Filter. +type ( + Filter C.struct_filter +) + +// RCU. +type ( + RCU C.rcu_t +) + +// Balancer core types. +type ( + VS C.struct_balancer_vs + Real C.struct_balancer_real + AllowedSource C.struct_balancer_vs_allowed_source + SessionTimeouts C.struct_balancer_session_timeouts + PacketHandler C.struct_balancer_packet_handler + SessionTrackerShard C.struct_balancer_sessions_tracker_shard + IntervalCounter C.struct_balancer_interval_counter + SessionTable C.struct_balancer_session_table + RealSelector C.struct_balancer_real_selector +) + +// Stats +type ( + L4Stats C.struct_balancer_l4_stats + IcmpStats C.struct_balancer_icmp_stats + CommonStats C.struct_balancer_common_stats + VsStats C.struct_balancer_vs_stats + RealStats C.struct_balancer_real_stats +) + +// VS flags +const ( + VSFlagPureL3 = C.balancer_vs_pure_l3 + VSFlagFixMSS = C.balancer_vs_fix_mss + VSFlagGRE = C.balancer_vs_gre + VSFlagOPS = C.balancer_vs_ops + VSFlagWLC = C.balancer_vs_wlc + VSFlagRemoved = C.balancer_vs_removed + VSFlagRoundRobin = C.balancer_vs_round_robin +) + +// Real flags +const ( + RealFlagEnabled = C.balancer_real_enabled + RealFlagRemoved = C.balancer_real_removed + RealFlagIPv6 = C.balancer_real_ipv6 +) + +const ( + MaxSessionTimeout = uint32(C.balancer_max_session_timeout) + AllowedSourceMaxTagLength = uint32(C.balancer_vs_acl_max_tag_len) +) + +// Session iteration types. +type ( + SessionID C.struct_balancer_session_id + SessionState C.struct_balancer_session_state + SessionEntry C.struct_balancer_session_entry + SessionTableIter C.struct_balancer_session_table_iter +) + +// Filter types. +type PortRange C.struct_filter_port_range diff --git a/modules/balancer/agent/go/config.go b/modules/balancer/controlplane/cfg.go similarity index 88% rename from modules/balancer/agent/go/config.go rename to modules/balancer/controlplane/cfg.go index db5a3fd84..fbdc26f9d 100644 --- a/modules/balancer/agent/go/config.go +++ b/modules/balancer/controlplane/cfg.go @@ -5,18 +5,17 @@ import ( "github.com/yanet-platform/yanet2/common/go/xcfg" ) -//////////////////////////////////////////////////////////////////////////////// - -// Config for the balancer service +// Config for the balancer controlplane module. type Config struct { // InstanceID specifies which dataplane instance this module serves. InstanceID uint32 `yaml:"instance_id"` + // MemoryPath is the path to the shared-memory file that is used to // communicate with dataplane. MemoryPath xcfg.NonEmptyString `yaml:"memory_path"` // MemoryRequirements is the amount of memory that is required for a single - // agent + // agent transaction. MemoryRequirements xcfg.NonZero[datasize.ByteSize] `yaml:"memory_requirements"` Endpoint xcfg.NonEmptyString `yaml:"endpoint"` diff --git a/modules/balancer/controlplane/chelpers.go b/modules/balancer/controlplane/chelpers.go new file mode 100644 index 000000000..9938bd34e --- /dev/null +++ b/modules/balancer/controlplane/chelpers.go @@ -0,0 +1,241 @@ +package balancer + +/* +#cgo CFLAGS: -I../../../ -I../../../filter -I../../../lib -I../../../modules/balancer/dataplane -I../../../modules/balancer/dataplane/types +#cgo LDFLAGS: -L../../../build/modules/balancer/controlplane/helpers -lbalancer_helpers +#cgo LDFLAGS: -L../../../build/filter -lfilter_compiler + +#include "filter/rule.h" + +#include "modules/balancer/controlplane/helpers/agent.h" +#include "modules/balancer/controlplane/helpers/balancer.h" +#include "modules/balancer/controlplane/helpers/vs.h" +#include "modules/balancer/controlplane/helpers/real.h" +#include "modules/balancer/controlplane/helpers/sessions.h" + +#include "modules/balancer/dataplane/types/vs.h" +#include "modules/balancer/dataplane/types/stats.h" +#include "modules/balancer/dataplane/types/selector.h" +#include "modules/balancer/dataplane/types/real.h" +#include "modules/balancer/dataplane/types/sessions_tracker.h" +#include "modules/balancer/dataplane/types/session.h" +#include "modules/balancer/dataplane/dataplane.h" +*/ +import "C" + +import ( + "time" + "unsafe" + + "github.com/yanet-platform/yanet2/common/go/relptr" + "google.golang.org/grpc/codes" +) + +func errFromCode(res C.int) error { + switch res { + case 0: + return nil + case -1: + return errNoAgentMemory + case -2: + return CodedErrorf(codes.ResourceExhausted, "no heap memory") + default: + return CodedErrorf(codes.Unknown, "unknown error code=%d", res) + } +} + +func (a *Agent) asCPtr() *C.struct_agent { + return (*C.struct_agent)(unsafe.Pointer(a.AsYanetAgent().AsRawPtr())) +} + +func (a *Agent) install(handler *PacketHandler) error { + a.AsYanetAgent().CleanError() + + res := C.balancer_agent_install( + a.asCPtr(), + handler.asCPtr(), + ) + if res != 0 { + return a.AsYanetAgent().TakeError() + } + + return nil +} + +func (a *Agent) register(handler *PacketHandler) error { + return errFromCode(C.balancer_agent_register( + a.asCPtr(), + handler.asCPtr(), + )) +} + +func (a *Agent) forget(handler *PacketHandler) { + C.balancer_agent_forget( + a.asCPtr(), + handler.asCPtr(), + ) +} + +func (a *Agent) list() []*PacketHandler { + count := C.size_t(0) + handlersRaw := C.balancer_agent_list(a.asCPtr(), &count) + if handlersRaw == nil { + return nil + } + handlers := unsafe.Slice((**PacketHandler)(unsafe.Pointer(handlersRaw)), count) + for i := range handlers { + handlers[i] = relptr.Deref(&handlers[i]) + } + return handlers +} + +func (a *Agent) createSessionTable(capacity int) *SessionTable { + return (*SessionTable)(unsafe.Pointer(C.balancer_agent_create_st( + a.asCPtr(), + C.size_t(capacity), + ))) +} + +func (a *Agent) destroySessionTable(st *SessionTable) { + C.balancer_agent_destroy_st(a.asCPtr(), st.asCPtr()) +} + +func (ph *PacketHandler) asCPtr() *C.struct_balancer_packet_handler { + return (*C.struct_balancer_packet_handler)(unsafe.Pointer(ph)) +} + +func (ph *PacketHandler) initialSetup(agent *Agent, name string, st *SessionTable) error { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + + return errFromCode(C.balancer_initial_setup( + agent.asCPtr(), + ph.asCPtr(), + cName, + st.asCPtr(), + )) +} + +func (ph *PacketHandler) registerCounters() error { + return errFromCode(C.balancer_register_counters(ph.asCPtr())) +} + +func (ph *PacketHandler) name() string { + return C.GoString(C.balancer_name(ph.asCPtr())) +} + +func (ph *PacketHandler) setIpv4DecapFilter() error { + return errFromCode(C.balancer_set_ipv4_decap_filter(ph.asCPtr())) +} + +func (ph *PacketHandler) setIpv6DecapFilter() error { + return errFromCode(C.balancer_set_ipv6_decap_filter(ph.asCPtr())) +} + +func (ph *PacketHandler) freeDecapFilters() { + C.balancer_free_decap_filters(ph.asCPtr()) +} + +func (ph *PacketHandler) setIpv4VsMatcher() error { + return errFromCode(C.balancer_set_ipv4_vs_matcher(ph.asCPtr())) +} + +func (ph *PacketHandler) setIpv6VsMatcher() error { + return errFromCode(C.balancer_set_ipv6_vs_matcher(ph.asCPtr())) +} + +func (ph *PacketHandler) freeVsMatchers() { + C.balancer_free_vs_matchers(ph.asCPtr()) +} + +func (vs *VS) asCPtr() *C.struct_balancer_vs { + return (*C.struct_balancer_vs)(unsafe.Pointer(vs)) +} + +func (vs *VS) setACL(agent *Agent) error { + return errFromCode(C.balancer_vs_set_acl(vs.asCPtr(), agent.asCPtr())) +} + +func (vs *VS) freeACL(agent *Agent) { + C.balancer_vs_free_acl(vs.asCPtr(), agent.asCPtr()) +} + +func (vs *VS) updateRealSelector(rcu *RCU, agent *Agent) error { + return errFromCode( + C.balancer_vs_update_real_selector(vs.asCPtr(), rcu.asCPtr(), agent.asCPtr()), + ) +} + +func (vs *VS) freeRealSelector(agent *Agent) { + C.balancer_vs_free_real_selector(vs.asCPtr(), agent.asCPtr()) +} + +func (vs *VS) setSessionsTrackers(agent *Agent) error { + return errFromCode(C.balancer_vs_set_session_trackers(vs.asCPtr(), agent.asCPtr())) +} + +func (vs *VS) freeSessionTrackers(agent *Agent) { + C.balancer_vs_free_session_trackers(vs.asCPtr(), agent.asCPtr()) +} + +func (r *Real) asCPtr() *C.struct_balancer_real { + return (*C.struct_balancer_real)(unsafe.Pointer(r)) +} + +func (r *Real) sessions(workers uint32, now time.Time) (uint64, time.Time) { + activeSessions := C.uint64_t(0) + lastPacketTimestamp := C.uint32_t(0) + C.balancer_real_sessions( + r.asCPtr(), + C.size_t(workers), + &activeSessions, + &lastPacketTimestamp, + C.uint32_t(now.Unix()), + ) + return uint64(activeSessions), time.Unix(int64(lastPacketTimestamp), 0) +} + +func (st *SessionTable) asCPtr() *C.struct_balancer_session_table { + return (*C.struct_balancer_session_table)(unsafe.Pointer(st)) +} + +func (st *SessionTable) capacity() int { + return int(C.balancer_st_capacity(st.asCPtr())) +} + +func (st *SessionTable) resize(newSize int, now time.Time) error { + return errFromCode(C.balancer_st_resize(st.asCPtr(), C.size_t(newSize), C.uint32_t(now.Unix()))) +} + +const bucketMaxEntries = 16 + +func (it *SessionTableIter) asCPtr() *C.struct_balancer_session_table_iter { + return (*C.struct_balancer_session_table_iter)(unsafe.Pointer(it)) +} + +func (st *SessionTable) newSessionIter() SessionTableIter { + var iter SessionTableIter + C.balancer_st_iter_init( + iter.asCPtr(), + st.asCPtr(), + ) + return iter +} + +func (it *SessionTableIter) nextBucket(now uint32, buf []SessionEntry) int { + var count C.int + ret := C.balancer_st_iter_next_bucket_buf( + it.asCPtr(), + C.uint32_t(now), + (*C.struct_balancer_session_entry)(unsafe.Pointer(&buf[0])), + &count, + ) + if ret == 0 { + return -1 + } + return int(count) +} + +func (rcu *RCU) asCPtr() *C.struct_rcu { + return (*C.struct_rcu)(unsafe.Pointer(rcu)) +} diff --git a/modules/balancer/controlplane/ctypes.go b/modules/balancer/controlplane/ctypes.go new file mode 100644 index 000000000..65c16bfd6 --- /dev/null +++ b/modules/balancer/controlplane/ctypes.go @@ -0,0 +1,255 @@ +// Code generated by cmd/cgo -godefs; DO NOT EDIT. +// cgo -godefs -- -I../../../ -I../../../filter -I../../../lib -I../../../modules/balancer/dataplane -I../../../modules/balancer/dataplane/types cdefs.go + +package balancer + +type ( + Net4Addr struct { + Bytes [4]uint8 + } + Net6Addr struct { + Bytes [16]uint8 + } + NetAddr struct { + V4 Net4Addr + Pad_cgo_0 [12]byte + } + Net4 struct { + Addr [4]uint8 + Mask [4]uint8 + } + Net6 struct { + Addr [16]uint8 + Mask [16]uint8 + } + Net struct { + V4 Net4 + Pad_cgo_0 [24]byte + } +) + +type ( + Filter struct { + V [20]_Ctype_struct_filter_vertex + Context _Ctype_struct_memory_context + } +) + +type ( + RCU struct { + Workers [8]_Ctype_struct___1 + Epoch uint32 + Pad_cgo_0 [60]byte + } +) + +type ( + VS struct { + Reals *Real + Reals_count uint32 + Stable_idx uint64 + Counter_id uint64 + Selector *RealSelector + Acl *Filter + Rule_counter_ids *uint64 + Flags uint16 + X__padding [64]uint8 + Addr NetAddr + Ip_proto uint8 + Port uint16 + Transport_proto uint8 + Peers_v4 *Net4Addr + Peers_v4_count uint32 + Peers_v6 *Net6Addr + Peers_v6_count uint32 + Allowed_sources *AllowedSource + Allowed_sources_count uint32 + Pad_cgo_0 [4]byte + } + Real struct { + Counter_id uint64 + Stable_idx uint64 + Tracker_shards *SessionTrackerShard + Addr NetAddr + Src Net + Flags uint8 + Weight uint32 + Effective_weight uint32 + Pad_cgo_0 [4]byte + } + AllowedSource struct { + Nets *Net + Nets_count uint32 + Port_ranges *PortRange + Port_ranges_count uint32 + Tag [21]int8 + Pad_cgo_0 [7]byte + } + SessionTimeouts struct { + Syn_ack uint8 + Syn uint8 + Fin uint8 + Tcp uint8 + Udp uint8 + } + PacketHandler struct { + Cp_module _Ctype_struct_cp_module + Common_counter_id uint64 + Icmp_v4_counter_id uint64 + Icmp_v6_counter_id uint64 + L4_counter_id uint64 + Decap_ipv4_filter *Filter + Decap_ipv6_filter *Filter + Session_table *SessionTable + Ipv4_vs_matcher *Filter + Ipv6_vs_matcher *Filter + Vs *VS + Vs_count uint32 + Session_timeouts SessionTimeouts + Source_v4 Net4Addr + Source_v6 Net6Addr + Pad_cgo_0 [35]byte + Rcu RCU + Decap_v4 *Net4Addr + Decap_v4_count uint32 + Decap_v6 *Net6Addr + Decap_v6_count uint32 + Wlc_power uint32 + Wlc_max_weight uint32 + Refresh_period_ms uint32 + Session_table_max_load_factor float32 + Pad_cgo_1 [20]byte + } + SessionTrackerShard struct { + Counter IntervalCounter + Count uint32 + Pad_cgo_0 [24]byte + } + IntervalCounter struct { + Diff [8]int32 + Timestamp uint32 + } + SessionTable struct { + Maps [2]_Ctype_struct_ttlmap + Rcu RCU + Gen uint64 + Mctx _Ctype_struct_memory_context + Workers uint32 + Pad_cgo_0 [12]byte + } + RealSelector struct { + Rings [2]_Ctype_struct_balancer_ring + Ring_id uint64 + Use_rr int32 + Pad_cgo_0 [36]byte + Workers [8]_Ctype_struct_balancer_rr_counter + } +) + +type ( + L4Stats struct { + Incoming_packets uint64 + Select_vs_failed uint64 + Invalid_packets uint64 + Select_real_failed uint64 + Outgoing_packets uint64 + } + IcmpStats struct { + Incoming_packets uint64 + Src_not_allowed uint64 + Echo_responses uint64 + Payload_too_short_ip uint64 + Unmatching_src_from_original uint64 + Payload_too_short_port uint64 + Unexpected_transport uint64 + Unrecognized_vs uint64 + Forwarded_packets uint64 + Broadcasted_packets uint64 + Packet_clones_sent uint64 + Packet_clones_received uint64 + Packet_clone_failures uint64 + } + CommonStats struct { + Incoming_packets uint64 + Incoming_bytes uint64 + Unexpected_network_proto uint64 + Decap_successful uint64 + Decap_failed uint64 + Outgoing_packets uint64 + Outgoing_bytes uint64 + } + VsStats struct { + Incoming_packets uint64 + Incoming_bytes uint64 + Packet_src_not_allowed uint64 + No_reals uint64 + Session_table_overflow uint64 + Echo_icmp_packets uint64 + Error_icmp_packets uint64 + Real_is_disabled uint64 + Real_is_removed uint64 + Not_rescheduled_packets uint64 + Broadcasted_icmp_packets uint64 + Created_sessions uint64 + Outgoing_packets uint64 + Outgoing_bytes uint64 + } + RealStats struct { + Packets_real_disabled uint64 + Error_icmp_packets uint64 + Created_sessions uint64 + Packets uint64 + Bytes uint64 + } +) + +const ( + VSFlagPureL3 = 0x1 + VSFlagFixMSS = 0x2 + VSFlagGRE = 0x4 + VSFlagOPS = 0x8 + VSFlagWLC = 0x10 + VSFlagRemoved = 0x20 + VSFlagRoundRobin = 0x40 +) + +const ( + RealFlagEnabled = 0x1 + RealFlagRemoved = 0x2 + RealFlagIPv6 = 0x4 +) + +const ( + MaxSessionTimeout = uint32(97.000000) + AllowedSourceMaxTagLength = uint32(0x14) +) + +type ( + SessionID struct { + Vs_stable_idx uint64 + Client_port uint16 + Client_ip [16]uint8 + Padding [6]uint8 + } + SessionState struct { + Real_stable_idx uint64 + Last_packet_timestamp uint32 + Create_timestamp uint32 + Timeout uint8 + Pad_cgo_0 [7]byte + } + SessionEntry struct { + Id SessionID + State SessionState + } + SessionTableIter struct { + Iter _Ctype_struct_ttlmap_bucket_iter + Gen uint32 + Pad_cgo_0 [4]byte + } +) + +type PortRange struct { + From uint16 + To uint16 +} diff --git a/modules/balancer/controlplane/error.go b/modules/balancer/controlplane/error.go new file mode 100644 index 000000000..2e5a3bc80 --- /dev/null +++ b/modules/balancer/controlplane/error.go @@ -0,0 +1,43 @@ +package balancer + +import ( + "errors" + "fmt" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// codedError attaches a gRPC code to an error while remaining transparent +// to errors.Is/As and %w chains. +type codedError struct { + code codes.Code + err error +} + +func (e *codedError) Error() string { return e.err.Error() } +func (e *codedError) Unwrap() error { return e.err } +func (e *codedError) GRPCStatus() *status.Status { + return status.New(e.code, e.err.Error()) +} + +// CodedErrorf builds a new error tagged with a gRPC code. +// The format string follows fmt.Errorf conventions (including %w). +func CodedErrorf(code codes.Code, format string, args ...any) error { + return &codedError{code: code, err: fmt.Errorf(format, args...)} +} + +// AsStatus converts err to a gRPC status error at the RPC boundary. +// It walks the error chain for an attached code; if none is found, +// it falls back to the supplied code. +func AsStatus(err error, fallback codes.Code) error { + if err == nil { + return nil + } + code := fallback + var ce *codedError + if errors.As(err, &ce) { + code = ce.code + } + return status.Error(code, err.Error()) +} diff --git a/modules/balancer/controlplane/filter.go b/modules/balancer/controlplane/filter.go new file mode 100644 index 000000000..02da1e354 --- /dev/null +++ b/modules/balancer/controlplane/filter.go @@ -0,0 +1,84 @@ +package balancer + +import ( + "bytes" + "fmt" + + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +func validateFilter(filter *balancerpb.Filter) error { + if filter == nil { + return nil + } + if filter.Proto != nil { + switch *filter.Proto { + case balancerpb.TransportProto_TCP, balancerpb.TransportProto_UDP: + default: + return fmt.Errorf("invalid filter proto: %v", *filter.Proto) + } + } + return nil +} + +type filterMatcher struct { + vip []byte + vsPort *uint32 + proto *balancerpb.TransportProto + realIP []byte + realPort *uint32 + + hasVsFilter bool + hasRealFilter bool +} + +func newFilterMatcher(filter *balancerpb.Filter) filterMatcher { + if filter == nil { + return filterMatcher{} + } + m := filterMatcher{} + if filter.Vip != nil { + m.vip = filter.Vip + m.hasVsFilter = true + } + if filter.VsPort != nil { + m.vsPort = filter.VsPort + m.hasVsFilter = true + } + if filter.Proto != nil { + m.proto = filter.Proto + m.hasVsFilter = true + } + if filter.RealIp != nil { + m.realIP = filter.RealIp + m.hasRealFilter = true + } + if filter.RealPort != nil { + m.realPort = filter.RealPort + m.hasRealFilter = true + } + return m +} + +func (m *filterMatcher) matchVsID(id *balancerpb.VsIdentifier) bool { + if m.vip != nil && !bytes.Equal(m.vip, id.Addr) { + return false + } + if m.vsPort != nil && *m.vsPort != id.Port { + return false + } + if m.proto != nil && *m.proto != id.Proto { + return false + } + return true +} + +func (m *filterMatcher) matchRealID(id *balancerpb.RelativeRealIdentifier) bool { + if m.realIP != nil && !bytes.Equal(m.realIP, id.Ip) { + return false + } + if m.realPort != nil && *m.realPort != id.Port { + return false + } + return true +} diff --git a/modules/balancer/controlplane/handler.go b/modules/balancer/controlplane/handler.go new file mode 100644 index 000000000..3ef3a3ca4 --- /dev/null +++ b/modules/balancer/controlplane/handler.go @@ -0,0 +1,361 @@ +package balancer + +import ( + "bytes" + "fmt" + "time" + + "github.com/yanet-platform/yanet2/common/go/relptr" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +// populateSourceAddrs copies source IPv4 and IPv6 addresses to the handler. +func (ph *PacketHandler) populateSourceAddrs(config *balancerpb.PacketHandlerConfig) { + writeNet4Addr(&ph.Source_v4, config.SourceAddressV4) + writeNet6Addr(&ph.Source_v6, config.SourceAddressV6) +} + +// populateSessionTimeouts writes timeout values to the handler. +func (ph *PacketHandler) populateSessionTimeouts(t *balancerpb.SessionsTimeouts) { + ph.Session_timeouts = SessionTimeouts{ + Syn_ack: uint8(t.TcpSynAck), + Syn: uint8(t.TcpSyn), + Fin: uint8(t.TcpFin), + Tcp: uint8(t.Tcp), + Udp: uint8(t.Udp), + } +} + +// populateDecapAddresses separates decap addresses into IPv4/IPv6, allocates +// shared-memory arrays, and writes them to the handler. +// +// Each address family is handled in two passes: first to count, then to fill. +// AllocSlice requires the total element count upfront and cannot grow +// incrementally, so we must count before allocating. +func (ph *PacketHandler) populateDecapAddresses(agent *Agent, addrs [][]byte) error { + v4Count := 0 + v6Count := 0 + for _, addr := range addrs { + if len(addr) == 4 { + v4Count++ + } else { + v6Count++ + } + } + + if v4Count > 0 { + slice := yanet.AllocSlice[Net4Addr](agent.AsYanetAgent(), v4Count) + if slice == nil { + return errNoAgentMemory + } + j := 0 + for _, addr := range addrs { + if len(addr) == 4 { + writeNet4Addr(&slice[j], addr) + j++ + } + } + relptr.SetSlice(&ph.Decap_v4, slice) + ph.Decap_v4_count = uint32(v4Count) + } + + if v6Count > 0 { + slice := yanet.AllocSlice[Net6Addr](agent.AsYanetAgent(), v6Count) + if slice == nil { + return errNoAgentMemory + } + j := 0 + for _, addr := range addrs { + if len(addr) == 16 { + writeNet6Addr(&slice[j], addr) + j++ + } + } + relptr.SetSlice(&ph.Decap_v6, slice) + ph.Decap_v6_count = uint32(v6Count) + } + + return nil +} + +// setupFilters compiles or reuses VS matchers and decap filters based on the reuse report. +// Precondition: when any *Reused flag is true, prev must be non-nil. This is guaranteed +// because reuse flags are only set when prev exists (decapFiltersReusable returns false +// for nil, and populateVS only sets matcher flags when previous VSes overlap with new ones). +func (ph *PacketHandler) setupFilters( + prev *PacketHandler, + reuseReport *balancerpb.ReuseReport, +) error { + if reuseReport.Ipv4VsMatcherReused { + relptr.Equate(&ph.Ipv4_vs_matcher, &prev.Ipv4_vs_matcher) + } else { + if err := ph.setIpv4VsMatcher(); err != nil { + return fmt.Errorf("set ipv4 vs matcher: %w", err) + } + } + + if reuseReport.Ipv6VsMatcherReused { + relptr.Equate(&ph.Ipv6_vs_matcher, &prev.Ipv6_vs_matcher) + } else { + if err := ph.setIpv6VsMatcher(); err != nil { + return fmt.Errorf("set ipv6 vs matcher: %w", err) + } + } + + if reuseReport.Ipv4DecapFilterReused { + relptr.Equate(&ph.Decap_ipv4_filter, &prev.Decap_ipv4_filter) + } else { + if err := ph.setIpv4DecapFilter(); err != nil { + return fmt.Errorf("set ipv4 decap filter: %w", err) + } + } + + if reuseReport.Ipv6DecapFilterReused { + relptr.Equate(&ph.Decap_ipv6_filter, &prev.Decap_ipv6_filter) + } else { + if err := ph.setIpv6DecapFilter(); err != nil { + return fmt.Errorf("set ipv6 decap filter: %w", err) + } + } + + return nil +} + +// NewPacketHandler allocates and fully populates a new PacketHandler in shared memory. +// If prev is non-nil, filters and selectors may be reused from it when the underlying +// data hasn't changed (reported via ReuseReport). +// On error, all partially-allocated resources are cleaned up automatically. +func NewPacketHandler( + config *balancerpb.BalancerConfig, + name string, + sessionTable *SessionTable, + agent *Agent, + prev *PacketHandler, +) (*PacketHandler, *balancerpb.ReuseReport, error) { + phConfig, stateConfig := config.PacketHandler, config.State + + handler := yanet.Alloc[PacketHandler](agent.AsYanetAgent()) + if handler == nil { + return nil, nil, errNoAgentMemory + } + + // Free all resources on error. handler.Free is safe on partially-initialized handlers + // because sub-slices start as nil and FreeSlice/C helpers are no-ops on nil values. + success := false + defer func() { + if !success { + handler.free(agent) + yanet.Free(agent.AsYanetAgent(), handler) + } + }() + + if err := handler.initialSetup(agent, name, sessionTable); err != nil { + return nil, nil, fmt.Errorf("initial setup: %w", err) + } + + handler.populateSourceAddrs(phConfig) + handler.populateSessionTimeouts(phConfig.SessionsTimeouts) + + if err := handler.populateDecapAddresses(agent, phConfig.DecapAddresses); err != nil { + return nil, nil, fmt.Errorf("populate decap addrs: %w", err) + } + + reuseReport := &balancerpb.ReuseReport{} + + if err := handler.populateVS(agent, phConfig.Vs, prev, reuseReport); err != nil { + return nil, nil, fmt.Errorf("populate virtual services: %w", err) + } + + reuseReport.Ipv4DecapFilterReused, reuseReport.Ipv6DecapFilterReused = prev.decapFiltersReusable( + phConfig.DecapAddresses, + ) + + if err := handler.setupFilters(prev, reuseReport); err != nil { + return nil, nil, err + } + + if err := handler.registerCounters(); err != nil { + return nil, nil, fmt.Errorf("register counters: %w", err) + } + + handler.setState(stateConfig, sessionTable) + + success = true + + return handler, reuseReport, nil +} + +// decapFiltersReusable checks whether the previous handler's decap addresses match +// the new config, allowing compiled decap filters to be reused. +// Precondition: addrs must be sorted by address family (all IPv4 first, then all IPv6). +// This ordering is guaranteed by validatePacketHandlerConfig which sorts decap addresses. +// Returns false, false when ph is nil (initial creation, no previous handler to reuse from). +func (ph *PacketHandler) decapFiltersReusable(addrs [][]byte) (ipv4Reused, ipv6Reused bool) { + if ph == nil { + return + } + + // split is the index where IPv6 addresses begin in the sorted addrs slice. + split := len(addrs) + for i := range addrs { + if len(addrs[i]) == 16 { + split = i + break + } + } + + ipv4Addrs := relptr.Slice(&ph.Decap_v4, ph.Decap_v4_count) + if split == len(ipv4Addrs) { + ipv4Reused = true + for i := range ipv4Addrs { + if !bytes.Equal(ipv4Addrs[i].Bytes[:], addrs[i]) { + ipv4Reused = false + break + } + } + } + + ipv6Addrs := relptr.Slice(&ph.Decap_v6, ph.Decap_v6_count) + if len(addrs)-split == len(ipv6Addrs) { + ipv6Reused = true + for i := range ipv6Addrs { + if !bytes.Equal(ipv6Addrs[i].Bytes[:], addrs[split+i]) { + ipv6Reused = false + break + } + } + } + + return ipv4Reused, ipv6Reused +} + +func (ph *PacketHandler) populateVS( + agent *Agent, + pbVS []*balancerpb.VirtualService, + prevPh *PacketHandler, + reuseReport *balancerpb.ReuseReport, +) error { + vsMap := make(map[vsKey]int, len(pbVS)) + for i, vs := range pbVS { + k := makeVsKey(vs.Id) + vsMap[k] = i + } + + var prevVS []VS + if prevPh != nil { + prevVS = relptr.Slice(&prevPh.Vs, prevPh.Vs_count) + } + + slotCount := max(len(pbVS), len(prevVS)) + services := yanet.AllocSlice[VS](agent.AsYanetAgent(), slotCount) + if services == nil { + return errNoAgentMemory + } + for idx := range services { + stableIdx := uint64(0) + if idx < len(prevVS) { + stableIdx = prevVS[idx].Stable_idx + } + services[idx] = VS{ + Flags: VSFlagRemoved, + Stable_idx: stableIdx, + } + } + + // freeServices cleans up the local services slice on error. This slice hasn't been + // written to ph yet (relptr.SetSlice happens at the end), so there's no conflict + // with the deferred handler.Free cleanup in NewPacketHandler. + freeServices := func() { + for idx := range services { + services[idx].free(agent) + } + yanet.FreeSlice(agent.AsYanetAgent(), services) + } + + reuseReport.VsReuseReports = make([]*balancerpb.VsReuseReport, 0, len(pbVS)) + + // First, write virtual services which are present in the previous config + oldIPv4VsMatches, oldIPv6VsMatches, err := placeExistingVS( + ph, + agent, + pbVS, + services, + prevVS, + vsMap, + reuseReport, + ) + if err != nil { + freeServices() + return err + } + + // Then, write virtual services which are new in the new config + noNewIPv4Vs, noNewIPv6Vs, err := placeNewVS( + ph, + agent, + pbVS, + services, + prevVS, + vsMap, + reuseReport, + ) + if err != nil { + freeServices() + return err + } + + reuseReport.Ipv4VsMatcherReused = prevPh != nil && oldIPv4VsMatches && noNewIPv4Vs + reuseReport.Ipv6VsMatcherReused = prevPh != nil && oldIPv6VsMatches && noNewIPv6Vs + + ph.Vs_count = uint32(len(services)) + relptr.SetSlice(&ph.Vs, services) + + return nil +} + +func (ph *PacketHandler) setState(stateConfig *balancerpb.StateConfig, sessionTable *SessionTable) { + ph.Wlc_power = uint32(*stateConfig.Wlc.Power) + ph.Wlc_max_weight = uint32(*stateConfig.Wlc.MaxWeight) + ph.Refresh_period_ms = uint32(stateConfig.RefreshPeriod.AsDuration().Milliseconds()) + ph.Session_table_max_load_factor = float32(*stateConfig.SessionTableMaxLoadFactor) + relptr.Set(&ph.Session_table, sessionTable) +} + +func (ph *PacketHandler) resizeSessionTable( + sessionTable *SessionTable, + newSize int, + now time.Time, +) error { + if newSize <= sessionTable.capacity() { + return nil + } + return sessionTable.resize(newSize, now) +} + +// free releases all fields owned by the packet handler: per-VS resources, +// the VS array, decap address arrays, and compiled handler-level filters. +// +// Safe to call on a partially-initialized handler because every sub-slice +// starts as nil (zero-initialized) and FreeSlice / the C helpers are no-ops +// on zero/nil values. +func (ph *PacketHandler) free(agent *Agent) { + yanetAgent := agent.AsYanetAgent() + + // Free per-VS resources and the VS array itself. + vsSlice := relptr.Slice(&ph.Vs, ph.Vs_count) + for i := range vsSlice { + vsSlice[i].free(agent) + } + yanet.FreeSlice(yanetAgent, vsSlice) + + // Free decap address arrays. + decapV4 := relptr.Slice(&ph.Decap_v4, ph.Decap_v4_count) + yanet.FreeSlice(yanetAgent, decapV4) + decapV6 := relptr.Slice(&ph.Decap_v6, ph.Decap_v6_count) + yanet.FreeSlice(yanetAgent, decapV6) + + // Free compiled handler-level filters. + ph.freeDecapFilters() + ph.freeVsMatchers() +} diff --git a/modules/balancer/controlplane/handler/handler.c b/modules/balancer/controlplane/handler/handler.c deleted file mode 100644 index 1f70fd9e3..000000000 --- a/modules/balancer/controlplane/handler/handler.c +++ /dev/null @@ -1,572 +0,0 @@ -#include "handler.h" -#include "api/balancer.h" -#include "api/session.h" -#include "api/vs.h" -#include "common/memory.h" -#include "common/memory_address.h" - -#include "lib/controlplane/agent/agent.h" -#include "lib/controlplane/config/cp_module.h" -#include "lib/controlplane/diag/diag.h" -#include "lib/dataplane/config/zone.h" - -#include "modules/balancer/dataplane/active_sessions.h" - -#include -#include -#include -#include - -#include "api/handler.h" -#include "counters/counters.h" -#include "init.h" -#include "real.h" -#include "rules.h" -#include "services.h" -#include "state/state.h" -#include "vs.h" - -//////////////////////////////////////////////////////////////////////////////// - -static int -prepare_vs_configs( - size_t **initial_vs_idx, - size_t *ipv4_count, - size_t *ipv6_count, - struct packet_handler_config *config -) { - *initial_vs_idx = malloc(config->vs_count * sizeof(size_t)); - for (size_t idx = 0; idx < config->vs_count; ++idx) { - (*initial_vs_idx)[idx] = idx; - } - - if (validate_and_reorder_vs_configs( - *initial_vs_idx, - config->vs_count, - config->vs, - ipv4_count, - ipv6_count - ) != 0) { - PUSH_ERROR("invalid service config"); - free(*initial_vs_idx); - return -1; - } - - return 0; -} - -static int -register_and_prepare_all_vs( - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct packet_handler_config *config, - struct vs *virtual_services, - size_t *initial_vs_idx, - size_t ipv4_count, - size_t ipv6_count, - struct balancer_update_info *update_info, - int *reuse_ipv4_filter, - int *reuse_ipv6_filter -) { - // Register and prepare IPv4 services - if (register_and_prepare_vs( - handler, - prev_handler, - IPPROTO_IP, - ipv4_count, - config->vs, - initial_vs_idx, - virtual_services, - update_info, - reuse_ipv4_filter - ) != 0) { - PUSH_ERROR("prepare IPv4 services"); - return -1; - } - - // Register and prepare IPv6 services - if (register_and_prepare_vs( - handler, - prev_handler, - IPPROTO_IPV6, - ipv6_count, - config->vs + ipv4_count, - initial_vs_idx + ipv4_count, - virtual_services + ipv4_count, - update_info, - reuse_ipv6_filter - ) != 0) { - PUSH_ERROR("prepare IPv6 services"); - return -1; - } - - return 0; -} - -static int -init_all_packet_handler_vs( - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct memory_context *mctx, - struct packet_handler_config *config, - struct counter_registry *registry, - struct real *reals, - size_t *initial_vs_idx, - size_t ipv4_count, - struct balancer_update_info *update_info -) { - size_t reals_counter = 0; - - // Initialize IPv4 packet handler VS - if (init_packet_handler_vs( - handler, - IPPROTO_IP, - mctx, - config->vs, - registry, - prev_handler, - reals, - &reals_counter, - update_info, - initial_vs_idx - ) != 0) { - PUSH_ERROR("initialize IPv4 services"); - return -1; - } - - // Initialize IPv6 packet handler VS - if (init_packet_handler_vs( - handler, - IPPROTO_IPV6, - mctx, - config->vs + ipv4_count, - registry, - prev_handler, - reals, - &reals_counter, - update_info, - initial_vs_idx + ipv4_count - ) != 0) { - PUSH_ERROR("initialize IPv6 services"); - return -1; - } - - return 0; -} - -static int -init_all_vs_filters_and_announce( - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct memory_context *mctx, - struct packet_handler_config *config, - size_t *initial_vs_idx, - size_t ipv4_count, - int reuse_ipv4_filter, - int reuse_ipv6_filter -) { - // Initialize IPv4 VS filter - if (init_vs_filter( - &handler->vs_ipv4, - get_packet_handler_vs(prev_handler, IPPROTO_IP), - config->vs, - reuse_ipv4_filter, - mctx, - initial_vs_idx, - IPPROTO_IP - ) != 0) { - PUSH_ERROR("initialize IPv4 VS matcher"); - return -1; - } - - // Initialize IPv6 VS filter - if (init_vs_filter( - &handler->vs_ipv6, - get_packet_handler_vs(prev_handler, IPPROTO_IPV6), - config->vs + ipv4_count, - reuse_ipv6_filter, - mctx, - initial_vs_idx + ipv4_count, - IPPROTO_IPV6 - ) != 0) { - PUSH_ERROR("initialize IPv6 VS matcher"); - return -1; - } - - // Initialize IPv4 announce - if (init_announce(&handler->vs_ipv4, mctx, config->vs, IPPROTO_IP) != - 0) { - PUSH_ERROR("initialize IPv4 announce"); - return -1; - } - - // Initialize IPv6 announce - if (init_announce( - &handler->vs_ipv6, - mctx, - config->vs + ipv4_count, - IPPROTO_IPV6 - ) != 0) { - PUSH_ERROR("initialize IPv6 announce"); - return -1; - } - - return 0; -} - -static int -init_vs_and_reals( - struct packet_handler *handler, - struct memory_context *mctx, - struct packet_handler_config *config, - struct counter_registry *registry, - struct packet_handler *prev_handler, - struct balancer_update_info *update_info, - size_t workers -) { - size_t *initial_vs_idx = NULL; - size_t ipv4_count = 0; - size_t ipv6_count = 0; - - // Prepare and validate VS configs - if (prepare_vs_configs( - &initial_vs_idx, &ipv4_count, &ipv6_count, config - ) != 0) { - return -1; - } - - // Collect VS identifiers for registry initialization - struct vs_identifier *vs_identifiers = - malloc(sizeof(struct vs_identifier) * config->vs_count); - if (vs_identifiers == NULL && config->vs_count > 0) { - NEW_ERROR("failed to allocate memory for VS identifiers"); - goto free_initial_vs_idx_on_error; - } - for (size_t i = 0; i < config->vs_count; ++i) { - vs_identifiers[i] = config->vs[i].identifier; - } - - // Initialize VS registry - if (vs_registry_init( - &handler->vs_registry, - mctx, - vs_identifiers, - config->vs_count, - prev_handler ? &prev_handler->vs_registry : NULL - ) != 0) { - NEW_ERROR("failed to initialize VS registry"); - free(vs_identifiers); - goto free_initial_vs_idx_on_error; - } - free(vs_identifiers); - - // Initialize reals - if (init_reals( - handler, - prev_handler, - mctx, - config, - registry, - initial_vs_idx, - workers - ) != 0) { - PUSH_ERROR("init reals"); - goto free_vs_registry_on_error; - } - - struct real *reals = ADDR_OF(&handler->reals); - - // Allocate virtual services array - handler->vs_count = config->vs_count; - struct vs *virtual_services = - memory_balloc(mctx, sizeof(struct vs) * config->vs_count); - if (virtual_services == NULL && config->vs_count > 0) { - NEW_ERROR("no memory"); - goto free_vs_registry_on_error; - } - SET_OFFSET_OF(&handler->vs, virtual_services); - - // Register and prepare all VS (both IPv4 and IPv6) - int reuse_ipv4_filter = 0; - int reuse_ipv6_filter = 0; - if (register_and_prepare_all_vs( - handler, - prev_handler, - config, - virtual_services, - initial_vs_idx, - ipv4_count, - ipv6_count, - update_info, - &reuse_ipv4_filter, - &reuse_ipv6_filter - ) != 0) { - goto free_virtual_services_on_error; - } - - // Initialize all packet handler VS - if (init_all_packet_handler_vs( - handler, - prev_handler, - mctx, - config, - registry, - reals, - initial_vs_idx, - ipv4_count, - update_info - ) != 0) { - goto free_virtual_services_on_error; - } - - // Initialize all VS filters and announce - if (init_all_vs_filters_and_announce( - handler, - prev_handler, - mctx, - config, - initial_vs_idx, - ipv4_count, - reuse_ipv4_filter, - reuse_ipv6_filter - ) != 0) { - goto free_virtual_services_on_error; - } - - // Setup VS index mapping - if (setup_vs_index(handler, virtual_services, initial_vs_idx, mctx) != - 0) { - PUSH_ERROR("failed to setup VS index"); - goto free_virtual_services_on_error; - } - - free(initial_vs_idx); - return 0; - -free_virtual_services_on_error: - memory_bfree( - mctx, virtual_services, sizeof(struct vs) * config->vs_count - ); - -free_vs_registry_on_error: - vs_registry_free(&handler->vs_registry); - -free_initial_vs_idx_on_error: - free(initial_vs_idx); - return -1; -} - -#define MAX_TIMEOUT ACTIVE_SESSIONS_TRACKER_MAX_TIMEOUT - -static bool -validate_sessions_timeouts(struct sessions_timeouts *timeouts) { - return (timeouts->tcp <= MAX_TIMEOUT && timeouts->udp <= MAX_TIMEOUT && - timeouts->def <= MAX_TIMEOUT && - timeouts->tcp_fin <= MAX_TIMEOUT && - timeouts->tcp_syn <= MAX_TIMEOUT && - timeouts->tcp_syn_ack <= MAX_TIMEOUT); -} - -struct packet_handler * -packet_handler_setup( - struct agent *agent, - const char *name, - struct packet_handler_config *config, - struct balancer_state *state, - struct packet_handler *prev_handler, - struct balancer_update_info *update_info -) { - if (!validate_sessions_timeouts(&config->sessions_timeouts)) { - NEW_ERROR( - "sessions timeouts are too large (max is %d)", - MAX_TIMEOUT - ); - return NULL; - } - - if (update_info != NULL && config->vs_count > 0) { - update_info->vs_acl_reused = - calloc(config->vs_count, sizeof(struct vs_identifier)); - } - - struct memory_context *mctx = &agent->memory_context; - struct packet_handler *handler = - memory_balloc(mctx, sizeof(struct packet_handler)); - if (handler == NULL) { - NEW_ERROR("failed to allocate packet handler"); - return NULL; - } - memset(handler, 0, sizeof(struct packet_handler)); - SET_OFFSET_OF(&handler->state, state); - - memcpy(&handler->sessions_timeouts, - &config->sessions_timeouts, - sizeof(struct sessions_timeouts)); - - if (cp_module_init(&handler->cp_module, agent, "balancer", name) != 0) { - PUSH_ERROR("failed to initialize controlplane module"); - goto free_handler; - } - - struct counter_registry *counter_registry = - &handler->cp_module.counter_registry; - - if (init_counters(handler, counter_registry) != 0) { - PUSH_ERROR("failed to setup balancer counters"); - goto free_handler; - } - - if (init_sources(handler, mctx, config) != 0) { - PUSH_ERROR("failed to setup source addresses"); - goto free_handler; - } - - if (init_decaps(handler, mctx, config) != 0) { - PUSH_ERROR("failed to setup decap addresses"); - goto free_handler; - } - - size_t workers = ADDR_OF(&agent->dp_config)->worker_count; - if (init_vs_and_reals( - handler, - mctx, - config, - counter_registry, - prev_handler, - update_info, - workers - ) != 0) { - PUSH_ERROR("virtual services"); - goto free_decap; - } - - struct cp_module *cp_module = &handler->cp_module; - if (agent_update_modules(agent, 1, &cp_module) != 0) { - PUSH_ERROR("failed to update controlplane modules"); - goto free_vs; - } - - return handler; - -free_vs: - memory_bfree( - mctx, - ADDR_OF(&handler->vs), - sizeof(struct vs) * handler->vs_count - ); - map_free(&handler->vs_index); - -free_decap: - lpm_free(&handler->decap_ipv4); - lpm_free(&handler->decap_ipv6); - -free_handler: - memory_bfree(mctx, handler, sizeof(struct packet_handler)); - - return NULL; -} - -int -packet_handler_real_idx( - struct packet_handler *handler, - struct real_identifier *real, - struct real_ph_index *real_ph_index -) { - // Look up the real's stable index in the registry - ssize_t stable_idx; - if ((stable_idx = reals_registry_lookup(&handler->reals_registry, real) - ) == -1) { - return -1; - } - - // Look up the config index from the stable index - size_t config_idx; - if (map_find(&handler->reals_index, stable_idx, &config_idx) != 0) { - return -1; - } - - // Get the real and find its VS - struct real *reals = ADDR_OF(&handler->reals); - struct real *r = &reals[config_idx]; - - // Look up VS stable index - ssize_t vs_stable_idx; - if ((vs_stable_idx = vs_registry_lookup( - &handler->vs_registry, &r->identifier.vs_identifier - )) == -1) { - return -1; - } - - // Look up VS config index - size_t vs_config_idx; - if (map_find(&handler->vs_index, vs_stable_idx, &vs_config_idx) != 0) { - return -1; - } - - real_ph_index->vs_idx = vs_config_idx; - - struct vs *vss = ADDR_OF(&handler->vs); - struct vs *vs = &vss[vs_config_idx]; - - real_ph_index->real_idx = config_idx - vs->first_real_idx; - - return 0; -} - -void -packet_handler_free(struct packet_handler *handler) { - if (handler == NULL) { - return; - } - - struct agent *agent = ADDR_OF(&handler->cp_module.agent); - struct memory_context *mctx = &agent->memory_context; - - // Free VS filters (if not reused) - free_filter_ipv4(&handler->vs_ipv4, mctx); - free_filter_ipv6(&handler->vs_ipv6, mctx); - - // Free announce LPMs - lpm_free(&handler->vs_ipv4.announce); - lpm_free(&handler->vs_ipv6.announce); - - // Free VS index maps - map_free(&handler->vs_ipv4.index); - map_free(&handler->vs_ipv6.index); - - // Free each VS's resources - struct vs *vss = ADDR_OF(&handler->vs); - for (size_t i = 0; i < handler->vs_count; ++i) { - vs_free(&vss[i], mctx); - } - - // Free VS array - memory_bfree(mctx, vss, sizeof(struct vs) * handler->vs_count); - - // Free VS index map - map_free(&handler->vs_index); - - // Free VS registry - vs_registry_free(&handler->vs_registry); - - // Free reals array - size_t workers = ADDR_OF(&handler->state)->workers; - struct real *reals = ADDR_OF(&handler->reals); - for (size_t i = 0; i < handler->reals_count; ++i) { - real_free(&reals[i], workers, mctx); - } - memory_bfree(mctx, reals, sizeof(struct real) * handler->reals_count); - - // Free reals index map - map_free(&handler->reals_index); - - // Free reals registry - reals_registry_free(&handler->reals_registry); - - // Free decap LPMs - lpm_free(&handler->decap_ipv4); - lpm_free(&handler->decap_ipv6); - - // Free the handler itself - memory_bfree(mctx, handler, sizeof(struct packet_handler)); -} diff --git a/modules/balancer/controlplane/handler/handler.h b/modules/balancer/controlplane/handler/handler.h deleted file mode 100644 index 7ce0d7a36..000000000 --- a/modules/balancer/controlplane/handler/handler.h +++ /dev/null @@ -1,236 +0,0 @@ -#pragma once - -#include "api/handler.h" -#include "handler/real.h" -#include "lib/controlplane/config/cp_module.h" - -#include "api/real.h" -#include "api/session.h" - -#include "filter/filter.h" - -#include "common/lpm.h" -#include "map.h" -#include "registry.h" - -//////////////////////////////////////////////////////////////////////////////// - -struct balancer_state; -struct packet_handler_config; -struct balancer_update_info; - -/** - * Sentinel value indicating an invalid or non-existent index. - * Used in vs_index and reals_index arrays to mark unmapped entries. - */ -#define INDEX_INVALID ((uint32_t)-1) - -/** - * Per-protocol (IPv4/IPv6) virtual service container. - * - * This structure holds the fast-path lookup structures for a specific IP - * protocol version. It uses relative pointers (see common/memory_address.h) - * for all pointer fields. - * - * Memory Layout - Connection to packet_handler: - * The `vs` field points into the parent packet_handler's vs array, specifically - * to the subset of virtual services for this protocol (IPv4 or IPv6). - * - * Filter Reuse Optimization: - * When updating configuration, if all VS identifiers match between old and new - * configs (same count, same services), the filter can be reused via - * EQUATE_OFFSET. The filter_reused flag prevents double-free during cleanup. - */ -struct packet_handler_vs { - // Fast-path filter for matching packets to virtual services - // Uses destination IP, port, and protocol to find VS index - struct filter *filter; - - uint8_t __padding[56]; - - // Set to 1 when filter is reused from previous packet_handler_vs - // Prevents double-free during configuration updates - // Set to 0 when a new filter is built - int filter_reused; - - // LPM tree for VS address announcement/routing - struct lpm announce; - - // Number of virtual services for this protocol - size_t vs_count; - - // Points to the subset of packet_handler->vs array for this protocol - // (relative pointer to IPv4 or IPv6 virtual services) - struct vs *vs; - - // maps stable vs index to the config index - struct map index; -}; - -/** - * Packet handler instance. - * - * Owns fast-path lookup structures (filters/LPM), VS/real views and counters - * bound to a balancer_state. Used by the control-plane to program dataplane. - */ -struct packet_handler { - // Control-plane module interface for dataplane communication - // Manages counter registry and module lifecycle - struct cp_module cp_module; - - // relative pointer to the balancer state, - // corresponding to this handler - struct balancer_state *state; - - // timeouts of sessions with different types - struct sessions_timeouts sessions_timeouts; - - // Total number of virtual services (IPv4 + IPv6) - size_t vs_count; - - // Array of all virtual services (relative pointer) - // First ipv4_count entries are IPv4, remaining are IPv6 - struct vs *vs; - - // maps stable vs index to the config index - struct map vs_index; - - // registry of all virtual services - vs_registry_t vs_registry; - - // registry of all reals - reals_registry_t reals_registry; - - // Per-protocol virtual service containers - // vs_ipv4.vs points to IPv4 subset of the vs array - // vs_ipv6.vs points to IPv6 subset of the vs array - struct packet_handler_vs vs_ipv4; - struct packet_handler_vs vs_ipv6; - - // reals - size_t reals_count; - struct real *reals; - - // maps stable real index to the config index - struct map reals_index; - - // counter indices - struct { - // common counter - uint64_t common; - - // icmp v4 counter - uint64_t icmp_v4; - - // icmp v6 counter - uint64_t icmp_v6; - - // l4 (tcp and udp) counter - uint64_t l4; - } counter; - - // if packet destination id is from decap list, - // then we make decap - struct lpm decap_ipv4; - struct lpm decap_ipv6; - - // source address of the balancer - struct net4_addr source_ipv4; - struct net6_addr source_ipv6; -}; - -/** - * Setup packet handler and update control-plane modules. - * - * Creates/configures a handler bound to the provided balancer_state. - * Returns pointer on success, or NULL on error. - * - * Diagnostics: errors are recorded and retrievable via - * balancer_take_error_msg() on the balancer associated with this handler. - * - * @param agent Agent that will own the handler. - * @param name Handler name. - * @param config Packet handler configuration. - * @param state Balancer state to bind to. - * @param prev_handler Previous handler for filter reuse (may be NULL). - * @param update_info Output structure filled with update information. - * May be NULL if caller doesn't need this information. - * @return Pointer to new handler on success, NULL on error. - */ -struct packet_handler * -packet_handler_setup( - struct agent *agent, - const char *name, - struct packet_handler_config *config, - struct balancer_state *state, - struct packet_handler *prev_handler, - struct balancer_update_info *update_info -); - -/** - * Apply updates to reals visible by this handler. - * - * @param handler Handler instance. - * @param count Number of updates. - * @param updates Array of real_update entries. - * @return 0 on success, -1 on error. - */ -int -packet_handler_update_reals( - struct packet_handler *handler, - size_t count, - struct real_update *updates -); - -struct packet_handler_ref; -struct balancer_stats; - -/** - * Fill statistics for the packet handler. - * - * Collects counter data from the handler and populates the stats structure. - * - * @param handler Packet handler instance - * @param stats Output structure to fill with statistics - * @param ref Reference to packet handler for counter access - * @return 0 on success, -1 on error - */ -int -packet_handler_fill_stats( - struct packet_handler *handler, - struct balancer_stats *stats, - struct packet_handler_ref *ref -); - -/** - * Get packet handler indices for a real server. - * - * Resolves a real identifier to its indices within the packet handler's - * internal arrays (VS index and real index within that VS). - * - * @param handler Packet handler instance - * @param real Real server identifier to look up - * @param idx Output structure filled with VS and real indices - * @return 0 on success, -1 if real not found - */ -int -packet_handler_real_idx( - struct packet_handler *handler, - struct real_identifier *real, - struct real_ph_index *idx -); - -/** - * Free resources held by a packet handler. - * - * Releases all memory and structures owned by the handler, including: - * - VS and real registries - * - Filters (unless reused by another handler) - * - LPM trees for announce and decap - * - Index maps - * - The handler structure itself - * - * @param handler Packet handler instance to free - */ -void -packet_handler_free(struct packet_handler *handler); diff --git a/modules/balancer/controlplane/handler/info.c b/modules/balancer/controlplane/handler/info.c deleted file mode 100644 index f6c20acc5..000000000 --- a/modules/balancer/controlplane/handler/info.c +++ /dev/null @@ -1,300 +0,0 @@ -#include "info.h" - -#include "api/balancer.h" -#include "api/real.h" -#include "api/vs.h" -#include "common/memory_address.h" -#include "handler/handler.h" -#include "modules/balancer/dataplane/active_sessions.h" -#include "real.h" -#include "state/session.h" -#include "state/state.h" -#include "vs.h" - -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -struct fill_sessions_info_ctx { - struct named_session_info *sessions; - struct packet_handler *handler; - struct balancer_state *state; - size_t size; - size_t capacity; -}; - -static int -fill_sessions_callback( - struct session_id *id, struct session_state *state, void *userdata -) { - struct fill_sessions_info_ctx *ctx = userdata; - if (ctx->size == ctx->capacity) { - ctx->capacity = - ctx->capacity == 0 ? (1 << 16) : ctx->capacity * 2; - struct named_session_info *new_info = - realloc(ctx->sessions, - ctx->capacity * - sizeof(struct named_session_info)); - if (new_info == NULL) { - return -1; - } - ctx->sessions = new_info; - } - - // Check if real is present in current packet handler config - size_t real_config_idx; - if (map_find( - &ctx->handler->reals_index, state->real_id, &real_config_idx - ) != 0) { - // real not present in current packet handler config - return 0; - } - - struct named_session_info *session_info = &ctx->sessions[ctx->size++]; - - // Get real from handler's reals array - struct real *reals = ADDR_OF(&ctx->handler->reals); - struct real *real = &reals[real_config_idx]; - - // fill identifier - session_info->identifier.real = real->identifier; - session_info->identifier.client_ip = id->client_ip; - session_info->identifier.client_port = ntohs(id->client_port); - - // fill info - session_info->info = (struct session_info - ){.create_timestamp = state->create_timestamp, - .last_packet_timestamp = state->last_packet_timestamp, - .timeout = state->timeout}; - - return 0; -} - -size_t -packet_handler_sessions_info( - struct packet_handler *handler, - struct named_session_info **sessions, - uint32_t now -) { - struct balancer_state *state = ADDR_OF(&handler->state); - struct fill_sessions_info_ctx ctx = { - .state = state, - .sessions = NULL, - .size = 0, - .capacity = 0, - .handler = handler, - }; - int res = session_table_iter( - &state->session_table, now, fill_sessions_callback, &ctx - ); - assert(res == 0); - *sessions = ctx.sessions; - return ctx.size; -} - -//////////////////////////////////////////////////////////////////////////////// - -static void -init_real_info(struct named_real_info *info, struct real *real) { - info->real = real->identifier.relative; - info->active_sessions = 0; - info->last_packet_timestamp = 0; -} - -static void -init_real_infos( - struct named_real_info *real_infos, struct packet_handler *handler -) { - struct real *reals = ADDR_OF(&handler->reals); - for (size_t i = 0; i < handler->reals_count; i++) { - init_real_info(&real_infos[i], &reals[i]); - } -} - -static void -init_vs_info( - struct named_vs_info *info, - struct vs *vs, - struct named_real_info *real_infos -) { - info->identifier = vs->identifier; - info->reals_count = vs->reals_count; - info->reals = real_infos; - info->active_sessions = 0; - info->last_packet_timestamp = 0; -} - -static void -init_vs_infos( - struct named_vs_info *vs_infos, - struct named_real_info *real_infos, - struct packet_handler *handler -) { - struct vs *vss = ADDR_OF(&handler->vs); - size_t reals_counter = 0; - for (size_t i = 0; i < handler->vs_count; i++) { - struct vs *vs = &vss[i]; - struct named_vs_info *info = &vs_infos[i]; - init_vs_info(info, vs, real_infos + reals_counter); - reals_counter += vs->reals_count; - } -} - -struct fill_balancer_info_ctx { - struct balancer_info *info; - struct named_real_info *reals; - struct packet_handler *handler; - struct balancer_state *state; - uint32_t now; -}; - -static void -check_max(uint32_t *value, uint32_t c) { - if (*value < c) { - *value = c; - } -} - -int -fill_balancer_info_callback( - struct session_id *id, struct session_state *state, void *userdata -) { - struct fill_balancer_info_ctx *ctx = userdata; - - // Check if real is present in current packet handler config - size_t real_config_idx; - if (map_find( - &ctx->handler->reals_index, state->real_id, &real_config_idx - ) != 0) { - // real not present in packet handler config - return 0; - } - - const int is_session_active = - state->last_packet_timestamp + state->timeout > ctx->now; - - // Check if VS is present in current packet handler config - size_t vs_config_idx; - int res = map_find(&ctx->handler->vs_index, id->vs_id, &vs_config_idx); - assert(res == 0); - - ctx->info->active_sessions += is_session_active; - check_max( - &ctx->info->last_packet_timestamp, state->last_packet_timestamp - ); - - struct named_real_info *real_info = &ctx->reals[real_config_idx]; - real_info->active_sessions += is_session_active; - check_max( - &real_info->last_packet_timestamp, state->last_packet_timestamp - ); - - struct named_vs_info *vs_info = &ctx->info->vs[vs_config_idx]; - vs_info->active_sessions += is_session_active; - check_max( - &vs_info->last_packet_timestamp, state->last_packet_timestamp - ); - - return 0; -} - -void -packet_handler_balancer_info( - struct packet_handler *handler, struct balancer_info *info, uint32_t now -) { - struct balancer_state *state = ADDR_OF(&handler->state); - - struct named_real_info *reals = - malloc(sizeof(struct named_real_info) * handler->reals_count); - init_real_infos(reals, handler); - - struct named_vs_info *vs = - malloc(sizeof(struct named_vs_info) * handler->vs_count); - init_vs_infos(vs, reals, handler); - - // Initialize info structure - info->vs_count = handler->vs_count; - info->vs = vs; - info->active_sessions = 0; - info->last_packet_timestamp = 0; - - struct fill_balancer_info_ctx ctx = { - .handler = handler, - .state = state, - .info = info, - .reals = reals, - .now = now - }; - - int res = session_table_iter( - &state->session_table, 0, fill_balancer_info_callback, &ctx - ); - assert(res == 0); -} - -void -packet_handler_active_sessions( - struct packet_handler *handler, struct balancer_info *info -) { - struct balancer_state *state = ADDR_OF(&handler->state); - const size_t workers = state->workers; - - struct named_real_info *reals_info = - malloc(sizeof(struct named_real_info) * handler->reals_count); - init_real_infos(reals_info, handler); - - struct named_vs_info *vs_info = - malloc(sizeof(struct named_vs_info) * handler->vs_count); - init_vs_infos(vs_info, reals_info, handler); - - // Initialize info structure - info->vs_count = handler->vs_count; - info->vs = vs_info; - info->active_sessions = 0; - info->last_packet_timestamp = 0; - - // fill reals - struct real *real = ADDR_OF(&handler->reals); - for (size_t real_idx = 0; real_idx < handler->reals_count; ++real_idx) { - struct active_sessions_tracker_shard *tracker_shards = - ADDR_OF(&real[real_idx].tracker_shards); - for (size_t worker_idx = 0; worker_idx < workers; - ++worker_idx) { - struct active_sessions_tracker_shard *shard = - &tracker_shards[worker_idx]; - struct named_real_info *cur_real_info = - &reals_info[real_idx]; - cur_real_info->active_sessions += shard->count; - if (cur_real_info->last_packet_timestamp < - shard->last_packet_timestamp) { - cur_real_info->last_packet_timestamp = - shard->last_packet_timestamp; - } - } - } - - // fill virtual services - for (size_t vs_idx = 0; vs_idx < handler->vs_count; ++vs_idx) { - struct named_vs_info *cur_vs_info = &vs_info[vs_idx]; - for (size_t real_idx = 0; real_idx < cur_vs_info->reals_count; - ++real_idx) { - struct named_real_info *cur_real_info = - &cur_vs_info->reals[real_idx]; - cur_vs_info->active_sessions += - cur_real_info->active_sessions; - if (cur_vs_info->last_packet_timestamp < - cur_real_info->last_packet_timestamp) { - cur_vs_info->last_packet_timestamp = - cur_real_info->last_packet_timestamp; - } - } - info->active_sessions += cur_vs_info->active_sessions; - if (info->last_packet_timestamp < - cur_vs_info->last_packet_timestamp) { - info->last_packet_timestamp = - cur_vs_info->last_packet_timestamp; - } - } -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/info.h b/modules/balancer/controlplane/handler/info.h deleted file mode 100644 index fdb23fb0c..000000000 --- a/modules/balancer/controlplane/handler/info.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "api/session.h" -#include - -struct packet_handler; -struct balancer_info; - -/** - * Retrieve information about all active sessions in the packet handler. - * - * Iterates through the session table and collects information about each - * session that belongs to a real server present in the current packet handler - * configuration. Allocates memory for the sessions array which must be freed - * by the caller. - * - * @param handler Packet handler instance - * @param sessions Output pointer to allocated array of session info - * (caller must free this memory) - * @param now Current timestamp for session timeout calculations - * @return Number of sessions in the output array - */ -size_t -packet_handler_sessions_info( - struct packet_handler *handler, - struct named_session_info **sessions, - uint32_t now -); - -/** - * Collect comprehensive balancer information including VS and real statistics. - * - * Aggregates session counts and timestamps for all virtual services and their - * real servers. Allocates memory for the info structure's internal arrays - * which must be freed by the caller. - * - * @param handler Packet handler instance - * @param info Output structure to populate with balancer information - * (caller must free info->vs and nested arrays) - * @param now Current timestamp for active session calculations - */ -void -packet_handler_balancer_info( - struct packet_handler *handler, struct balancer_info *info, uint32_t now -); - -void -packet_handler_active_sessions( - struct packet_handler *handler, struct balancer_info *info -); diff --git a/modules/balancer/controlplane/handler/init.c b/modules/balancer/controlplane/handler/init.c deleted file mode 100644 index de3d0292c..000000000 --- a/modules/balancer/controlplane/handler/init.c +++ /dev/null @@ -1,266 +0,0 @@ -#include "init.h" - -#include "common/lpm.h" -#include "common/memory.h" -#include "common/memory_address.h" -#include "handler.h" -#include "lib/controlplane/diag/diag.h" -#include "map.h" -#include "real.h" -#include "registry.h" - -#include -#include - -extern uint64_t -register_common_counter(struct counter_registry *registry); - -extern uint64_t -register_icmp_v4_counter(struct counter_registry *registry); - -extern uint64_t -register_icmp_v6_counter(struct counter_registry *registry); - -extern uint64_t -register_l4_counter(struct counter_registry *registry); - -int -init_counters( - struct packet_handler *handler, struct counter_registry *registry -) { - if ((handler->counter.common = register_common_counter(registry)) == - (uint64_t)-1) { - PUSH_ERROR("failed to register common counter"); - return -1; - } - if ((handler->counter.icmp_v4 = register_icmp_v4_counter(registry)) == - (uint64_t)-1) { - PUSH_ERROR("failed to register ICMPv4 counter"); - return -1; - } - if ((handler->counter.icmp_v6 = register_icmp_v6_counter(registry)) == - (uint64_t)-1) { - PUSH_ERROR("failed to register ICMPv6 counter"); - return -1; - } - if ((handler->counter.l4 = register_l4_counter(registry)) == - (uint64_t)-1) { - PUSH_ERROR("failed to register L4 counter"); - return -1; - } - - return 0; -} - -int -init_sources( - struct packet_handler *handler, - struct memory_context *mctx, - struct packet_handler_config *config -) { - (void)mctx; - memcpy(&handler->source_ipv4, - &config->source_v4, - sizeof(struct net4_addr)); - memcpy(&handler->source_ipv6, - &config->source_v6, - sizeof(struct net6_addr)); - return 0; -} - -int -init_decaps( - struct packet_handler *handler, - struct memory_context *mctx, - struct packet_handler_config *config -) { - // init ipv4 decap addresses - if (lpm_init(&handler->decap_ipv4, mctx) != 0) { - NEW_ERROR( - "failed to allocate container for decap IPv4 addresses" - ); - return -1; - } - for (size_t i = 0; i < config->decap_v4_count; i++) { - struct net4_addr *addr = &config->decap_v4[i]; - if (lpm4_insert( - &handler->decap_ipv4, addr->bytes, addr->bytes, 1 - ) != 0) { - lpm_free(&handler->decap_ipv4); - NEW_ERROR( - "failed to insert decap IPv4 address at index " - "%zu", - i - ); - return -1; - } - } - - // init ipv6 decap addresses - if (lpm_init(&handler->decap_ipv6, mctx) != 0) { - NEW_ERROR( - "failed to allocate container for decap IPv6 addresses" - ); - return -1; - } - for (size_t i = 0; i < config->decap_v6_count; i++) { - struct net6_addr *addr = &config->decap_v6[i]; - if (lpm8_insert( - &handler->decap_ipv6, addr->bytes, addr->bytes, 1 - ) != 0) { - lpm_free(&handler->decap_ipv4); - lpm_free(&handler->decap_ipv6); - NEW_ERROR( - "failed to insert decap IPv6 address at index " - "%zu", - i - ); - return -1; - } - } - - return 0; -} - -int -setup_reals_index( - struct packet_handler *handler, - struct memory_context *mctx, - struct real *reals, - size_t reals_count -) { - // Build key-value pairs for the map (stable_idx -> config_idx) - struct key_value *entries = - malloc(sizeof(struct key_value) * reals_count); - if (entries == NULL && reals_count > 0) { - NEW_ERROR("failed to allocate memory for reals index entries"); - return -1; - } - - for (size_t i = 0; i < reals_count; ++i) { - entries[i].key = reals[i].stable_idx; - entries[i].value = i; - } - - // Initialize the map - if (map_init(&handler->reals_index, mctx, entries, reals_count) != 0) { - NEW_ERROR("failed to initialize reals index map"); - free(entries); - return -1; - } - - free(entries); - return 0; -} - -int -init_reals( - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct memory_context *mctx, - struct packet_handler_config *config, - struct counter_registry *registry, - size_t *initial_vs_idx, - size_t workers -) { - // Count total reals - size_t real_count = 0; - for (size_t i = 0; i < config->vs_count; ++i) { - real_count += config->vs[i].config.real_count; - } - handler->reals_count = real_count; - - // Collect all real identifiers for registry initialization - struct real_identifier *real_identifiers = - malloc(sizeof(struct real_identifier) * real_count); - if (real_identifiers == NULL && real_count > 0) { - NEW_ERROR("failed to allocate memory for real identifiers"); - return -1; - } - - size_t real_idx = 0; - for (size_t i = 0; i < config->vs_count; ++i) { - struct named_vs_config *vs_config = &config->vs[i]; - for (size_t j = 0; j < vs_config->config.real_count; ++j) { - struct named_real_config *real_config = - &vs_config->config.reals[j]; - real_identifiers[real_idx].vs_identifier = - vs_config->identifier; - real_identifiers[real_idx].relative = real_config->real; - ++real_idx; - } - } - - // Initialize reals registry - if (reals_registry_init( - &handler->reals_registry, - mctx, - real_identifiers, - real_count, - prev_handler ? &prev_handler->reals_registry : NULL - ) != 0) { - NEW_ERROR("failed to initialize reals registry"); - free(real_identifiers); - return -1; - } - free(real_identifiers); - - // Allocate reals array - struct real *reals = - memory_balloc(mctx, sizeof(struct real) * real_count); - if (reals == NULL && real_count > 0) { - NEW_ERROR("no memory for reals array"); - reals_registry_free(&handler->reals_registry); - return -1; - } - if (real_count > 0) { - memset(reals, 0, sizeof(struct real) * real_count); - } - SET_OFFSET_OF(&handler->reals, reals); - - // Initialize each real - size_t real_ph_idx = 0; - for (size_t i = 0; i < config->vs_count; ++i) { - struct named_vs_config *vs_config = &config->vs[i]; - for (size_t j = 0; j < vs_config->config.real_count; ++j) { - struct named_real_config *real_config = - &vs_config->config.reals[j]; - struct real *real = &reals[real_ph_idx]; - if (real_init( - real, - handler, - prev_handler, - &vs_config->identifier, - real_config, - registry, - workers, - mctx - ) != 0) { - PUSH_ERROR( - "service at index %zu: real at index " - "%zu", - initial_vs_idx[i], - j - ); - memory_bfree( - mctx, - reals, - sizeof(struct real) * real_count - ); - reals_registry_free(&handler->reals_registry); - return -1; - } - ++real_ph_idx; - } - } - - // Setup reals index map - if (setup_reals_index(handler, mctx, reals, real_count) != 0) { - PUSH_ERROR("failed to setup reals index"); - memory_bfree(mctx, reals, sizeof(struct real) * real_count); - reals_registry_free(&handler->reals_registry); - return -1; - } - - return 0; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/init.h b/modules/balancer/controlplane/handler/init.h deleted file mode 100644 index 9043c6cd6..000000000 --- a/modules/balancer/controlplane/handler/init.h +++ /dev/null @@ -1,82 +0,0 @@ -#pragma once - -#include -#include - -struct packet_handler; -struct counter_registry; -struct memory_context; -struct packet_handler_config; -struct balancer_state; - -struct real; - -/** - * Initialize packet handler counters. - * - * @param handler Packet handler instance - * @param registry Counter registry - * @return 0 on success, -1 on error - */ -int -init_counters( - struct packet_handler *handler, struct counter_registry *registry -); - -/** - * Initialize source addresses for the packet handler. - * - * @param handler Packet handler instance - * @param mctx Memory context - * @param config Packet handler configuration - * @return 0 on success, -1 on error - */ -int -init_sources( - struct packet_handler *handler, - struct memory_context *mctx, - struct packet_handler_config *config -); - -/** - * Initialize decap addresses for the packet handler. - * - * @param handler Packet handler instance - * @param mctx Memory context - * @param config Packet handler configuration - * @return 0 on success, -1 on error - */ -int -init_decaps( - struct packet_handler *handler, - struct memory_context *mctx, - struct packet_handler_config *config -); - -/** - * Setup reals index mapping. - * - * @param handler Packet handler instance - * @param mctx Memory context - * @param reals Array of reals - * @param reals_count Number of reals - * @return 0 on success, -1 on error - */ -int -setup_reals_index( - struct packet_handler *handler, - struct memory_context *mctx, - struct real *reals, - size_t reals_count -); - -int -init_reals( - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct memory_context *mctx, - struct packet_handler_config *config, - struct counter_registry *registry, - size_t *initial_vs_idx, - size_t workers -); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/inspect.c b/modules/balancer/controlplane/handler/inspect.c deleted file mode 100644 index 8f48944cb..000000000 --- a/modules/balancer/controlplane/handler/inspect.c +++ /dev/null @@ -1,70 +0,0 @@ -#include "inspect.h" -#include "api/inspect.h" -#include "api/stats.h" -#include "common/lpm.h" -#include "common/memory_address.h" -#include "compiler.h" -#include "map.h" -#include "registry.h" -#include "vs.h" -#include - -void -packet_handler_vs_inspect( - struct packet_handler_vs *handler_vs, - struct packet_handler_vs_inspect *inspect, - size_t workers -) { - inspect->matcher_usage = - filter_memory_usage(ADDR_OF(&handler_vs->filter)); - inspect->summary_vs_usage = 0; - inspect->vs_count = handler_vs->vs_count; - inspect->vs_inspects = - malloc(sizeof(struct named_vs_inspect) * handler_vs->vs_count); - struct vs *vs = ADDR_OF(&handler_vs->vs); - for (size_t vs_idx = 0; vs_idx < handler_vs->vs_count; ++vs_idx) { - struct named_vs_inspect *vs_inspect = - inspect->vs_inspects + vs_idx; - vs_inspect->identifier = vs[vs_idx].identifier; - vs_fill_inspect(vs + vs_idx, &vs_inspect->inspect, workers); - inspect->summary_vs_usage += vs_inspect->inspect.total_usage; - } - inspect->announce_usage = lpm_memory_usage(&handler_vs->announce); - inspect->index_usage = map_memory_usage(&handler_vs->index); - inspect->total_usage = inspect->matcher_usage + - inspect->summary_vs_usage + - inspect->announce_usage + inspect->index_usage; -} - -void -packet_handler_inspect( - struct packet_handler *handler, - struct packet_handler_inspect *inspect, - size_t workers -) { - packet_handler_vs_inspect( - &handler->vs_ipv4, &inspect->vs_ipv4_inspect, workers - ); - packet_handler_vs_inspect( - &handler->vs_ipv6, &inspect->vs_ipv6_inspect, workers - ); - inspect->summary_vs_usage = inspect->vs_ipv4_inspect.summary_vs_usage + - inspect->vs_ipv6_inspect.summary_vs_usage + - sizeof(struct vs) * handler->vs_count; - inspect->reals_index_usage = - map_memory_usage(&handler->reals_index) + - reals_registry_usage(&handler->reals_registry); - inspect->vs_index_usage = map_memory_usage(&handler->vs_index) + - vs_registry_usage(&handler->vs_registry); - inspect->counters_usage = (sizeof(struct balancer_icmp_stats) * 2 + - sizeof(struct balancer_common_stats) + - sizeof(struct balancer_l4_stats)) * - workers; - inspect->decap_usage = lpm_memory_usage(&handler->decap_ipv4) + - lpm_memory_usage(&handler->decap_ipv6); - inspect->total_usage = inspect->vs_ipv4_inspect.total_usage + - inspect->vs_ipv6_inspect.total_usage + - inspect->reals_index_usage + - inspect->vs_index_usage + - inspect->counters_usage + inspect->decap_usage; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/inspect.h b/modules/balancer/controlplane/handler/inspect.h deleted file mode 100644 index 423a05eb7..000000000 --- a/modules/balancer/controlplane/handler/inspect.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include "api/inspect.h" -#include "handler.h" - -//////////////////////////////////////////////////////////////////////////////// - -// TODO: docs -void -packet_handler_inspect( - struct packet_handler *handler, - struct packet_handler_inspect *inspect, - size_t workers -); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/map.c b/modules/balancer/controlplane/handler/map.c deleted file mode 100644 index d30f83da3..000000000 --- a/modules/balancer/controlplane/handler/map.c +++ /dev/null @@ -1,68 +0,0 @@ -#include "map.h" -#include "common/big_array.h" -#include "common/btree/u64.h" -#include -#include - -static int -cmp_kv(const void *left, const void *right) { - const struct key_value *left_entry = (const struct key_value *)left; - const struct key_value *right_entry = (const struct key_value *)right; - return left_entry->key - right_entry->key; -}; - -static inline void -set(struct big_array *array, size_t idx, size_t value) { - memcpy(big_array_get(array, idx * sizeof(size_t)), - &value, - sizeof(size_t)); -} - -int -map_init( - struct map *map, - struct memory_context *mctx, - struct key_value *entries, - size_t entry_count -) { - if (big_array_init(&map->keys, sizeof(size_t) * entry_count, mctx) != - 0) { - return -1; - } - if (big_array_init(&map->values, sizeof(size_t) * entry_count, mctx) != - 0) { - big_array_free(&map->keys); - return -1; - } - qsort((void *)entries, entry_count, sizeof(struct key_value), cmp_kv); - uint64_t *keys = malloc(sizeof(uint64_t) * entry_count); - for (size_t i = 0; i < entry_count; ++i) { - keys[i] = entries[i].key; - } - if (btree_u64_init(&map->btree, keys, entry_count, mctx) != 0) { - big_array_free(&map->keys); - big_array_free(&map->values); - free(keys); - return -1; - } - for (size_t i = 0; i < entry_count; ++i) { - set(&map->keys, i, entries[i].key); - set(&map->values, i, entries[i].value); - } - free(keys); - return 0; -} - -void -map_free(struct map *map) { - big_array_free(&map->keys); - big_array_free(&map->values); - btree_u64_free(&map->btree); -} - -size_t -map_memory_usage(struct map *map) { - return big_array_memory_usage(&map->keys) + - big_array_memory_usage(&map->values) + - btree_u64_memory_usage(&map->btree); -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/map.h b/modules/balancer/controlplane/handler/map.h deleted file mode 100644 index ea41467ae..000000000 --- a/modules/balancer/controlplane/handler/map.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include "common/big_array.h" -#include "common/btree/u64.h" -#include "common/likely.h" - -struct map { - struct btree_u64 btree; - struct big_array keys; - struct big_array values; -}; - -struct key_value { - size_t key; - size_t value; -}; - -int -map_init( - struct map *map, - struct memory_context *mctx, - struct key_value *entries, - size_t entry_count -); - -void -map_free(struct map *map); - -static inline int -map_find(struct map *map, size_t key, size_t *value) { - size_t lb = btree_u64_lower_bound(&map->btree, key); - if (unlikely(lb == map->btree.n)) { - return -1; - } - size_t lb_key = - *(size_t *)big_array_get(&map->keys, lb * sizeof(size_t)); - if (lb_key == key) { - *value = *(size_t *)big_array_get( - &map->values, lb * sizeof(size_t) - ); - return 0; - } else { - return -1; - } -} - -size_t -map_memory_usage(struct map *map); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/meson.build b/modules/balancer/controlplane/handler/meson.build deleted file mode 100644 index b9ae6a061..000000000 --- a/modules/balancer/controlplane/handler/meson.build +++ /dev/null @@ -1,38 +0,0 @@ -dependencies = [ - lib_common_dep, - lib_agent_cp_dep, - lib_balancer_state_dep, - lib_filter_compiler_dep, -] - -includes = include_directories('.', '../') - -sources = files( - 'handler.c', - 'real.c', - 'selector.c', - 'update.c', - 'vs.c', - 'stats.c', - 'info.c', - 'rules.c', - 'init.c', - 'services.c', - 'inspect.c', - 'registry.c', - 'map.c', -) - -lib_balancer_packet_handler = static_library( - 'balancer_packet_handler', - sources, - c_args: yanet_c_args, - link_args: yanet_link_args, - dependencies: dependencies, - include_directories: includes, - install: false, -) - -lib_balancer_packet_handler_dep = declare_dependency( - link_with: [lib_balancer_packet_handler], -) diff --git a/modules/balancer/controlplane/handler/real.c b/modules/balancer/controlplane/handler/real.c deleted file mode 100644 index fa952285b..000000000 --- a/modules/balancer/controlplane/handler/real.c +++ /dev/null @@ -1,165 +0,0 @@ -#include "real.h" - -#include "api/counter.h" -#include "api/real.h" - -#include "api/vs.h" -#include "common/memory.h" -#include "common/memory_address.h" -#include "common/network.h" -#include "handler.h" -#include "lib/controlplane/diag/diag.h" -#include "lib/counters/counters.h" -#include "modules/balancer/dataplane/active_sessions.h" -#include "registry.h" - -#include -#include -#include - -#include "modules/balancer/controlplane/state/active_sessions.h" - -static int -real_init_active_sessions_tracker( - struct real *real, - struct real *prev_real, - size_t workers, - struct memory_context *mctx -) { - real->tracker_reused = false; - - if (prev_real != NULL) { - EQUATE_OFFSET( - &real->tracker_shards, &prev_real->tracker_shards - ); - prev_real->tracker_reused = true; - return 0; - } - - struct active_sessions_tracker_shard *tracker_shards = - active_sessions_tracker_create(mctx, workers, 0); - if (tracker_shards == NULL) { - NEW_ERROR("no memory"); - return -1; - } - SET_OFFSET_OF(&real->tracker_shards, tracker_shards); - - return 0; -} - -int -real_init( - struct real *real, - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct vs_identifier *vs, - struct named_real_config *named_config, - struct counter_registry *registry, - size_t workers, - struct memory_context *mctx -) { - // Build full real identifier - struct real_identifier identifier; - memset(&identifier, 0, sizeof(identifier)); - identifier.vs_identifier = *vs; - identifier.relative = named_config->real; - - // Look up stable index in handler's registry (must already be - // initialized) - ssize_t stable_idx = - reals_registry_lookup(&handler->reals_registry, &identifier); - if (stable_idx < 0) { - NEW_ERROR("real not found in registry"); - return -1; - } - - // Register counter using stable index - char name[60]; - sprintf(name, "rl_%zu", (size_t)stable_idx); - uint64_t counter_id = counter_registry_register( - registry, name, sizeof(struct real_stats) / sizeof(uint64_t) - ); - if (counter_id == (size_t)-1) { - NEW_ERROR("failed to register counter"); - return -1; - } - - // Determine enabled and weight - preserve from previous config if - // exists - bool enabled = false; // default - uint16_t weight = named_config->config.weight; - - struct real *prev_real = NULL; - if (prev_handler) { - // Check if this real existed in previous handler - size_t prev_config_idx; - if (map_find( - &prev_handler->reals_index, - stable_idx, - &prev_config_idx - ) == 0) { - // Real existed - preserve its mutable state - struct real *prev_reals = ADDR_OF(&prev_handler->reals); - prev_real = &prev_reals[prev_config_idx]; - } - } - - if (prev_real) { - enabled = prev_real->enabled; - weight = prev_real->weight; - } - - // Mask the source address based on IP protocol version - struct net src = named_config->config.src; - if (named_config->real.ip_proto == IPPROTO_IP) { // IPv4 - uint8_t *src_addr = src.v4.addr; - const uint8_t *src_mask = src.v4.mask; - for (size_t i = 0; i < NET4_LEN; i++) { - src_addr[i] &= src_mask[i]; - } - } else { // IPv6 - uint8_t *src_addr = src.v6.addr; - const uint8_t *src_mask = src.v6.mask; - for (size_t i = 0; i < NET6_LEN; i++) { - src_addr[i] &= src_mask[i]; - } - } - - // Initialize the real structure - struct real r = { - .identifier = identifier, // Full identifier (includes VS) - .stable_idx = (size_t)stable_idx, - .counter_id = counter_id, - .src = src, - .enabled = enabled, - .weight = weight - }; - memcpy(real, &r, sizeof(struct real)); - - // Initialize active sessions tracker - if (real_init_active_sessions_tracker(real, prev_real, workers, mctx) != - 0) { - return -1; - } - assert(real->tracker_shards != NULL); - - return 0; -} - -ssize_t -counter_to_real_registry_idx(struct counter_handle *counter) { - if (strncmp(counter->name, "rl_", 3) == 0) { - return atoi(counter->name + 3); - } else { - return -1; - } -} - -void -real_free(struct real *real, size_t workers, struct memory_context *mctx) { - if (!real->tracker_reused) { - active_sessions_tracker_destroy( - ADDR_OF(&real->tracker_shards), workers, mctx - ); - } -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/real.h b/modules/balancer/controlplane/handler/real.h deleted file mode 100644 index 21e86b544..000000000 --- a/modules/balancer/controlplane/handler/real.h +++ /dev/null @@ -1,81 +0,0 @@ -#pragma once - -#include -#include - -#include "api/real.h" -#include "api/vs.h" -#include "common/network.h" -#include "counters/counters.h" -#include "modules/balancer/dataplane/active_sessions.h" - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Handler-side view of a real backend server. - * - * Represents a single backend server within a virtual service. All fields are - * const to ensure immutability after initialization. The structure is designed - * to be compact and cache-friendly for fast-path lookups. - * - * Memory Layout: - * This structure is part of the packet_handler's reals array. Each VS points - * to a contiguous subset of this array via vs->reals (relative pointer). - */ -struct real { - // Stable index in the handler's real registry - // Preserved across config updates for the same real - const size_t stable_idx; - - // Counter ID for tracking statistics for this real server - // Registered as "rl_" in the counter registry - const uint64_t counter_id; - - // Relative pointer to the - // array of per-worker sessions tracker - struct active_sessions_tracker_shard *tracker_shards; - - // Scheduler weight [0..MAX_REAL_WEIGHT] - uint16_t weight; - - // Source network used for encapsulation/routing to this backend - // The address is masked by the mask during initialization - const struct net src; - - // Full identifier of the real server - const struct real_identifier identifier; - - // Mutable state - preserved from previous config or set from config - // Whether traffic is allowed to this real. False by default - bool enabled; - - bool tracker_reused; -}; - -//////////////////////////////////////////////////////////////////////////////// - -struct balancer_state; -struct counter_registry; -struct packet_handler; - -int -real_init( - struct real *real, - struct packet_handler *handler, - struct packet_handler *prev_handler, - struct vs_identifier *vs, - struct named_real_config *named_config, - struct counter_registry *registry, - size_t workers, - struct memory_context *mctx -); - -void -real_free(struct real *real, size_t workers, struct memory_context *mctx); - -/** - * Resolve real registry index from a counter handle. - * Returns index on success, or -1 on error. - */ -ssize_t -counter_to_real_registry_idx(struct counter_handle *counter); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/registry.c b/modules/balancer/controlplane/handler/registry.c deleted file mode 100644 index 9ff427a74..000000000 --- a/modules/balancer/controlplane/handler/registry.c +++ /dev/null @@ -1,265 +0,0 @@ -#include "registry.h" -#include "common/big_array.h" -#include "common/memory.h" -#include "controlplane/diag/diag.h" -#include -#include -#include -#include -#include - -static int -vs_identifier_cmp(const void *left, const void *right); - -static int -real_identifier_cmp(const void *left, const void *right); - -int -service_registry_init( - struct service_registry *registry, - struct memory_context *mctx, - void *elems, - size_t elem_size, - size_t elems_count, - registry_cmp cmp, - struct service_registry *prev -) { - if (prev != NULL && prev->elem_size != elem_size) { - NEW_ERROR("internal error: incompatible registry: " - "prev->elem_size != elem_size"); - return -1; - } - - registry->elem_size = elem_size; - registry->elems_count = elems_count; - - if (big_array_init(®istry->elems, elem_size * elems_count, mctx) != - 0) { - NEW_ERROR("no memory"); - return -1; - } - - if (big_array_init( - ®istry->indices, elems_count * sizeof(size_t), mctx - ) != 0) { - big_array_free(®istry->elems); - NEW_ERROR("no memory"); - return -1; - } - - registry->next_stable_index = - prev != NULL ? prev->next_stable_index : 0; - - qsort(elems, elems_count, elem_size, cmp); - - uint8_t *elems_bytes = (uint8_t *)elems; - for (size_t idx = 0; idx < elems_count; ++idx) { - void *elem_ptr = elems_bytes + idx * elem_size; - - ssize_t stable_idx = - prev != NULL - ? service_registry_lookup(prev, elem_ptr, cmp) - : -1; - if (stable_idx == -1) { - stable_idx = (ssize_t)registry->next_stable_index++; - } - - size_t stable_idx_u = (size_t)stable_idx; - memcpy(big_array_get(®istry->elems, idx * elem_size), - elem_ptr, - elem_size); - memcpy(big_array_get(®istry->indices, idx * sizeof(size_t)), - &stable_idx_u, - sizeof(size_t)); - } - - return 0; -} - -void -service_registry_free(struct service_registry *registry) { - big_array_free(®istry->elems); - big_array_free(®istry->indices); - memset(registry, 0, sizeof(*registry)); -} - -size_t -service_registry_usage(struct service_registry *registry) { - return big_array_memory_usage(®istry->elems) + - big_array_memory_usage(®istry->indices); -} - -ssize_t -service_registry_lookup( - struct service_registry *registry, const void *elem, registry_cmp cmp -) { - ssize_t left = -1; - ssize_t right = (ssize_t)registry->elems_count; - while (left + 1 < right) { - ssize_t idx = (left + right) / 2; - - void *elem_at_idx = big_array_get( - ®istry->elems, (size_t)idx * registry->elem_size - ); - int cmp_res = cmp(elem, elem_at_idx); - if (cmp_res == 0) { - size_t stable_idx_u = *(size_t *)big_array_get( - ®istry->indices, (size_t)idx * sizeof(size_t) - ); - return (ssize_t)stable_idx_u; - } else if (cmp_res < 0) { - right = idx; - } else { - left = idx; - } - } - return -1; -} - -static int -cmp_u8(uint8_t a, uint8_t b) { - return (a > b) - (a < b); -} - -static int -cmp_u16(uint16_t a, uint16_t b) { - return (a > b) - (a < b); -} - -static int -vs_identifier_cmp(const void *left, const void *right) { - const struct vs_identifier *a = (const struct vs_identifier *)left; - const struct vs_identifier *b = (const struct vs_identifier *)right; - - int c = cmp_u8(a->ip_proto, b->ip_proto); - if (c != 0) { - return c; - } - - if (a->ip_proto == IPPROTO_IP) { - c = memcmp(a->addr.v4.bytes, b->addr.v4.bytes, NET4_LEN); - } else if (a->ip_proto == IPPROTO_IPV6) { - c = memcmp(a->addr.v6.bytes, b->addr.v6.bytes, NET6_LEN); - } else { - /* Fallback for unexpected protocol values */ - c = memcmp(&a->addr, &b->addr, sizeof(a->addr)); - } - if (c != 0) { - return c; - } - - c = cmp_u16(a->port, b->port); - if (c != 0) { - return c; - } - - return cmp_u8(a->transport_proto, b->transport_proto); -} - -static int -real_identifier_cmp(const void *left, const void *right) { - const struct real_identifier *a = (const struct real_identifier *)left; - const struct real_identifier *b = (const struct real_identifier *)right; - - int c = vs_identifier_cmp(&a->vs_identifier, &b->vs_identifier); - if (c != 0) { - return c; - } - - c = cmp_u8(a->relative.ip_proto, b->relative.ip_proto); - if (c != 0) { - return c; - } - - if (a->relative.ip_proto == IPPROTO_IP) { - c = - memcmp(a->relative.addr.v4.bytes, - b->relative.addr.v4.bytes, - NET4_LEN); - } else if (a->relative.ip_proto == IPPROTO_IPV6) { - c = - memcmp(a->relative.addr.v6.bytes, - b->relative.addr.v6.bytes, - NET6_LEN); - } else { - c = - memcmp(&a->relative.addr, - &b->relative.addr, - sizeof(a->relative.addr)); - } - if (c != 0) { - return c; - } - - return cmp_u16(a->relative.port, b->relative.port); -} - -int -vs_registry_init( - vs_registry_t *registry, - struct memory_context *mctx, - struct vs_identifier *vs, - size_t vs_count, - vs_registry_t *prev -) { - return service_registry_init( - registry, - mctx, - vs, - sizeof(struct vs_identifier), - vs_count, - vs_identifier_cmp, - prev - ); -} - -void -vs_registry_free(vs_registry_t *registry) { - service_registry_free(registry); -} - -ssize_t -vs_registry_lookup(vs_registry_t *registry, const struct vs_identifier *vs) { - return service_registry_lookup(registry, vs, vs_identifier_cmp); -} - -size_t -vs_registry_usage(vs_registry_t *registry) { - return service_registry_usage(registry); -} - -int -reals_registry_init( - reals_registry_t *registry, - struct memory_context *mctx, - struct real_identifier *reals, - size_t reals_count, - reals_registry_t *prev -) { - return service_registry_init( - registry, - mctx, - reals, - sizeof(struct real_identifier), - reals_count, - real_identifier_cmp, - prev - ); -} - -void -reals_registry_free(reals_registry_t *registry) { - service_registry_free(registry); -} - -ssize_t -reals_registry_lookup( - reals_registry_t *registry, const struct real_identifier *real -) { - return service_registry_lookup(registry, real, real_identifier_cmp); -} - -size_t -reals_registry_usage(reals_registry_t *registry) { - return service_registry_usage(registry); -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/registry.h b/modules/balancer/controlplane/handler/registry.h deleted file mode 100644 index 459c4e7ed..000000000 --- a/modules/balancer/controlplane/handler/registry.h +++ /dev/null @@ -1,84 +0,0 @@ -#pragma once - -#include "api/real.h" -#include "api/vs.h" -#include "common/big_array.h" -#include "common/memory.h" -#include - -typedef int (*registry_cmp)(const void *left, const void *right); - -struct service_registry { - size_t next_stable_index; - - size_t elems_count; - - size_t elem_size; - struct big_array elems; // array of elems - - struct big_array indices; // array of size_t -}; - -int -service_registry_init( - struct service_registry *registry, - struct memory_context *mctx, - void *elems, - size_t elem_size, - size_t elems_count, - registry_cmp cmp, - struct service_registry *prev -); - -void -service_registry_free(struct service_registry *registry); - -ssize_t -service_registry_lookup( - struct service_registry *registry, const void *elem, registry_cmp cmp -); - -size_t -service_registry_usage(struct service_registry *registry); - -typedef struct service_registry vs_registry_t; - -int -vs_registry_init( - vs_registry_t *registry, - struct memory_context *mctx, - struct vs_identifier *vs, - size_t vs_count, - vs_registry_t *prev -); - -void -vs_registry_free(vs_registry_t *registry); - -ssize_t -vs_registry_lookup(vs_registry_t *registry, const struct vs_identifier *vs); - -size_t -vs_registry_usage(vs_registry_t *registry); - -typedef struct service_registry reals_registry_t; - -int -reals_registry_init( - reals_registry_t *registry, - struct memory_context *mctx, - struct real_identifier *reals, - size_t reals_count, - reals_registry_t *prev -); - -void -reals_registry_free(reals_registry_t *registry); - -ssize_t -reals_registry_lookup( - reals_registry_t *registry, const struct real_identifier *real -); - -size_t -reals_registry_usage(reals_registry_t *registry); diff --git a/modules/balancer/controlplane/handler/rules.c b/modules/balancer/controlplane/handler/rules.c deleted file mode 100644 index 99da2f9c8..000000000 --- a/modules/balancer/controlplane/handler/rules.c +++ /dev/null @@ -1,228 +0,0 @@ -#include "rules.h" - -#include "api/vs.h" -#include "common/memory.h" -#include "filter/compiler.h" -#include "filter/rule.h" -#include "handler.h" -#include "lib/controlplane/diag/diag.h" - -#include -#include -#include - -// Declare filter compiler signatures for VS lookup tables -FILTER_COMPILER_DECLARE( - vs_lookup_ipv4, net4_fast_dst, port_fast_dst, proto_range_fast -); -FILTER_COMPILER_DECLARE( - vs_lookup_ipv6, net6_fast_dst, port_fast_dst, proto_range_fast -); - -static int -init_transport_rule( - struct filter_rule *rule, struct named_vs_config *vs_config -) { - rule->transport.dst_count = 1; - rule->transport.dsts = calloc(1, sizeof(struct filter_port_range)); - - // For PureL3 mode, match all ports (0-65535) - // Otherwise, match only the specific port - if (vs_config->config.flags & VS_PURE_L3_FLAG) { - rule->transport.dsts[0].from = 0; - rule->transport.dsts[0].to = 65535; - } else { - rule->transport.dsts[0].from = vs_config->identifier.port; - rule->transport.dsts[0].to = vs_config->identifier.port; - } - - if (vs_config->identifier.transport_proto != IPPROTO_TCP && - vs_config->identifier.transport_proto != IPPROTO_UDP) { - NEW_ERROR( - "unsupported transport protocol %d: only TCP (%d) and " - "UDP (%d) are supported", - vs_config->identifier.transport_proto, - IPPROTO_TCP, - IPPROTO_UDP - ); - return -1; - } - - rule->transport.proto_count = 1; - rule->transport.protos = calloc(1, sizeof(struct filter_proto_range)); - rule->transport.protos[0].from = - vs_config->identifier.transport_proto * 256; - rule->transport.protos[0].to = - vs_config->identifier.transport_proto * 256 + 255; - return 0; -} - -static void -init_dst_rule(struct filter_rule *rule, struct named_vs_config *vs_config) { - memset(&rule->net6, 0, sizeof(rule->net6)); - memset(&rule->net4, 0, sizeof(rule->net4)); - if (vs_config->identifier.ip_proto == IPPROTO_IPV6) { - rule->net6.dst_count = 1; - rule->net6.dsts = malloc(sizeof(struct net6)); - struct net6 *n = &rule->net6.dsts[0]; - memcpy(n->addr, vs_config->identifier.addr.v6.bytes, NET6_LEN); - memset(n->mask, 0xFF, NET6_LEN); - } else { // ipv4 - rule->net4.dst_count = 1; - rule->net4.dsts = malloc(sizeof(struct net4)); - struct net4 *n = &rule->net4.dsts[0]; - memcpy(n->addr, vs_config->identifier.addr.v4.bytes, NET4_LEN); - memset(n->mask, 0xFF, NET4_LEN); - } -} - -int -make_filter_rules( - struct filter_rule **result_rules, - size_t count, - struct named_vs_config *vs_configs, - size_t *vs_initial_idx -) { - *result_rules = NULL; - struct filter_rule *rules = malloc(sizeof(struct filter_rule) * count); - for (size_t rule_idx = 0; rule_idx < count; ++rule_idx) { - const size_t vs_idx = rule_idx; - init_dst_rule(rules + rule_idx, vs_configs + vs_idx); - if (init_transport_rule( - rules + rule_idx, vs_configs + vs_idx - ) != 0) { - free(rules); - PUSH_ERROR( - "service at index %zu", vs_initial_idx[vs_idx] - ); - return -1; - } - } - *result_rules = rules; - return 0; -} - -void -free_rules(size_t rules_count, struct filter_rule *rules) { - for (size_t rule_idx = 0; rule_idx < rules_count; ++rule_idx) { - struct filter_rule *rule = rules + rule_idx; - free(rule->net4.dsts); - free(rule->net6.dsts); - free(rule->transport.dsts); - free(rule->transport.protos); - } - free(rules); -} - -int -build_filter( - struct packet_handler_vs *packet_handler_vs, - size_t *initial_vs_idx, - struct named_vs_config *vs_configs, - struct memory_context *mctx, - int proto -) { - struct filter *filter = memory_balloc(mctx, sizeof(struct filter)); - if (filter == NULL) { - NEW_ERROR("no memory"); - return -1; - } - - struct filter_rule *rules = NULL; - const size_t vs_count = packet_handler_vs->vs_count; - if (make_filter_rules(&rules, vs_count, vs_configs, initial_vs_idx) != - 0) { - PUSH_ERROR("invalid VS configs"); - memory_bfree(mctx, filter, sizeof(struct filter)); - return -1; - } - - const size_t rules_count = vs_count; - const struct filter_rule **rule_ptrs = (const struct filter_rule **) - malloc(sizeof(struct filter_rule *) * rules_count); - if (rule_ptrs == NULL) { - memory_bfree(mctx, filter, sizeof(struct filter)); - free_rules(rules_count, rules); - NEW_ERROR("no memory"); - return -1; - } - for (size_t idx = 0; idx < rules_count; ++idx) - rule_ptrs[idx] = rules + idx; - - if (proto == IPPROTO_IPV6) { - if (filter_init( - filter, vs_lookup_ipv6, rule_ptrs, rules_count, mctx - ) != 0) { - memory_bfree(mctx, filter, sizeof(struct filter)); - free_rules(rules_count, rules); - free(rule_ptrs); - NEW_ERROR("no memory"); - return -1; - } - } else { - if (filter_init( - filter, vs_lookup_ipv4, rule_ptrs, rules_count, mctx - ) != 0) { - memory_bfree(mctx, filter, sizeof(struct filter)); - free_rules(rules_count, rules); - free(rule_ptrs); - NEW_ERROR("no memory"); - return -1; - } - } - free(rule_ptrs); - - SET_OFFSET_OF(&packet_handler_vs->filter, filter); - - free_rules(rules_count, rules); - - return 0; -} - -void -free_filter_ipv4( - struct packet_handler_vs *packet_handler_vs, struct memory_context *mctx -) { - if (packet_handler_vs->filter_reused) { - return; - } - struct filter *filter = ADDR_OF(&packet_handler_vs->filter); - if (filter == NULL) { - return; - } - filter_free(filter, vs_lookup_ipv4); - memory_bfree(mctx, filter, sizeof(struct filter)); -} - -void -free_filter_ipv6( - struct packet_handler_vs *packet_handler_vs, struct memory_context *mctx -) { - if (packet_handler_vs->filter_reused) { - return; - } - struct filter *filter = ADDR_OF(&packet_handler_vs->filter); - if (filter == NULL) { - return; - } - filter_free(filter, vs_lookup_ipv6); - memory_bfree(mctx, filter, sizeof(struct filter)); -} - -uint64_t -rules_memory_usage(size_t rules_count, struct filter_rule *rules) { - uint64_t result = sizeof(struct filter_rule) * rules_count; - for (size_t rule_idx = 0; rule_idx < rules_count; ++rule_idx) { - struct filter_rule *rule = rules + rule_idx; - result += sizeof(struct net6) * - (rule->net6.dst_count + rule->net6.src_count); - result += sizeof(struct net4) * - (rule->net4.dst_count + rule->net4.src_count); - result += - sizeof(struct filter_port_range) * - (rule->transport.dst_count + rule->transport.src_count); - result += sizeof(struct filter_proto_range) * - rule->transport.proto_count; - } - return result; -} diff --git a/modules/balancer/controlplane/handler/rules.h b/modules/balancer/controlplane/handler/rules.h deleted file mode 100644 index 4b020c694..000000000 --- a/modules/balancer/controlplane/handler/rules.h +++ /dev/null @@ -1,86 +0,0 @@ -#pragma once - -#include -#include - -struct filter_rule; -struct named_vs_config; -struct packet_handler_vs; -struct memory_context; - -/** - * Create filter rules from VS configurations. - * - * @param result_rules Output pointer to allocated filter rules array - * @param count Number of VS configurations - * @param vs_configs Array of VS configurations - * @param vs_initial_idx Array of initial VS indices for error reporting - * @return 0 on success, -1 on error - */ -int -make_filter_rules( - struct filter_rule **result_rules, - size_t count, - struct named_vs_config *vs_configs, - size_t *vs_initial_idx -); - -/** - * Free filter rules and their associated resources. - * - * @param rules_count Number of rules to free - * @param rules Array of filter rules - */ -void -free_rules(size_t rules_count, struct filter_rule *rules); - -/** - * Build a filter from rules for a packet handler VS. - * - * @param packet_handler_vs Packet handler VS structure - * @param initial_vs_idx Array of initial VS indices for error reporting - * @param vs_configs Array of VS configurations - * @param mctx Memory context for allocations - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @return 0 on success, -1 on error - */ -int -build_filter( - struct packet_handler_vs *packet_handler_vs, - size_t *initial_vs_idx, - struct named_vs_config *vs_configs, - struct memory_context *mctx, - int proto -); - -/** - * Free IPv4 VS lookup filter resources. - * - * Releases filter resources for an IPv4 packet handler VS. - * Does nothing if the filter was reused from a previous handler. - * - * @param packet_handler_vs Packet handler VS structure - * @param mctx Memory context for deallocations - */ -void -free_filter_ipv4( - struct packet_handler_vs *packet_handler_vs, struct memory_context *mctx -); - -/** - * Free IPv6 VS lookup filter resources. - * - * Releases filter resources for an IPv6 packet handler VS. - * Does nothing if the filter was reused from a previous handler. - * - * @param packet_handler_vs Packet handler VS structure - * @param mctx Memory context for deallocations - */ -void -free_filter_ipv6( - struct packet_handler_vs *packet_handler_vs, struct memory_context *mctx -); - -// TODO: docs -uint64_t -rules_memory_usage(size_t rules_count, struct filter_rule *rules); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/selector.c b/modules/balancer/controlplane/handler/selector.c deleted file mode 100644 index c5c1f156a..000000000 --- a/modules/balancer/controlplane/handler/selector.c +++ /dev/null @@ -1,117 +0,0 @@ -#include "selector.h" - -#include "common/memory.h" -#include "common/memory_address.h" -#include "common/rng.h" - -#include "lib/controlplane/diag/diag.h" - -#include "real.h" -#include - -static int -ring_init( - struct ring *ring, - struct memory_context *mctx, - size_t reals_count, - const struct real *reals -) { - memset(ring, 0, sizeof(struct ring)); - ring->enabled_len = (reals_count + 7) / 8; - uint8_t *enabled = memory_balloc(mctx, ring->enabled_len); - if (enabled == NULL && ring->enabled_len > 0) { - NEW_ERROR("failed to allocate enabled bits"); - return -1; - } - if (ring->enabled_len > 0) { - memset(enabled, 0, ring->enabled_len); - } - size_t len = 0; - for (size_t i = 0; i < reals_count; ++i) { - const struct real *real = &reals[i]; - uint16_t weight = real->enabled ? real->weight : 0; - len += weight; - if (real->enabled) { - enabled[i / 8] |= 1 << (i % 8); - } - } - uint32_t *ids = memory_balloc(mctx, len * sizeof(uint32_t)); - if (ids == NULL && len > 0) { - memory_bfree(mctx, enabled, ring->enabled_len); - NEW_ERROR("failed to allocate weighted reals list"); - return -1; - } - size_t idx = 0; - for (size_t i = 0; i < reals_count; ++i) { - const struct real *real = &reals[i]; - uint16_t weight = real->enabled ? real->weight : 0; - for (size_t copy = 0; copy < weight; ++copy) { - ids[idx++] = i; - } - } - uint64_t rng = 0xdeadbeef; - for (size_t i = 1; i < len; ++i) { - // swap with random before me - size_t j = rng_next(&rng) % i; - uint32_t tmp = ids[i]; - ids[i] = ids[j]; - ids[j] = tmp; - } - SET_OFFSET_OF(&ring->ids, ids); - SET_OFFSET_OF(&ring->enabled, enabled); - ring->len = len; - return 0; -} - -static void -ring_free(struct ring *ring, struct memory_context *mctx) { - memory_bfree(mctx, ADDR_OF(&ring->ids), ring->len * sizeof(uint32_t)); - memory_bfree(mctx, ADDR_OF(&ring->enabled), ring->enabled_len); -} - -//////////////////////////////////////////////////////////////////////////////// - -int -selector_update( - struct real_selector *selector, - size_t reals_count, - const struct real *reals -) { - size_t cur_ring_id = selector->ring_id; - size_t new_ring_id = cur_ring_id ^ 1; - struct ring *new_ring = &selector->rings[new_ring_id]; - if (ring_init(new_ring, &selector->mctx, reals_count, reals) != 0) { - PUSH_ERROR("failed to init ring"); - return -1; - } - rcu_update(&selector->rcu, &selector->ring_id, new_ring_id); - ring_free(&selector->rings[cur_ring_id], &selector->mctx); - return 0; -} - -int -selector_init( - struct real_selector *selector, - struct memory_context *mctx, - enum vs_scheduler scheduler -) { - memory_context_init_from(&selector->mctx, mctx, "real_selector"); - rcu_init(&selector->rcu); - selector->use_rr = scheduler == round_robin ? 1 : 0; - selector->ring_id = 0; - if (ring_init(&selector->rings[0], &selector->mctx, 0, NULL) != 0) { - PUSH_ERROR("failed to init ring"); - return -1; - } - uint64_t rng = 0xdeadbeef; - for (size_t i = 0; i < MAX_WORKERS_NUM; ++i) { - selector->workers[i].rr_counter = rng_next(&rng); - } - return 0; -} - -void -selector_free(struct real_selector *selector) { - size_t cur_ring_id = selector->ring_id; - ring_free(&selector->rings[cur_ring_id], &selector->mctx); -} diff --git a/modules/balancer/controlplane/handler/selector.h b/modules/balancer/controlplane/handler/selector.h deleted file mode 100644 index 743452a97..000000000 --- a/modules/balancer/controlplane/handler/selector.h +++ /dev/null @@ -1,101 +0,0 @@ -#pragma once - -#include "common/memory.h" - -#include -#include - -#include "common/rcu.h" - -#include "api/vs.h" -#include "state/worker.h" - -#include "real.h" - -//////////////////////////////////////////////////////////////////////////////// - -#define SELECTOR_VALUE_INVALID ((uint32_t)-1) - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Compact ring of backend identifiers for selection. - */ -struct ring { - // Number of entries in ids - uint32_t len; - - // Relative pointer to per-backend identifiers (packet-handler indices) - uint32_t *ids; - - uint32_t enabled_len; - - // Maps local real index to its enabled state - uint8_t *enabled; -}; - -/** - * Per-worker selector state. - */ -struct selector_worker { - uint64_t rr_counter; // Round-robin position -} __attribute__((aligned(64))); - -/** - * Real backend selector. - * - * Maintains two rings for RCU-swapped updates and per-worker RR counters. - * Uses either round-robin or hash-based selection depending on VS scheduler. - */ -struct real_selector { - struct memory_context mctx; // Memory context for rings - rcu_t rcu; // RCU guard for ring swaps - struct selector_worker workers[MAX_WORKERS_NUM]; // Per-worker state - struct ring rings[2]; // Double-buffered rings - _Atomic size_t ring_id; // Active ring index - int use_rr; // Non-zero for RR, zero for hash -}; - -/** - * Initialize selector with desired scheduling mode. - * Returns 0 on success, -1 on error. - */ -int -selector_init( - struct real_selector *selector, - struct memory_context *mctx, - enum vs_scheduler scheduler -); - -/** - * Free resources held by the selector. - */ -void -selector_free(struct real_selector *selector); - -/** - * Rebuild selector rings from provided real views. - * Returns 0 on success, -1 on error. - */ -int -selector_update( - struct real_selector *selector, - size_t reals_count, - const struct real *reals -); - -static inline bool -selector_real_enabled(struct real_selector *selector, size_t local_real_idx) { - uint32_t current_ring_idx = - atomic_load_explicit(&selector->ring_id, memory_order_relaxed); - struct ring *current_ring = &selector->rings[current_ring_idx]; - uint8_t *enabled = ADDR_OF(¤t_ring->enabled); - return enabled[local_real_idx / 8] & (1 << (local_real_idx % 8)); -} - -static inline uint64_t -selector_memory_usage(struct real_selector *selector) { - struct memory_context *mctx = &selector->mctx; - assert(mctx->balloc_size >= mctx->bfree_size); - return mctx->balloc_size - mctx->bfree_size; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/services.c b/modules/balancer/controlplane/handler/services.c deleted file mode 100644 index c3afcab83..000000000 --- a/modules/balancer/controlplane/handler/services.c +++ /dev/null @@ -1,476 +0,0 @@ -#include "services.h" - -#include "api/balancer.h" -#include "api/vs.h" -#include "common/lpm.h" -#include "common/memory.h" -#include "common/memory_address.h" -#include "common/swap.h" -#include "handler.h" -#include "lib/controlplane/diag/diag.h" -#include "map.h" -#include "registry.h" -#include "rules.h" -#include "vs.h" - -#include -#include -#include -#include - -struct vs * -find_vs_in_packet_handler_vs( - struct packet_handler_vs *packet_handler_vs, struct vs *vs -) { - if (packet_handler_vs == NULL) { - return NULL; - } - - size_t vs_idx; - if (map_find(&packet_handler_vs->index, vs->stable_idx, &vs_idx) != 0) { - return NULL; - } - - struct vs *services = ADDR_OF(&packet_handler_vs->vs); - return &services[vs_idx]; -} - -struct packet_handler_vs * -get_packet_handler_vs(struct packet_handler *handler, int proto) { - return handler == NULL ? NULL - : (proto == IPPROTO_IP ? &handler->vs_ipv4 - : &handler->vs_ipv6); -} - -int -can_reuse_filter(int current_vs_count, int prev_vs_count, int match_count) { - // all virtual services are unique, it is validated on packet handler - // update - return current_vs_count == prev_vs_count && - current_vs_count == match_count; -} - -static int -validate_vs_config(struct named_vs_config *config) { - int proto = config->identifier.ip_proto; - if (proto != IPPROTO_IP && proto != IPPROTO_IPV6) { - NEW_ERROR( - "network protocol is invalid: got %d, but only IPv4 " - "(%d) and IPv6 (%d) are supported", - proto, - IPPROTO_IP, - IPPROTO_IPV6 - ); - return -1; - } - - if (config->identifier.transport_proto != IPPROTO_TCP && - config->identifier.transport_proto != IPPROTO_UDP) { - NEW_ERROR( - "transport protocol is invalid: got %d, but only TCP " - "(%d) and UDP (%d) are supported", - config->identifier.transport_proto, - IPPROTO_TCP, - IPPROTO_UDP - ); - return -1; - } - - // TODO: better validation - - return 0; -} - -static void -swap_vs_configs( - size_t *initial_vs_idx, - struct named_vs_config *configs, - size_t left_idx, - size_t right_idx -) { - SWAP(configs + left_idx, configs + right_idx); - SWAP(initial_vs_idx + left_idx, initial_vs_idx + right_idx); -} - -int -validate_and_reorder_vs_configs( - size_t *initial_vs_idx, - size_t count, - struct named_vs_config *configs, - size_t *ipv4_count, - size_t *ipv6_count -) { - // move ipv4 services first, and ipv6 then. - - ssize_t last_ipv6 = -1; - for (size_t idx = 0; idx < count; ++idx) { - struct named_vs_config *current = &configs[idx]; - - // validate service - if (validate_vs_config(current) != 0) { - PUSH_ERROR("at index %zu", idx); - return -1; - } - - int proto = current->identifier.ip_proto; - - if (proto == IPPROTO_IPV6) { - // IPv6 service - *ipv6_count += 1; - if (last_ipv6 == -1) { - last_ipv6 = idx; - } - continue; - } - - // IPv4 service - *ipv4_count += 1; - if (last_ipv6 == -1) { - continue; - } - - swap_vs_configs(initial_vs_idx, configs, idx, last_ipv6); - - last_ipv6 += 1; - } - - return 0; -} - -int -register_virtual_services( - struct packet_handler *handler, - size_t vs_count, - const size_t *initial_vs_idx, - struct named_vs_config *configs, - struct packet_handler *prev_handler, - size_t *match -) { - for (size_t vs_idx = 0; vs_idx < vs_count; ++vs_idx) { - struct named_vs_config *vs_config = &configs[vs_idx]; - - // Look up stable index in current registry - ssize_t stable_idx = vs_registry_lookup( - &handler->vs_registry, &vs_config->identifier - ); - if (stable_idx == -1) { - PUSH_ERROR( - "VS not found in registry at index %zu", - initial_vs_idx[vs_idx] - ); - return -1; - } - - // Check if this VS existed in previous config - if (prev_handler != NULL) { - size_t prev_config_idx; - if (map_find( - &prev_handler->vs_index, - stable_idx, - &prev_config_idx - ) == 0) { - *match += 1; - } - } - } - - return 0; -} - -int -register_and_prepare_vs( - struct packet_handler *handler, - struct packet_handler *prev_handler, - int proto, - size_t vs_count, - struct named_vs_config *vs_configs, - size_t *initial_vs_idx, - struct vs *virtual_services, - struct balancer_update_info *update_info, - int *reuse_filter -) { - // only IPv4 and IPv6 are supported - assert(proto == IPPROTO_IP || proto == IPPROTO_IPV6); - - // Check how many services match with previous config - size_t match = 0; - if (register_virtual_services( - handler, - vs_count, - initial_vs_idx, - vs_configs, - prev_handler, - &match - ) != 0) { - PUSH_ERROR("registration failed"); - return -1; - } - - // init some fields of the packet_handler_vs for this protocol: - // - vs_count - // - vs - struct packet_handler_vs *packet_handler_vs = - get_packet_handler_vs(handler, proto); - packet_handler_vs->vs_count = vs_count; - SET_OFFSET_OF(&packet_handler_vs->vs, virtual_services); - - // prev handler is optional - struct packet_handler_vs *prev_packet_handler_vs = - get_packet_handler_vs(prev_handler, proto); - - // check if VS filter for this protocol can be reused - *reuse_filter = prev_packet_handler_vs == NULL - ? 0 - : can_reuse_filter( - vs_count, - prev_packet_handler_vs->vs_count, - match - ); - if (update_info != NULL) { - *(proto == IPPROTO_IPV6 ? &update_info->vs_ipv6_matcher_reused - : &update_info->vs_ipv4_matcher_reused - ) = *reuse_filter; - } - - // to reuse filter for network protocol, the VS indices in - // packet_handler_vs MUST match with the corresponding indices in the - // previous config. this is because the VS matching mechanism - if (*reuse_filter) { - // permute VS configs according to indices in the previous - // config - for (size_t vs_idx = 0; vs_idx < vs_count; ++vs_idx) { - ssize_t stable_idx = vs_registry_lookup( - &handler->vs_registry, - &vs_configs[vs_idx].identifier - ); - assert(stable_idx != -1); - - size_t position; - int found = map_find( - &prev_packet_handler_vs->index, - stable_idx, - &position - ); - assert(found == 0); - - swap_vs_configs( - initial_vs_idx, vs_configs, vs_idx, position - ); - } - } - - return 0; -} - -int -init_packet_handler_vs( - struct packet_handler *handler, - int proto, - struct memory_context *mctx, - struct named_vs_config *vs_configs, - struct counter_registry *registry, - struct packet_handler *prev_handler, - struct real *reals, - size_t *reals_counter, - struct balancer_update_info *update_info, - size_t *initial_vs_idx -) { - // only IPv4 and IPv6 are supported - assert(proto == IPPROTO_IP || proto == IPPROTO_IPV6); - - // prev packet handler is optional - struct packet_handler_vs *prev_packet_handler_vs = - get_packet_handler_vs(prev_handler, proto); - - // find packet handler vs for this protocol - struct packet_handler_vs *packet_handler_vs = - get_packet_handler_vs(handler, proto); - size_t vs_count = packet_handler_vs->vs_count; - struct vs *virtual_services = ADDR_OF(&packet_handler_vs->vs); - - // Build key-value pairs for the index map (stable_idx -> config_idx) - struct key_value *entries = malloc(sizeof(struct key_value) * vs_count); - if (entries == NULL && vs_count > 0) { - NEW_ERROR("failed to allocate memory for VS index entries"); - return -1; - } - - // initialize virtual services - for (size_t vs_idx = 0; vs_idx < vs_count; ++vs_idx) { - struct vs *current_vs = virtual_services + vs_idx; - struct named_vs_config *current_vs_config = vs_configs + vs_idx; - - // set identifier - current_vs->identifier = current_vs_config->identifier; - - // Look up stable index from registry - ssize_t stable_idx = vs_registry_lookup( - &handler->vs_registry, ¤t_vs->identifier - ); - assert(stable_idx != -1); - current_vs->stable_idx = stable_idx; - - // Add to index map entries - entries[vs_idx].key = stable_idx; - entries[vs_idx].value = vs_idx; - - // try to find this virtual service in previous config, can be - // null - struct vs *prev_vs = find_vs_in_packet_handler_vs( - prev_packet_handler_vs, current_vs - ); - - size_t first_real_idx = *reals_counter; - if (vs_with_identifier_and_registry_idx_init( - current_vs, - prev_vs, - first_real_idx, - reals + first_real_idx, - current_vs_config, - registry, - mctx, - update_info - ) != 0) { - PUSH_ERROR( - "service at index %zu", initial_vs_idx[vs_idx] - ); - free(entries); - return -1; - } - - // increase reals counter - *reals_counter += current_vs->reals_count; - } - - // Initialize the index map - if (map_init(&packet_handler_vs->index, mctx, entries, vs_count) != 0) { - NEW_ERROR("failed to initialize VS index map"); - free(entries); - return -1; - } - - free(entries); - return 0; -} - -int -init_vs_filter( - struct packet_handler_vs *packet_handler_vs, - struct packet_handler_vs *prev_packet_handler_vs, - struct named_vs_config *vs_configs, - int reuse_filter, - struct memory_context *mctx, - size_t *initial_vs_idx, - int proto -) { - packet_handler_vs->filter_reused = 0; - if (reuse_filter) { - // just reuse filter from the current packet handler - EQUATE_OFFSET( - &packet_handler_vs->filter, - &prev_packet_handler_vs->filter - ); - prev_packet_handler_vs->filter_reused = 1; - } else { - if (build_filter( - packet_handler_vs, - initial_vs_idx, - vs_configs, - mctx, - proto - ) != 0) { - PUSH_ERROR("build failed"); - return -1; - } - } - return 0; -} - -int -init_announce( - struct packet_handler_vs *handler, - struct memory_context *mctx, - struct named_vs_config *vs_configs, - int proto -) { - struct lpm *lpm = &handler->announce; - if (lpm_init(lpm, mctx) != 0) { - NEW_ERROR("no memory"); - return -1; - } - - for (size_t vs_idx = 0; vs_idx < handler->vs_count; ++vs_idx) { - struct named_vs_config *vs_config = vs_configs + vs_idx; - int res; - if (proto == IPPROTO_IP) { - res = lpm4_insert( - lpm, - vs_config->identifier.addr.v4.bytes, - vs_config->identifier.addr.v4.bytes, - 1 - ); - } else { - res = lpm8_insert( - lpm, - vs_config->identifier.addr.v6.bytes, - vs_config->identifier.addr.v6.bytes, - 1 - ); - } - if (res != 0) { - lpm_free(lpm); - NEW_ERROR("no memory"); - return -1; - } - } - - return 0; -} - -int -setup_vs_index( - struct packet_handler *handler, - struct vs *virtual_services, - size_t *initial_vs_idx, - struct memory_context *mctx -) { - // Build key-value pairs for the map (stable_idx -> config_idx) - struct key_value *entries = - malloc(sizeof(struct key_value) * handler->vs_count); - if (entries == NULL && handler->vs_count > 0) { - NEW_ERROR("failed to allocate memory for VS index entries"); - return -1; - } - - for (size_t vs_idx = 0; vs_idx < handler->vs_count; vs_idx++) { - struct vs *vs = virtual_services + vs_idx; - - // Check for duplicates - for (size_t i = 0; i < vs_idx; i++) { - if (entries[i].key == vs->stable_idx) { - NEW_ERROR( - "service at index %zu matches with " - "service at index %zu", - initial_vs_idx[vs_idx], - initial_vs_idx[i] - ); - free(entries); - return -1; - } - } - - entries[vs_idx].key = vs->stable_idx; - entries[vs_idx].value = vs_idx; - } - - // Initialize the map - if (map_init(&handler->vs_index, mctx, entries, handler->vs_count) != - 0) { - NEW_ERROR("failed to initialize VS index map"); - free(entries); - return -1; - } - - free(entries); - return 0; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/services.h b/modules/balancer/controlplane/handler/services.h deleted file mode 100644 index 7746e2455..000000000 --- a/modules/balancer/controlplane/handler/services.h +++ /dev/null @@ -1,200 +0,0 @@ -#pragma once - -#include - -struct packet_handler; -struct packet_handler_vs; -struct named_vs_config; -struct balancer_state; -struct counter_registry; -struct memory_context; -struct balancer_update_info; -struct vs; -struct real; - -/** - * Find VS in packet handler VS structure. - * - * @param packet_handler_vs Packet handler VS structure - * @param vs VS to find - * @return Pointer to VS if found, NULL otherwise - */ -struct vs * -find_vs_in_packet_handler_vs( - struct packet_handler_vs *packet_handler_vs, struct vs *vs -); - -/** - * Get packet handler VS for a specific protocol. - * - * @param handler Packet handler instance - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @return Pointer to packet handler VS structure - */ -struct packet_handler_vs * -get_packet_handler_vs(struct packet_handler *handler, int proto); - -/** - * Check if VS filter can be reused from previous configuration. - * - * @param current_vs_count Current VS count - * @param prev_vs_count Previous VS count - * @param match_count Number of matching VS - * @return 1 if filter can be reused, 0 otherwise - */ -int -can_reuse_filter(int current_vs_count, int prev_vs_count, int match_count); - -/** - * Validate and reorder VS configurations. - * Moves IPv4 services first, then IPv6 services. - * - * @param initial_vs_idx Array of initial VS indices (will be reordered) - * @param count Number of VS configurations - * @param configs Array of VS configurations (will be reordered) - * @param ipv4_count Output: number of IPv4 services - * @param ipv6_count Output: number of IPv6 services - * @return 0 on success, -1 on error - */ -int -validate_and_reorder_vs_configs( - size_t *initial_vs_idx, - size_t count, - struct named_vs_config *configs, - size_t *ipv4_count, - size_t *ipv6_count -); - -/** - * Register virtual services in handler registry. - * - * @param handler Packet handler instance - * @param vs_count Number of virtual services - * @param initial_vs_idx Array of initial VS indices for error reporting - * @param configs Array of VS configurations - * @param prev_handler Previous packet handler (may be NULL) - * @param match Output: number of matching VS with previous config - * @return 0 on success, -1 on error - */ -int -register_virtual_services( - struct packet_handler *handler, - size_t vs_count, - const size_t *initial_vs_idx, - struct named_vs_config *configs, - struct packet_handler *prev_handler, - size_t *match -); - -/** - * Register and prepare virtual services for a specific protocol. - * - * @param handler Packet handler instance - * @param prev_handler Previous packet handler (may be NULL) - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @param vs_count Number of virtual services - * @param vs_configs Array of VS configurations - * @param initial_vs_idx Array of initial VS indices for error reporting - * @param virtual_services Array of VS structures to initialize - * @param update_info Update information structure (may be NULL) - * @param reuse_filter Output: 1 if filter can be reused, 0 otherwise - * @return 0 on success, -1 on error - */ -int -register_and_prepare_vs( - struct packet_handler *handler, - struct packet_handler *prev_handler, - int proto, - size_t vs_count, - struct named_vs_config *vs_configs, - size_t *initial_vs_idx, - struct vs *virtual_services, - struct balancer_update_info *update_info, - int *reuse_filter -); - -/** - * Initialize packet handler VS for a specific protocol. - * - * @param handler Packet handler instance - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @param mctx Memory context - * @param vs_configs Array of VS configurations - * @param registry Counter registry - * @param prev_handler Previous packet handler (may be NULL) - * @param reals Array of reals - * @param reals_counter Counter for reals (will be incremented) - * @param update_info Update information structure (may be NULL) - * @param initial_vs_idx Array of initial VS indices for error reporting - * @return 0 on success, -1 on error - */ -int -init_packet_handler_vs( - struct packet_handler *handler, - int proto, - struct memory_context *mctx, - struct named_vs_config *vs_configs, - struct counter_registry *registry, - struct packet_handler *prev_handler, - struct real *reals, - size_t *reals_counter, - struct balancer_update_info *update_info, - size_t *initial_vs_idx -); - -/** - * Initialize VS filter for a packet handler VS. - * - * @param packet_handler_vs Packet handler VS structure - * @param prev_packet_handler_vs Previous packet handler VS (may be NULL) - * @param vs_configs Array of VS configurations - * @param reuse_filter 1 if filter should be reused, 0 otherwise - * @param mctx Memory context - * @param initial_vs_idx Array of initial VS indices for error reporting - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @return 0 on success, -1 on error - */ -int -init_vs_filter( - struct packet_handler_vs *packet_handler_vs, - struct packet_handler_vs *prev_packet_handler_vs, - struct named_vs_config *vs_configs, - int reuse_filter, - struct memory_context *mctx, - size_t *initial_vs_idx, - int proto -); - -/** - * Initialize announce LPM for a packet handler VS. - * - * @param handler Packet handler VS structure - * @param mctx Memory context - * @param vs_configs Array of VS configurations - * @param proto IP protocol (IPPROTO_IP or IPPROTO_IPV6) - * @return 0 on success, -1 on error - */ -int -init_announce( - struct packet_handler_vs *handler, - struct memory_context *mctx, - struct named_vs_config *vs_configs, - int proto -); - -/** - * Setup VS index mapping. - * - * @param handler Packet handler instance - * @param virtual_services Array of VS structures - * @param initial_vs_idx Array of initial VS indices for error reporting - * @param mctx Memory context - * @return 0 on success, -1 on error - */ -int -setup_vs_index( - struct packet_handler *handler, - struct vs *virtual_services, - size_t *initial_vs_idx, - struct memory_context *mctx -); \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/stats.c b/modules/balancer/controlplane/handler/stats.c deleted file mode 100644 index c574b1ae9..000000000 --- a/modules/balancer/controlplane/handler/stats.c +++ /dev/null @@ -1,346 +0,0 @@ -#include "api/balancer.h" -#include "api/real.h" -#include "api/vs.h" -#include "common/memory_address.h" -#include "controlplane/agent/agent.h" -#include "handler.h" - -#include "api/counter.h" - -#include "lib/controlplane/diag/diag.h" -#include "vs.h" -#include - -#include "api/stats.h" - -//////////////////////////////////////////////////////////////////////////////// - -const char *common_module_counter_name = "cmn"; -const char *icmp_v4_module_counter_name = "iv4"; -const char *icmp_v6_module_counter_name = "iv6"; -const char *l4_module_counter_name = "l4"; - -//////////////////////////////////////////////////////////////////////////////// - -uint64_t -register_common_counter(struct counter_registry *registry) { - uint64_t res = counter_registry_register( - registry, - common_module_counter_name, - sizeof(struct balancer_common_stats) / sizeof(uint64_t) - ); - - if (res == (uint64_t)-1) { - PUSH_ERROR("failed to register counter in registry"); - return -1; - } - - return res; -} - -uint64_t -register_icmp_v4_counter(struct counter_registry *registry) { - uint64_t res = counter_registry_register( - registry, - icmp_v4_module_counter_name, - sizeof(struct balancer_icmp_stats) / sizeof(uint64_t) - ); - - if (res == (uint64_t)-1) { - PUSH_ERROR("failed to register counter in registry"); - return -1; - } - - return res; -} - -uint64_t -register_icmp_v6_counter(struct counter_registry *registry) { - uint64_t res = counter_registry_register( - registry, - icmp_v6_module_counter_name, - sizeof(struct balancer_icmp_stats) / sizeof(uint64_t) - ); - - if (res == (uint64_t)-1) { - PUSH_ERROR("failed to register counter in registry"); - return -1; - } - - return res; -} - -uint64_t -register_l4_counter(struct counter_registry *registry) { - uint64_t res = counter_registry_register( - registry, - l4_module_counter_name, - sizeof(struct balancer_l4_stats) / sizeof(uint64_t) - ); - - if (res == (uint64_t)-1) { - PUSH_ERROR("failed to register counter in registry"); - return -1; - } - - return res; -} - -//////////////////////////////////////////////////////////////////////////////// - -static void -setup_real_stats( - struct real_stats *real_stats, - const size_t instances, - struct counter_handle *counter -) { - counter_handle_accum( - (uint64_t *)real_stats, - instances, - counter->size, - counter->value_handle - ); -} - -static void -setup_vs_stats( - struct vs_stats *stats, - const size_t instances, - struct counter_handle *counter -) { - counter_handle_accum( - (uint64_t *)stats, - instances, - counter->size, - counter->value_handle - ); -} - -static void -setup_vs_acl_stats( - struct allowed_sources_stats *stats, - const char *tag, - const size_t instances, - struct counter_handle *counter -) { - stats->tag = strdup(tag); - counter_handle_accum( - (uint64_t *)&stats->passes, - instances, - counter->size, - counter->value_handle - ); -} - -static void -inc_balancer_stats( - struct balancer_stats *stats, - const size_t workers, - struct counter_handle *counter -) { - if (strcmp(counter->name, common_module_counter_name) == - 0) { // common module counter - counter_handle_accum( - (uint64_t *)&stats->common, - workers, - counter->size, - counter->value_handle - ); - } else if (strcmp(counter->name, icmp_v4_module_counter_name) == - 0) { // icmp module counter - counter_handle_accum( - (uint64_t *)&stats->icmp_ipv4, - workers, - counter->size, - counter->value_handle - ); - } else if (strcmp(counter->name, icmp_v6_module_counter_name) == 0) { - counter_handle_accum( - (uint64_t *)&stats->icmp_ipv6, - workers, - counter->size, - counter->value_handle - ); - } else if (strcmp(counter->name, l4_module_counter_name) == - 0) { // l4 module counter - counter_handle_accum( - (uint64_t *)&stats->l4, - workers, - counter->size, - counter->value_handle - ); - } -} - -static void -init_real_stats( - size_t reals_count, - struct named_real_stats *real_stats, - struct real *reals -) { - for (size_t i = 0; i < reals_count; ++i) { - real_stats[i].real = reals[i].identifier.relative; - memset(&real_stats[i].stats, 0, sizeof(struct real_stats)); - } -} - -static void -init_vs_stats( - struct packet_handler *handler, - struct balancer_stats *stats, - struct named_real_stats *real_stats -) { - stats->vs_count = handler->vs_count; - stats->vs = malloc(sizeof(struct named_vs_stats) * stats->vs_count); - - // init virtual services - struct vs *vss = ADDR_OF(&handler->vs); - size_t reals_counter = 0; - for (size_t i = 0; i < stats->vs_count; ++i) { - struct named_vs_stats *vs_stats = &stats->vs[i]; - struct vs *vs = &vss[i]; - vs_stats->identifier = vs->identifier; - memset(&vs_stats->stats, 0, sizeof(struct vs_stats)); - vs_stats->reals_count = vs->reals_count; - vs_stats->reals = real_stats + reals_counter; - reals_counter += vs->reals_count; - struct allowed_sources_stats *stats = - malloc(sizeof(struct allowed_sources_stats) * - vs->rules_count); - vs_stats->allowed_sources = stats; - vs_stats->allowed_sources_count = 0; - } -} - -static void -calculate_stats( - struct packet_handler *handler, - struct balancer_stats *stats, - struct named_real_stats *real_stats, - struct counter_handle_list *counter_handles -) { - const size_t instances = counter_handles->instance_count; - - // calculate virtual service, real and common balancer stats - - for (size_t i = 0; i < counter_handles->count; ++i) { - struct counter_handle *counter = &counter_handles->counters[i]; - ssize_t vs_stable_idx = counter_to_vs_registry_idx(counter); - if (vs_stable_idx != -1) { - size_t vs_config_idx; - if (map_find( - &handler->vs_index, - vs_stable_idx, - &vs_config_idx - ) != 0) { - // virtual service not present in packet handler - // config - continue; - } - struct named_vs_stats *vs_stats = - &stats->vs[vs_config_idx]; - setup_vs_stats(&vs_stats->stats, instances, counter); - continue; - } - - // else, if it is not virtual service counter - // check if it is real counter - - ssize_t real_stable_idx = counter_to_real_registry_idx(counter); - if (real_stable_idx != -1) { - size_t real_config_idx; - if (map_find( - &handler->reals_index, - real_stable_idx, - &real_config_idx - ) != 0) { - // real not present in packet handler config - continue; - } - setup_real_stats( - &real_stats[real_config_idx].stats, - instances, - counter - ); - continue; - } - - const char *rule_tag; - vs_stable_idx = parse_vs_acl_counter(counter, &rule_tag); - if (vs_stable_idx != -1) { - size_t vs_config_idx; - if (map_find( - &handler->vs_index, - vs_stable_idx, - &vs_config_idx - ) != 0) { - // virtual service not present in packet handler - // config - continue; - } - struct named_vs_stats *vs_stats = - &stats->vs[vs_config_idx]; - size_t allowed_sources_stats_count = - vs_stats->allowed_sources_count++; - struct allowed_sources_stats *stats = - &vs_stats->allowed_sources - [allowed_sources_stats_count]; - setup_vs_acl_stats(stats, rule_tag, instances, counter); - } - - // else, it is common balancer counter - inc_balancer_stats(stats, instances, counter); - } -} - -int -packet_handler_fill_stats( - struct packet_handler *handler, - struct balancer_stats *stats, - struct packet_handler_ref *ref -) { - struct agent *agent = ADDR_OF(&handler->cp_module.agent); - struct dp_config *dp_config = ADDR_OF(&agent->dp_config); - - const char *module = handler->cp_module.name; - - struct counter_handle_list *counter_handles = yanet_get_module_counters( - dp_config, - ref->device, - ref->pipeline, - ref->function, - ref->chain, - "balancer", - module, - NULL, - (size_t)-1 - ); - if (counter_handles == NULL) { - NEW_ERROR("failed to find stats"); - return -1; - } - - // Initialize all stats to zero - memset(&stats->common, 0, sizeof(struct balancer_common_stats)); - memset(&stats->icmp_ipv4, 0, sizeof(struct balancer_icmp_stats)); - memset(&stats->icmp_ipv6, 0, sizeof(struct balancer_icmp_stats)); - memset(&stats->l4, 0, sizeof(struct balancer_l4_stats)); - - // init real stats - struct real *reals = ADDR_OF(&handler->reals); - - // layout of reals corresponds to the - // layout in packet handler - struct named_real_stats *real_stats = - malloc(sizeof(struct named_real_stats) * handler->reals_count); - - init_real_stats(handler->reals_count, real_stats, reals); - - // init vs stats - init_vs_stats(handler, stats, real_stats); - - // calculate stats - calculate_stats(handler, stats, real_stats, counter_handles); - - return 0; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/update.c b/modules/balancer/controlplane/handler/update.c deleted file mode 100644 index 1526b9176..000000000 --- a/modules/balancer/controlplane/handler/update.c +++ /dev/null @@ -1,219 +0,0 @@ -#include "common/memory_address.h" -#include "lib/controlplane/diag/diag.h" - -#include "real.h" -#include "registry.h" -#include "vs.h" - -#include "handler.h" - -#define VS_COUNT 1024 * 1024 -static uint64_t updated_vs[VS_COUNT / 64] = {0}; - -static inline void -mark_vs_updated(size_t ph_idx) { - if (ph_idx < VS_COUNT) { - updated_vs[ph_idx / 64] |= 1ULL << (ph_idx % 64); - } -} - -static inline int -is_vs_updated(size_t ph_idx) { - if (ph_idx < VS_COUNT) { - return updated_vs[ph_idx / 64] & (1ULL << (ph_idx % 64)); - } else { - return 1; - } -} - -static inline void -unmark_vs_updated(size_t ph_idx) { - if (ph_idx < VS_COUNT) { - updated_vs[ph_idx / 64] &= ~(1ULL << (ph_idx % 64)); - } -} - -static int -validate_update(struct packet_handler *handler, struct real_update *update) { - // validate real - check if it exists in handler's registry - ssize_t real_stable_idx = reals_registry_lookup( - &handler->reals_registry, &update->identifier - ); - if (real_stable_idx < 0) { - NEW_ERROR("real not found in registry"); - return -1; - } - - // check if real is present in current handler config - size_t real_config_idx; - if (map_find( - &handler->reals_index, real_stable_idx, &real_config_idx - ) != 0) { - NEW_ERROR("real is not present in current handler configuration" - ); - return -1; - } - - // validate virtual service - ssize_t vs_stable_idx = vs_registry_lookup( - &handler->vs_registry, &update->identifier.vs_identifier - ); - if (vs_stable_idx < 0) { - NEW_ERROR("virtual service not found in registry"); - return -1; - } - - // check if VS is present in current handler config - size_t vs_config_idx; - if (map_find(&handler->vs_index, vs_stable_idx, &vs_config_idx) != 0) { - NEW_ERROR("virtual service is not present in current handler " - "configuration"); - return -1; - } - - // check update params - if (update->enabled != DONT_UPDATE_REAL_ENABLED) { - if (update->enabled != 0 && update->enabled != 1) { - NEW_ERROR( - "incorrect enabled field: %u (0, 1 or -1 " - "expected)", - update->enabled - ); - return -1; - } - } - - if (update->weight == DONT_UPDATE_REAL_WEIGHT && - update->enabled == DONT_UPDATE_REAL_ENABLED) { - // update changes nothing, and it is ok - return 0; - } - - if (update->weight != DONT_UPDATE_REAL_WEIGHT && - update->weight > MAX_REAL_WEIGHT) { - NEW_ERROR( - "weight %u is too big (max is %u)", - update->weight, - MAX_REAL_WEIGHT - ); - return -1; - } - - return 0; -} - -static void -update_real(struct packet_handler *handler, struct real_update *update) { - // Find real's stable index - ssize_t real_stable_idx = reals_registry_lookup( - &handler->reals_registry, &update->identifier - ); - assert(real_stable_idx >= 0); - - // Find real's config index - size_t real_config_idx; - int res = map_find( - &handler->reals_index, real_stable_idx, &real_config_idx - ); - assert(res == 0); - - // Get the real structure - struct real *reals = ADDR_OF(&handler->reals); - struct real *real = &reals[real_config_idx]; - - // Find VS stable index - ssize_t vs_stable_idx = vs_registry_lookup( - &handler->vs_registry, &update->identifier.vs_identifier - ); - assert(vs_stable_idx >= 0); - - // Find VS config index - size_t vs_config_idx; - res = map_find(&handler->vs_index, vs_stable_idx, &vs_config_idx); - assert(res == 0); - - // Update real fields - int updated = 0; - if (update->enabled != DONT_UPDATE_REAL_ENABLED && - real->enabled != update->enabled) { - real->enabled = update->enabled; - updated = 1; - } - - if (update->weight != DONT_UPDATE_REAL_WEIGHT && - real->weight != update->weight) { - real->weight = update->weight; - updated = 1; - } - - if (updated) { - mark_vs_updated(vs_config_idx); - } -} - -static int -update_vs(struct packet_handler *handler, struct real_update *update) { - // Find VS stable index - ssize_t vs_stable_idx = vs_registry_lookup( - &handler->vs_registry, &update->identifier.vs_identifier - ); - assert(vs_stable_idx >= 0); - - // Find VS config index - size_t vs_config_idx; - int res = map_find(&handler->vs_index, vs_stable_idx, &vs_config_idx); - assert(res == 0); - - if (!is_vs_updated(vs_config_idx)) { - return 0; - } - - struct vs *vss = ADDR_OF(&handler->vs); - struct vs *vs = &vss[vs_config_idx]; - if (vs_update_reals(vs) != 0) { - PUSH_ERROR("failed to update reals"); - return -1; - } - - unmark_vs_updated(vs_config_idx); - - return 0; -} - -int -packet_handler_update_reals( - struct packet_handler *handler, - size_t count, - struct real_update *updates -) { - // validate - for (size_t i = 0; i < count; ++i) { - struct real_update *update = &updates[i]; - if (validate_update(handler, update) != 0) { - PUSH_ERROR("update at index %lu is invalid", i); - return -1; - } - } - - // update - for (size_t i = 0; i < count; ++i) { - struct real_update *update = &updates[i]; - update_real(handler, update); - } - - // update virtual services that were marked as updated - for (size_t i = 0; i < count; ++i) { - struct real_update *update = &updates[i]; - if (update_vs(handler, update) != 0) { - PUSH_ERROR( - "failed to update virtual service for update " - "at " - "index %lu", - i - ); - return -1; - } - } - - return 0; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/handler/vs.c b/modules/balancer/controlplane/handler/vs.c deleted file mode 100644 index ebd8781c4..000000000 --- a/modules/balancer/controlplane/handler/vs.c +++ /dev/null @@ -1,1048 +0,0 @@ -#include "api/vs.h" -#include "api/balancer.h" -#include "api/counter.h" -#include "common/memory.h" -#include "common/memory_address.h" -#include "common/network.h" - -#include "compiler/declare.h" -#include "counters/counters.h" -#include "lib/controlplane/diag/diag.h" - -#include "rule.h" -#include "rules.h" -#include "selector.h" -#include "vs.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "filter/compiler.h" - -#define MAX_TAG_LENGTH 240 - -static int -validate_tag(const char *tag) { - if (tag == NULL) { - return 0; // NULL is valid (means no tracking) - } - size_t len = strnlen(tag, MAX_TAG_LENGTH + 1); - if (len == 0) { - NEW_ERROR("tag must be at least 1 character long"); - return -1; - } - if (len > MAX_TAG_LENGTH) { - NEW_ERROR( - "tag length %zu exceeds maximum %d", len, MAX_TAG_LENGTH - ); - return -1; - } - return 0; -} - -static int -setup_reals( - struct vs *vs, - struct vs_config *config, - size_t first_real_idx, - struct real *reals -) { - vs->reals_count = config->real_count; - vs->first_real_idx = first_real_idx; - SET_OFFSET_OF(&vs->reals, reals); - return 0; -} - -static int -setup_selector( - struct vs *vs, struct memory_context *mctx, struct vs_config *config -) { - const struct real *reals = ADDR_OF(&vs->reals); - if (selector_init(&vs->selector, mctx, config->scheduler) != 0) { - PUSH_ERROR("failed to setup selector"); - return -1; - } - if (selector_update(&vs->selector, vs->reals_count, reals) != 0) { - selector_free(&vs->selector); - PUSH_ERROR("failed to setup selector reals"); - return -1; - } - return 0; -} - -static int -register_counter(struct vs *vs, struct counter_registry *registry) { - char name[60]; - sprintf(name, "vs_%zu", vs->stable_idx); - vs->counter_id = counter_registry_register( - registry, name, sizeof(struct vs_stats) / sizeof(uint64_t) - ); - if (vs->counter_id == (size_t)-1) { - PUSH_ERROR("failed to register counter in the counter registry" - ); - return -1; - } - return 0; -} - -static int -setup_peers( - struct vs *vs, struct memory_context *mctx, struct vs_config *config -) { - vs->peers_v4_count = config->peers_v4_count; - vs->peers_v6_count = config->peers_v6_count; - - void *peers_v4_ptr = memory_balloc( - mctx, sizeof(struct net4_addr) * vs->peers_v4_count - ); - if (peers_v4_ptr == NULL && vs->peers_v4_count > 0) { - NEW_ERROR("failed to allocate memory for IPv4 peers"); - return -1; - } - SET_OFFSET_OF(&vs->peers_v4, peers_v4_ptr); - - // Copy IPv4 peer addresses from config (config uses normal pointers) - if (vs->peers_v4_count > 0) { - memcpy(peers_v4_ptr, - config->peers_v4, - sizeof(struct net4_addr) * vs->peers_v4_count); - } - - void *peers_v6_ptr = memory_balloc( - mctx, sizeof(struct net6_addr) * vs->peers_v6_count - ); - if (peers_v6_ptr == NULL && vs->peers_v6_count > 0) { - NEW_ERROR("failed to allocate memory for IPv6 peers"); - memory_bfree( - mctx, - peers_v4_ptr, - sizeof(struct net4_addr) * vs->peers_v4_count - ); - return -1; - } - SET_OFFSET_OF(&vs->peers_v6, peers_v6_ptr); - - // Copy IPv6 peer addresses from config (config uses normal pointers) - if (vs->peers_v6_count > 0) { - memcpy(peers_v6_ptr, - config->peers_v6, - sizeof(struct net6_addr) * vs->peers_v6_count); - } - - return 0; -} - -static int -setup_flags(struct vs *vs, struct named_vs_config *config) { - if ((config->config.flags & VS_PURE_L3_FLAG) && - config->identifier.port != 0) { - NEW_ERROR( - "PureL3 mode " - "requires port=0, but port=%u was specified", - config->identifier.port - ); - return -1; - } - vs->flags = config->config.flags; - return 0; -} - -static int -validate_net4(struct net4 *net4) { - int prev = 1; - for (int bit = 31; bit >= 0; --bit) { - int byte = (31 - bit) / 8; - int inner_bit = bit % 8; - int cur = net4->mask[byte] & (1 << inner_bit); - if (cur && !prev) { - NEW_ERROR("mask bits must be consecutive"); - return -1; - } - prev = cur != 0; - } - - return 0; -} - -static int -validate_net6_half(const uint8_t *mask) { // bytes are in big-endian - int prev = 1; - for (int bit = 63; bit >= 0; --bit) { - int byte = (63 - bit) / 8; - int inner_bit = bit % 8; - int cur = mask[byte] & (1 << inner_bit); - if (cur && !prev) { - NEW_ERROR("mask bits must be consecutive"); - return -1; - } - prev = cur != 0; - } - return 0; -} - -static int -validate_net6(struct net6 *net6) { - if (validate_net6_half(net6->mask) != 0) { - PUSH_ERROR("high mask bits are invalid"); - return -1; - } - if (validate_net6_half(net6->mask + 8) != 0) { - PUSH_ERROR("low mask bits are invalid"); - return -1; - } - return 0; -} - -// fill_rule creates a filter rule with ABSOLUTE pointers (not relative). -// This allows safe sorting and comparison. Pointers are converted to relative -// offsets later in setup_acl_rules after sorting and deduplication. -static int -fill_rule( - struct vs *vs, - struct filter_rule *rule, - size_t src_idx, - struct allowed_sources *src, - struct memory_context *mctx -) { - rule->action = (uint32_t)src_idx; - - if (vs->identifier.ip_proto == IPPROTO_IP) { - rule->net4.dst_count = 0; - rule->net4.dsts = NULL; - - for (size_t net_idx = 0; net_idx < src->nets_count; ++net_idx) { - if (validate_net4(&src->nets[net_idx].v4) != 0) { - PUSH_ERROR( - "IPv4 network at index %zu is invalid", - net_idx - ); - return -1; - } - } - - rule->net4.src_count = src->nets_count; - struct net4 *net4_srcs = memory_balloc( - mctx, sizeof(struct net4) * src->nets_count - ); - if (net4_srcs == NULL) { - NEW_ERROR("failed to allocate net4 srcs"); - return -1; - } - - for (size_t net_idx = 0; net_idx < src->nets_count; ++net_idx) { - net4_srcs[net_idx] = src->nets[net_idx].v4; - } - - rule->net4.srcs = net4_srcs; // Store absolute pointer - } else if (vs->identifier.ip_proto == IPPROTO_IPV6) { - rule->net6.dst_count = 0; - rule->net6.dsts = NULL; - - for (size_t net_idx = 0; net_idx < src->nets_count; ++net_idx) { - if (validate_net6(&src->nets[net_idx].v6) != 0) { - PUSH_ERROR( - "IPv6 network at index %zu is invalid", - net_idx - ); - return -1; - } - } - - rule->net6.src_count = src->nets_count; - struct net6 *net6_srcs = memory_balloc( - mctx, sizeof(struct net6) * src->nets_count - ); - if (net6_srcs == NULL) { - NEW_ERROR("failed to allocate net6 srcs"); - return -1; - } - - for (size_t net_idx = 0; net_idx < src->nets_count; ++net_idx) { - net6_srcs[net_idx] = src->nets[net_idx].v6; - } - - rule->net6.srcs = net6_srcs; // Store absolute pointer - } - - // Handle port ranges: if none specified, use default [0, 65535] - if (src->port_ranges_count == 0) { - rule->transport.src_count = 1; - struct filter_port_range *port_srcs = - memory_balloc(mctx, sizeof(struct filter_port_range)); - if (port_srcs == NULL) { - NEW_ERROR("failed to allocate port srcs"); - return -1; - } - port_srcs[0].from = 0; - port_srcs[0].to = 65535; - rule->transport.srcs = port_srcs; // Store absolute pointer - } else { - rule->transport.src_count = src->port_ranges_count; - struct filter_port_range *port_srcs = memory_balloc( - mctx, - sizeof(struct filter_port_range) * - src->port_ranges_count - ); - if (port_srcs == NULL) { - NEW_ERROR("failed to allocate port srcs"); - return -1; - } - for (size_t port_range_idx = 0; - port_range_idx < src->port_ranges_count; - ++port_range_idx) { - struct filter_port_range *filter_port_range = - &port_srcs[port_range_idx]; - struct ports_range *port_range = - &src->port_ranges[port_range_idx]; - filter_port_range->from = port_range->from; - filter_port_range->to = port_range->to; - if (filter_port_range->from > filter_port_range->to) { - PUSH_ERROR("port range is invalid"); - return -1; - } - } - rule->transport.srcs = port_srcs; // Store absolute pointer - } - return 0; -} - -// src_filter_rules creates filter rules with ABSOLUTE pointers. -// The rules can be safely sorted and compared. Conversion to relative -// pointers happens later in setup_acl_rules. -static int -src_filter_rules( - struct vs *vs, - struct vs_config *config, - struct filter_rule **rules, - size_t *rule_count, - struct memory_context *mctx -) { - if (vs->identifier.ip_proto != IPPROTO_IP && - vs->identifier.ip_proto != IPPROTO_IPV6) { - NEW_ERROR( - "virtual service IP protocol is incorrect: %u " - "(expected IPv4 %u or IPv6 %u)", - vs->identifier.ip_proto, - IPPROTO_IP, - IPPROTO_IPV6 - ); - return -1; - } - - size_t count = config->allowed_src_count; - if (count == 0) { - *rule_count = 0; - *rules = NULL; - return 0; - } - - struct filter_rule *r = - memory_balloc(mctx, sizeof(struct filter_rule) * count); - if (r == NULL) { - NEW_ERROR("failed to allocate rules"); - return -1; - } - memset(r, 0, sizeof(struct filter_rule) * count); - for (size_t rule_idx = 0; rule_idx < config->allowed_src_count; - ++rule_idx) { - if (fill_rule( - vs, - &r[rule_idx], - rule_idx, - &config->allowed_src[rule_idx], - mctx - ) != 0) { - PUSH_ERROR("rule at index %zu is invalid", rule_idx); - // Free already allocated rules (using absolute - // pointers) - for (size_t j = 0; j < rule_idx; ++j) { - struct filter_rule *rule = &r[j]; - if (rule->net4.src_count > 0) { - memory_bfree( - mctx, - rule->net4.srcs, - sizeof(struct net4) * - rule->net4.src_count - ); - } - if (rule->net6.src_count > 0) { - memory_bfree( - mctx, - rule->net6.srcs, - sizeof(struct net6) * - rule->net6.src_count - ); - } - if (rule->transport.src_count > 0) { - memory_bfree( - mctx, - rule->transport.srcs, - sizeof(struct filter_port_range - ) * rule->transport.src_count - ); - } - } - memory_bfree( - mctx, r, sizeof(struct filter_rule) * count - ); - return -1; - } - } - - *rule_count = count; - *rules = r; - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -FILTER_COMPILER_DECLARE(vs_acl_ipv4, net4_fast_src, port_fast_src); -FILTER_COMPILER_DECLARE(vs_acl_ipv6, net6_fast_src, port_fast_src); - -//////////////////////////////////////////////////////////////////////////////// -// Helper functions for rule comparison and sorting -//////////////////////////////////////////////////////////////////////////////// - -static int -compare_net4(const void *va, const void *vb) { - const struct net4 *a = (const struct net4 *)va; - const struct net4 *b = (const struct net4 *)vb; - // Compare address first - int addr_cmp = memcmp(a->addr, b->addr, NET4_LEN); - if (addr_cmp != 0) { - return addr_cmp; - } - // Then compare mask - return memcmp(a->mask, b->mask, NET4_LEN); -} - -static int -compare_net6(const void *va, const void *vb) { - const struct net6 *a = (const struct net6 *)va; - const struct net6 *b = (const struct net6 *)vb; - // Compare address first - int addr_cmp = memcmp(a->addr, b->addr, NET6_LEN); - if (addr_cmp != 0) { - return addr_cmp; - } - // Then compare mask - return memcmp(a->mask, b->mask, NET6_LEN); -} - -static int -compare_port_range(const void *va, const void *vb) { - const struct filter_port_range *a = - (const struct filter_port_range *)va; - const struct filter_port_range *b = - (const struct filter_port_range *)vb; - if (a->from != b->from) { - return (a->from < b->from) ? -1 : 1; - } - if (a->to != b->to) { - return (a->to < b->to) ? -1 : 1; - } - return 0; -} - -static int -compare_proto_range(const void *va, const void *vb) { - const struct filter_proto_range *a = - (const struct filter_proto_range *)va; - const struct filter_proto_range *b = - (const struct filter_proto_range *)vb; - if (a->from != b->from) { - return (a->from < b->from) ? -1 : 1; - } - if (a->to != b->to) { - return (a->to < b->to) ? -1 : 1; - } - return 0; -} - -// Normalize a single rule by sorting its internal arrays. -// This function expects ABSOLUTE pointers in the rule. -static void -normalize_rule(struct filter_rule *rule) { - // Sort net4 sources (already absolute pointer) - if (rule->net4.src_count > 1 && rule->net4.srcs != NULL) { - qsort(rule->net4.srcs, - rule->net4.src_count, - sizeof(struct net4), - compare_net4); - } - - // Sort net6 sources (already absolute pointer) - if (rule->net6.src_count > 1 && rule->net6.srcs != NULL) { - qsort(rule->net6.srcs, - rule->net6.src_count, - sizeof(struct net6), - compare_net6); - } - - // Sort transport source ports (already absolute pointer) - if (rule->transport.src_count > 1 && rule->transport.srcs != NULL) { - qsort(rule->transport.srcs, - rule->transport.src_count, - sizeof(struct filter_port_range), - compare_port_range); - } - - // Sort proto ranges (already absolute pointer) - if (rule->transport.proto_count > 1 && rule->transport.protos != NULL) { - qsort(rule->transport.protos, - rule->transport.proto_count, - sizeof(struct filter_proto_range), - compare_proto_range); - } -} - -// compare_filter_rules compares two filter rules with ABSOLUTE pointers. -// Used for sorting rules during setup. -static int -compare_filter_rules(const void *va, const void *vb) { - const struct filter_rule *a = (const struct filter_rule *)va; - const struct filter_rule *b = (const struct filter_rule *)vb; - - // Compare IPv4 source networks - if (a->net4.src_count != b->net4.src_count) { - return (a->net4.src_count < b->net4.src_count) ? -1 : 1; - } - if (a->net4.src_count > 0) { - // Use absolute pointers directly - for (size_t i = 0; i < a->net4.src_count; ++i) { - int cmp = compare_net4( - &a->net4.srcs[i], &b->net4.srcs[i] - ); - if (cmp != 0) { - return cmp; - } - } - } - - // Compare IPv6 source networks - if (a->net6.src_count != b->net6.src_count) { - return (a->net6.src_count < b->net6.src_count) ? -1 : 1; - } - if (a->net6.src_count > 0) { - // Use absolute pointers directly - for (size_t i = 0; i < a->net6.src_count; ++i) { - int cmp = compare_net6( - &a->net6.srcs[i], &b->net6.srcs[i] - ); - if (cmp != 0) { - return cmp; - } - } - } - - // Compare transport source port ranges - if (a->transport.src_count != b->transport.src_count) { - return (a->transport.src_count < b->transport.src_count) ? -1 - : 1; - } - if (a->transport.src_count > 0) { - // Use absolute pointers directly - for (size_t i = 0; i < a->transport.src_count; ++i) { - int cmp = compare_port_range( - &a->transport.srcs[i], &b->transport.srcs[i] - ); - if (cmp != 0) { - return cmp; - } - } - } - - return 0; -} - -// compare_filter_rules_relative compares two filter rules with RELATIVE -// pointers. Used for comparing rules from previous VS (which are already stored -// with relative pointers). -static int -compare_filter_rules_relative( - const struct filter_rule *a, const struct filter_rule *b -) { - // Compare IPv4 source networks - if (a->net4.src_count != b->net4.src_count) { - return (a->net4.src_count < b->net4.src_count) ? -1 : 1; - } - if (a->net4.src_count > 0) { - // Convert relative pointers to absolute for comparison - const struct net4 *a_srcs = ADDR_OF(&a->net4.srcs); - const struct net4 *b_srcs = ADDR_OF(&b->net4.srcs); - for (size_t i = 0; i < a->net4.src_count; ++i) { - int cmp = compare_net4(&a_srcs[i], &b_srcs[i]); - if (cmp != 0) { - return cmp; - } - } - } - - // Compare IPv6 source networks - if (a->net6.src_count != b->net6.src_count) { - return (a->net6.src_count < b->net6.src_count) ? -1 : 1; - } - if (a->net6.src_count > 0) { - // Convert relative pointers to absolute for comparison - const struct net6 *a_srcs = ADDR_OF(&a->net6.srcs); - const struct net6 *b_srcs = ADDR_OF(&b->net6.srcs); - for (size_t i = 0; i < a->net6.src_count; ++i) { - int cmp = compare_net6(&a_srcs[i], &b_srcs[i]); - if (cmp != 0) { - return cmp; - } - } - } - - // Compare transport source port ranges - if (a->transport.src_count != b->transport.src_count) { - return (a->transport.src_count < b->transport.src_count) ? -1 - : 1; - } - if (a->transport.src_count > 0) { - // Convert relative pointers to absolute for comparison - const struct filter_port_range *a_srcs = - ADDR_OF(&a->transport.srcs); - const struct filter_port_range *b_srcs = - ADDR_OF(&b->transport.srcs); - for (size_t i = 0; i < a->transport.src_count; ++i) { - int cmp = compare_port_range(&a_srcs[i], &b_srcs[i]); - if (cmp != 0) { - return cmp; - } - } - } - - return 0; -} - -// rules_equal_relative compares two rule arrays where both have RELATIVE -// pointers. Used for comparing current VS rules with previous VS rules. -static bool -rules_equal_relative( - const struct filter_rule *rules1, - size_t count1, - const struct filter_rule *rules2, - size_t count2 -) { - if (count1 != count2) { - return false; - } - - for (size_t i = 0; i < count1; ++i) { - if (compare_filter_rules_relative(&rules1[i], &rules2[i]) != - 0) { - return false; - } - } - - return true; -} - -//////////////////////////////////////////////////////////////////////////////// - -static int -setup_acl( - struct vs *vs, - struct vs *prev_vs, - struct memory_context *mctx, - struct balancer_update_info *update_info -) { - // Check if we can reuse ACL from previous VS - // Both current and previous VS rules have relative pointers at this - // point - if (prev_vs != NULL) { - const struct filter_rule *prev_rules = ADDR_OF(&prev_vs->rules); - const struct filter_rule *curr_rules = ADDR_OF(&vs->rules); - - if (rules_equal_relative( - curr_rules, - vs->rules_count, - prev_rules, - prev_vs->rules_count - )) { - // Reuse ACL - EQUATE_OFFSET(&vs->acl, &prev_vs->acl); - prev_vs->acl_reused = 1; - - // Track reuse in update_info - if (update_info != NULL) { - size_t idx = update_info->vs_acl_reused_count++; - update_info->vs_acl_reused[idx] = - vs->identifier; - } - - return 0; - } - } - - // Need to create new ACL - vs->acl = memory_balloc(mctx, sizeof(struct filter)); - if (vs->acl == NULL) { - PUSH_ERROR("no memory"); - return -1; - } - vs->acl_reused = 0; - - // Get rules and convert relative pointers to absolute for filter_init - struct filter_rule *rules = ADDR_OF(&vs->rules); - size_t rule_count = vs->rules_count; - - for (size_t i = 0; i < rule_count; ++i) { - struct filter_rule *rule = &rules[i]; - rule->net4.srcs = ADDR_OF(&rule->net4.srcs); - rule->net6.srcs = ADDR_OF(&rule->net6.srcs); - rule->transport.srcs = ADDR_OF(&rule->transport.srcs); - } - - const struct filter_rule **rule_ptrs = (const struct filter_rule **) - malloc(sizeof(struct filter_rule *) * rule_count); - if (rule_ptrs == NULL) { - PUSH_ERROR("no memory"); - return -1; - } - for (size_t idx = 0; idx < rule_count; ++idx) { - rule_ptrs[idx] = rules + idx; - } - - // Initialize filter with absolute pointers - int res; - if (vs->identifier.ip_proto == IPPROTO_IP) { - res = filter_init( - vs->acl, vs_acl_ipv4, rule_ptrs, rule_count, mctx - ); - } else { // IPPROTO_IPV6 - res = filter_init( - vs->acl, vs_acl_ipv6, rule_ptrs, rule_count, mctx - ); - } - free(rule_ptrs); - - // Restore relative pointers - for (size_t i = 0; i < rule_count; ++i) { - struct filter_rule *rule = &rules[i]; - SET_OFFSET_OF(&rule->net4.srcs, rule->net4.srcs); - SET_OFFSET_OF(&rule->net6.srcs, rule->net6.srcs); - SET_OFFSET_OF(&rule->transport.srcs, rule->transport.srcs); - } - - if (res != 0) { - NEW_ERROR("no memory"); - return -1; - } - - SET_OFFSET_OF(&vs->acl, vs->acl); - - return 0; -} - -static void -vs_free_acl_rules(struct vs *vs, struct memory_context *mctx) { - if (vs->rules_count == 0) { - return; - } - - struct filter_rule *rules = ADDR_OF(&vs->rules); - if (rules == NULL) { - return; - } - - // Free nested arrays in each rule (only sources matter) - for (size_t i = 0; i < vs->rules_count; ++i) { - struct filter_rule *rule = &rules[i]; - - // Free net4 source arrays - if (rule->net4.src_count > 0) { - memory_bfree( - mctx, - ADDR_OF(&rule->net4.srcs), - sizeof(struct net4) * rule->net4.src_count - ); - } - - // Free net6 source arrays - if (rule->net6.src_count > 0) { - memory_bfree( - mctx, - ADDR_OF(&rule->net6.srcs), - sizeof(struct net6) * rule->net6.src_count - ); - } - - // Free transport source port arrays - if (rule->transport.src_count > 0) { - memory_bfree( - mctx, - ADDR_OF(&rule->transport.srcs), - sizeof(struct filter_port_range) * - rule->transport.src_count - ); - } - } - - // Free the rules array itself - memory_bfree(mctx, rules, sizeof(struct filter_rule) * vs->rules_count); - vs->rules_count = 0; - SET_OFFSET_OF(&vs->rules, NULL); -} - -static void -rule_to_relative_addresses(struct filter_rule *rule) { - // net4 src - SET_OFFSET_OF(&rule->net4.srcs, rule->net4.srcs); - - // net6 src - SET_OFFSET_OF(&rule->net6.srcs, rule->net6.srcs); - - // transport src - SET_OFFSET_OF(&rule->transport.srcs, rule->transport.srcs); -} - -static int -setup_acl_rules( - struct vs *vs, - struct counter_registry *counters, - struct vs_config *config, - struct memory_context *mctx -) { - // Create filter rules from config (already uses memory_balloc and - // relative pointers) - struct filter_rule *rules = NULL; - size_t rules_count = 0; - if (src_filter_rules(vs, config, &rules, &rules_count, mctx) != 0) { - PUSH_ERROR("failed to create filter rules"); - return -1; - } - - // Normalize each rule (sort internal arrays) - for (size_t i = 0; i < rules_count; ++i) { - normalize_rule(&rules[i]); - } - - // Sort the rules array - if (rules_count > 1) { - qsort(rules, - rules_count, - sizeof(struct filter_rule), - compare_filter_rules); - } - - // Remove duplicates - int last_rule_idx = -1; - for (size_t rule_idx = 0; rule_idx < rules_count; ++rule_idx) { - if (last_rule_idx != -1 && - compare_filter_rules( - &rules[rule_idx], &rules[last_rule_idx] - ) == 0) { - continue; - } - rules[++last_rule_idx] = rules[rule_idx]; - } - - char counter_name[256]; - uint64_t *rule_counters = - memory_balloc(mctx, sizeof(uint64_t) * rules_count); - if (rule_counters == NULL && rules_count > 0) { - NEW_ERROR("failed to allocate rule counters: no memory"); - return -1; - } - - // Change to relative offsets - rules_count = last_rule_idx + 1; - for (size_t i = 0; i < rules_count; ++i) { - rule_to_relative_addresses(&rules[i]); - - uint32_t allowed_src_idx = rules[i].action; - const char *rule_tag = config->allowed_src[allowed_src_idx].tag; - - // Validate tag before using it - if (validate_tag(rule_tag) != 0) { - PUSH_ERROR( - "invalid tag at allowed_src index %u", - allowed_src_idx - ); - return -1; - } - - if (rule_tag != NULL) { - // register counter - sprintf(counter_name, - "acl_%zu_%s", - vs->stable_idx, - rule_tag); - uint64_t counter_id = counter_registry_register( - counters, counter_name, 1 - ); - if (counter_id == (uint64_t)-1) { - NEW_ERROR("failed to register counter for " - "rule: no memory"); - return -1; - } - - rule_counters[i] = counter_id; - } else { - // counter is undefined, because tag not specified - rule_counters[i] = (uint64_t)-1; - } - } - - // Store using relative pointers - SET_OFFSET_OF(&vs->rules, rules); - SET_OFFSET_OF(&vs->rule_counters, rule_counters); - vs->rules_count = rules_count; - - return 0; -} - -int -vs_with_identifier_and_registry_idx_init( - struct vs *vs, - struct vs *prev_vs, - size_t first_real_idx, - struct real *reals, - struct named_vs_config *config, - struct counter_registry *counters, - struct memory_context *mctx, - struct balancer_update_info *update_info -) { - if (setup_flags(vs, config) != 0) { - PUSH_ERROR("failed to setup flags"); - return -1; - } - - if (setup_peers(vs, mctx, &config->config) != 0) { - PUSH_ERROR("failed to setup peers"); - return -1; - } - - if (setup_reals(vs, &config->config, first_real_idx, reals) != 0) { - PUSH_ERROR("failed to setup reals"); - goto free_peers; - } - - if (setup_selector(vs, mctx, &config->config) != 0) { - PUSH_ERROR("failed to setup selector"); - goto free_peers; - } - - if (register_counter(vs, counters) != 0) { - PUSH_ERROR("failed to register counter"); - goto free_selector; - } - - if (setup_acl_rules(vs, counters, &config->config, mctx) != 0) { - PUSH_ERROR("failed to store acl rules"); - goto free_selector; - } - - if (setup_acl(vs, prev_vs, mctx, update_info) != 0) { - PUSH_ERROR("failed to setup acl"); - goto free_acl_rules; - } - - return 0; - -free_acl_rules: - vs_free_acl_rules(vs, mctx); - -free_selector: - selector_free(&vs->selector); - -free_peers: - memory_bfree( - mctx, - ADDR_OF(&vs->peers_v4), - sizeof(struct net4_addr) * vs->peers_v4_count - ); - memory_bfree( - mctx, - ADDR_OF(&vs->peers_v6), - sizeof(struct net6_addr) * vs->peers_v6_count - ); - - return -1; -} - -void -vs_free(struct vs *vs, struct memory_context *mctx) { - memory_bfree( - mctx, - ADDR_OF(&vs->peers_v4), - sizeof(struct net4_addr) * vs->peers_v4_count - ); - memory_bfree( - mctx, - ADDR_OF(&vs->peers_v6), - sizeof(struct net6_addr) * vs->peers_v6_count - ); - selector_free(&vs->selector); -} - -int -vs_update_reals(struct vs *vs) { - if (selector_update( - &vs->selector, vs->reals_count, ADDR_OF(&vs->reals) - ) != 0) { - PUSH_ERROR("failed to update real selector"); - return -1; - } - return 0; -} - -ssize_t -counter_to_vs_registry_idx(struct counter_handle *counter) { - if (strncmp(counter->name, "vs_", 3) == 0) { // vs counter - return atoi(counter->name + 3); - } else { - return -1; - } -} - -//////////////////////////////////////////////////////////////////////////////// - -static void -setup_reals_usage( - struct reals_usage *reals_usage, size_t workers, size_t reals_count -) { - reals_usage->counters_usage = - sizeof(struct real_stats) * workers * reals_count; - reals_usage->data_usage = sizeof(struct real) * reals_count; - reals_usage->total_usage = - reals_usage->counters_usage + reals_usage->data_usage; -} - -void -vs_fill_inspect(struct vs *vs, struct vs_inspect *inspect, size_t workers) { - inspect->acl_usage = filter_memory_usage(ADDR_OF(&vs->acl)); - inspect->ring_usage = selector_memory_usage(&vs->selector); - inspect->counters_usage = sizeof(struct vs_stats) * workers; - setup_reals_usage(&inspect->reals_usage, workers, vs->reals_count); - inspect->other_usage = - rules_memory_usage(vs->rules_count, ADDR_OF(&vs->rules)) + - sizeof(struct filter_rule) * vs->rules_count; - inspect->other_usage += vs->peers_v4_count * sizeof(struct net4_addr); - inspect->other_usage += vs->peers_v6_count * sizeof(struct net6_addr); - inspect->total_usage = inspect->acl_usage + inspect->ring_usage + - inspect->counters_usage + - inspect->reals_usage.total_usage + - inspect->other_usage; -} - -ssize_t -parse_vs_acl_counter(struct counter_handle *counter, const char **tag) { - // in format acl__ - if (strncmp(counter->name, "acl_", 4) == 0) { // vs acl counter - char *end_ptr = NULL; - size_t vs_registry_idx = - strtoull(counter->name + 4, &end_ptr, 10); - *tag = end_ptr + 1; // Point to the tag string in counter name - return vs_registry_idx; - } else { - return -1; - } -} diff --git a/modules/balancer/controlplane/handler/vs.h b/modules/balancer/controlplane/handler/vs.h deleted file mode 100644 index f4bb3b161..000000000 --- a/modules/balancer/controlplane/handler/vs.h +++ /dev/null @@ -1,133 +0,0 @@ -#pragma once - -#include "api/inspect.h" -#include "api/vs.h" - -#include "common/memory.h" -#include "filter.h" -#include "selector.h" -#include - -struct vs_state; -struct real; -struct named_vs_config; -struct balancer_update_info; - -/** - * Handler-side view of a virtual service. - * - * Holds selection policy, backend views and source filters for fast-path - * lookup. - */ -struct vs { - struct vs_identifier identifier; // Address + Port + Proto - - size_t stable_idx; // Index in the registry - - uint8_t flags; // VS_* flags describing behavior/scheduling - - // Can be modified atomically via real_update method - struct real_selector selector; - - size_t reals_count; // Number of elements in 'reals' - const struct real *reals; // Array of reals belongs to Virtual Service - - // Index of the first real in the reals array - size_t first_real_idx; - - // Access Control List (ACL) filter for source IP/port filtering - // Compiled from the rules array below, used for fast packet matching - // (relative pointer) - struct filter *acl; - - // Set to 1 when ACL filter is reused from previous VS configuration - // Prevents double-free during configuration updates - // Set to 0 when a new ACL is built - // Reuse occurs when filter rules are identical between old and new - // config - int acl_reused; - - // Number of filter rules in the rules array - // Each rule specifies allowed source networks and port ranges - size_t rules_count; - - // Array of filter rules defining allowed sources (relative pointer) - // Rules are normalized (sorted, deduplicated) and stored with relative - // pointers to their internal arrays (net4.srcs, net6.srcs, - // transport.srcs) - struct filter_rule *rules; - - // Indices of counters for rules - uint64_t *rule_counters; - - // Number of IPv4 peer balancer addresses - size_t peers_v4_count; - - // Array of IPv4 peer balancer addresses (relative pointer) - // Used for coordinating with other balancer instances - struct net4_addr *peers_v4; - - // Number of IPv6 peer balancer addresses - size_t peers_v6_count; - - // Array of IPv6 peer balancer addresses (relative pointer) - // Used for coordinating with other balancer instances - struct net6_addr *peers_v6; - - uint64_t counter_id; // Per-VS counter id -}; - -/** - * Initialize handler-side VS view. - * Returns 0 on success, -1 on error. - */ -int -vs_with_identifier_and_registry_idx_init( - struct vs *vs, - struct vs *prev_vs, - size_t first_real_idx, - struct real *reals, - struct named_vs_config *config, - struct counter_registry *registry, - struct memory_context *mctx, - struct balancer_update_info *update_info -); - -/** - * Free resources bound to the VS view. - */ -void -vs_free(struct vs *vs, struct memory_context *mctx); - -/** - * Refresh real selector and related data from the current state. - * Returns 0 on success, -1 on error. - */ -int -vs_update_reals(struct vs *vs); - -/** - * Resolve VS registry index from a counter handle. - * Returns index on success, or -1 on error. - */ -ssize_t -counter_to_vs_registry_idx(struct counter_handle *counter); - -//////////////////////////////////////////////////////////////////////////////// - -static inline bool -vs_real_enabled(struct vs *vs, uint32_t real_idx) { - return selector_real_enabled( - &vs->selector, real_idx - vs->first_real_idx - ); -} - -//////////////////////////////////////////////////////////////////////////////// - -void -vs_fill_inspect(struct vs *vs, struct vs_inspect *inspect, size_t workers); - -//////////////////////////////////////////////////////////////////////////////// - -ssize_t -parse_vs_acl_counter(struct counter_handle *counter, const char **tag); \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/agent.c b/modules/balancer/controlplane/helpers/agent.c new file mode 100644 index 000000000..34343d7ee --- /dev/null +++ b/modules/balancer/controlplane/helpers/agent.c @@ -0,0 +1,169 @@ +#include + +#include "agent.h" + +#include "api/agent.h" + +#include "common/memory_address.h" +#include "common/ttlmap/ttlmap.h" + +#include "lib/controlplane/agent/agent.h" +#include "lib/dataplane/config/zone.h" + +#include "modules/balancer/dataplane/dataplane.h" +#include "modules/balancer/dataplane/types/session.h" + +static const char *storage_name = "balancer_storage"; + +struct balancer_storage { + size_t count; + struct balancer_packet_handler **handlers; +}; + +static struct balancer_storage * +get_storage(struct agent *agent) { + return agent_storage_read(agent, storage_name); +} + +static int +ensure_storage(struct agent *agent) { + if (get_storage(agent) != NULL) { + return 0; + } + + struct balancer_storage empty = {0}; + return agent_storage_put(agent, storage_name, &empty, sizeof(empty)); +} + +int +balancer_agent_install( + struct agent *agent, struct balancer_packet_handler *handler +) { + struct cp_module *module = &handler->cp_module; + return agent_update_modules(agent, 1, &module); +} + +int +balancer_agent_register( + struct agent *agent, struct balancer_packet_handler *handler +) { + if (ensure_storage(agent) != 0) { + return -1; + } + + /* Try to find free slot. */ + struct balancer_storage *s = get_storage(agent); + struct balancer_packet_handler **handlers = ADDR_OF(&s->handlers); + for (size_t i = 0; i < s->count; ++i) { + if (handlers[i] == NULL) { + SET_OFFSET_OF(handlers + i, handler); + return 0; + } + } + + /* No free slot found, allocate new one. */ + struct memory_context *mctx = &agent->memory_context; + size_t new_count = s->count + 1; + struct balancer_packet_handler **new_arr = memory_balloc( + mctx, sizeof(struct balancer_packet_handler *) * new_count + ); + if (new_arr == NULL) { + return -1; + } + + /* Copy existing entries. */ + if (s->count > 0) { + struct balancer_packet_handler **old_arr = + ADDR_OF(&s->handlers); + for (size_t i = 0; i < s->count; ++i) { + EQUATE_OFFSET(new_arr + i, old_arr + i); + } + + memory_bfree( + mctx, + old_arr, + sizeof(struct balancer_packet_handler *) * s->count + ); + } + + /* Append new handler. */ + SET_OFFSET_OF(new_arr + s->count, handler); + SET_OFFSET_OF(&s->handlers, new_arr); + s->count = new_count; + + return 0; +} + +struct balancer_packet_handler ** +balancer_agent_list(struct agent *agent, size_t *count) { + struct balancer_storage *s = get_storage(agent); + if (s == NULL) { + return NULL; + } + + *count = s->count; + + return ADDR_OF(&s->handlers); +} + +void +balancer_agent_forget( + struct agent *agent, struct balancer_packet_handler *handler +) { + struct balancer_storage *s = get_storage(agent); + if (s == NULL) { + return; + } + + struct balancer_packet_handler **handlers = ADDR_OF(&s->handlers); + + for (size_t i = 0; i < s->count; ++i) { + if (ADDR_OF(handlers + i) == handler) { + handlers[i] = NULL; + break; + } + } +} + +struct balancer_session_table * +balancer_agent_create_st(struct agent *agent, size_t capacity) { + struct memory_context *mctx = &agent->memory_context; + + struct balancer_session_table *st = + memory_balloc(mctx, sizeof(struct balancer_session_table)); + if (st == NULL) { + return NULL; + } + memset(st, 0, sizeof(*st)); + + memory_context_init_from(&st->mctx, mctx, "session_table"); + + int res = TTLMAP_INIT( + &st->maps[0], + &st->mctx, + struct balancer_session_id, + struct balancer_session_state, + capacity + ); + if (res != 0) { + memory_bfree(mctx, st, sizeof(*st)); + return NULL; + } + + ttlmap_init_empty(&st->maps[1]); + rcu_init(&st->rcu); + st->current_gen = 0; + st->workers = ADDR_OF(&agent->dp_config)->worker_count; + + return st; +} + +void +balancer_agent_destroy_st( + struct agent *agent, struct balancer_session_table *st +) { + TTLMAP_FREE(&st->maps[0]); + TTLMAP_FREE(&st->maps[1]); + struct memory_context *mctx = &agent->memory_context; + memory_bfree(mctx, st, sizeof(*st)); +} diff --git a/modules/balancer/controlplane/helpers/agent.h b/modules/balancer/controlplane/helpers/agent.h new file mode 100644 index 000000000..8aa9ce43a --- /dev/null +++ b/modules/balancer/controlplane/helpers/agent.h @@ -0,0 +1,50 @@ +/* + * Agent-level helpers: handler installation/registration, handler + * storage management, and session table lifecycle. + * + * Error convention: functions returning int use 0 for success + * and -1 for failure (allocation or storage error). + */ +#pragma once + +#include +#include + +struct agent; +struct balancer_packet_handler; +struct balancer_session_table; + +/* Push handler's cp_module into the dataplane module list. + * Returns 0 on success, -1 on failure. */ +int +balancer_agent_install( + struct agent *agent, struct balancer_packet_handler *handler +); + +/* Register handler in the per-agent balancer_storage array (reuses + * NULL slots before growing). Returns 0 on success, -1 on failure. */ +int +balancer_agent_register( + struct agent *agent, struct balancer_packet_handler *handler +); + +/* Remove handler from balancer_storage by NULLing its slot. */ +void +balancer_agent_forget( + struct agent *agent, struct balancer_packet_handler *handler +); + +/* Return the array of registered handlers and set *count. + * Returns NULL and leaves *count unchanged when no storage exists. */ +struct balancer_packet_handler ** +balancer_agent_list(struct agent *agent, size_t *count); + +/* Allocate and initialize a session table with the given capacity. + * Returns NULL on allocation failure. */ +struct balancer_session_table * +balancer_agent_create_st(struct agent *agent, size_t capacity); + +void +balancer_agent_destroy_st( + struct agent *agent, struct balancer_session_table *st +); \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/balancer.c b/modules/balancer/controlplane/helpers/balancer.c new file mode 100644 index 000000000..384a589f1 --- /dev/null +++ b/modules/balancer/controlplane/helpers/balancer.c @@ -0,0 +1,567 @@ +#include "balancer.h" + +#include +#include +#include + +#include "common/memory.h" +#include "common/memory_address.h" +#include "common/rcu.h" + +#include "filter/compiler.h" +#include "filter/rule.h" + +#include "lib/controlplane/agent/agent.h" +#include "lib/controlplane/config/cp_module.h" +#include "lib/counters/counters.h" + +#include "modules/balancer/dataplane/dataplane.h" +#include "modules/balancer/dataplane/types/real.h" +#include "modules/balancer/dataplane/types/stats.h" +#include "modules/balancer/dataplane/types/vs.h" + +FILTER_COMPILER_DECLARE( + vs_matcher_ipv4, net4_fast_dst, port_fast_dst, proto_range_fast +); +FILTER_COMPILER_DECLARE( + vs_matcher_ipv6, net6_fast_dst, port_fast_dst, proto_range_fast +); + +FILTER_COMPILER_DECLARE(decap_ipv4, net4_fast_dst); +FILTER_COMPILER_DECLARE(decap_ipv6, net6_fast_dst); + +int +balancer_initial_setup( + struct agent *agent, + struct balancer_packet_handler *handler, + const char *name, + struct balancer_session_table *session_table +) { + memset(handler, 0, sizeof(*handler)); + + struct memory_context *mctx = &agent->memory_context; + + if (cp_module_init(&handler->cp_module, agent, "balancer", name) != 0) { + memory_bfree(mctx, handler, sizeof(*handler)); + return -1; + } + + SET_OFFSET_OF(&handler->session_table, session_table); + rcu_init(&handler->rcu); + + return 0; +} + +const char * +balancer_name(struct balancer_packet_handler *handler) { + return handler->cp_module.name; +} + +static int +register_vs_counter(struct balancer_vs *vs, struct counter_registry *registry) { + char name[64]; + snprintf(name, sizeof(name), "vs_%zu", vs->stable_idx); + + vs->counter_id = counter_registry_register( + registry, + name, + sizeof(struct balancer_vs_stats) / sizeof(uint64_t) + ); + if (vs->counter_id == (uint64_t)-1) { + return -1; + } + + return 0; +} + +static int +register_vs_rule_counters( + struct balancer_vs *vs, + struct counter_registry *registry, + struct memory_context *mctx +) { + uint32_t allowed_src_count = vs->allowed_sources_count; + if (allowed_src_count == 0) { + return 0; + } + + struct balancer_vs_allowed_source *allowed_srcs = + ADDR_OF(&vs->allowed_sources); + + uint64_t *ids = + memory_balloc(mctx, allowed_src_count * sizeof(uint64_t)); + if (ids == NULL) { + return -1; + } + + for (uint32_t i = 0; i < allowed_src_count; ++i) { + if (allowed_srcs[i].tag[0] == 0) { + ids[i] = (uint64_t)-1; + continue; + } + + char name[64]; + snprintf( + name, + sizeof(name), + "acl_%zu_%s", + vs->stable_idx, + allowed_srcs[i].tag + ); + + ids[i] = counter_registry_register(registry, name, 1); + if (ids[i] == (uint64_t)-1) { + memory_bfree( + mctx, ids, allowed_src_count * sizeof(uint64_t) + ); + return -1; + } + } + + SET_OFFSET_OF(&vs->rule_counter_ids, ids); + + return 0; +} + +static int +register_real_counter( + struct balancer_real *real, + size_t vs_stable_ix, + struct counter_registry *registry +) { + char name[64]; + snprintf( + name, sizeof(name), "rl_%zu_%zu", vs_stable_ix, real->stable_idx + ); + + real->counter_id = counter_registry_register( + registry, + name, + sizeof(struct balancer_real_stats) / sizeof(uint64_t) + ); + if (real->counter_id == (uint64_t)-1) { + return -1; + } + + return 0; +} + +static int +register_module_counters( + struct balancer_packet_handler *handler, + struct counter_registry *registry +) { + handler->common_counter_id = counter_registry_register( + registry, + "cmn", + sizeof(struct balancer_common_stats) / sizeof(uint64_t) + ); + if (handler->common_counter_id == (uint64_t)-1) { + return -1; + } + + handler->icmp_v4_counter_id = counter_registry_register( + registry, + "iv4", + sizeof(struct balancer_icmp_stats) / sizeof(uint64_t) + ); + if (handler->icmp_v4_counter_id == (uint64_t)-1) { + return -1; + } + + handler->icmp_v6_counter_id = counter_registry_register( + registry, + "iv6", + sizeof(struct balancer_icmp_stats) / sizeof(uint64_t) + ); + if (handler->icmp_v6_counter_id == (uint64_t)-1) { + return -1; + } + + handler->l4_counter_id = counter_registry_register( + registry, + "l4", + sizeof(struct balancer_l4_stats) / sizeof(uint64_t) + ); + if (handler->l4_counter_id == (uint64_t)-1) { + return -1; + } + + return 0; +} + +int +balancer_register_counters(struct balancer_packet_handler *handler) { + struct agent *agent = ADDR_OF(&handler->cp_module.agent); + struct memory_context *mctx = &agent->memory_context; + struct counter_registry *registry = + &handler->cp_module.counter_registry; + + if (register_module_counters(handler, registry) != 0) { + return -1; + } + + struct balancer_vs *services = ADDR_OF(&handler->vs); + for (uint32_t vs_idx = 0; vs_idx < handler->vs_count; ++vs_idx) { + struct balancer_vs *vs = &services[vs_idx]; + if (vs->flags & balancer_vs_removed) { + continue; + } + + if (register_vs_counter(vs, registry) != 0) { + return -1; + } + + if (register_vs_rule_counters(vs, registry, mctx) != 0) { + return -1; + } + + struct balancer_real *reals = ADDR_OF(&vs->reals); + for (uint32_t real_idx = 0; real_idx < vs->reals_count; + ++real_idx) { + struct balancer_real *real = &reals[real_idx]; + if (real->flags & balancer_real_removed) { + continue; + } + if (register_real_counter( + real, vs->stable_idx, registry + ) != 0) { + return -1; + } + } + } + + return 0; +} + +static void +free_rules(struct filter_rule *rules, size_t count) { + for (size_t i = 0; i < count; ++i) { + free(rules[i].net4.dsts); + free(rules[i].net6.dsts); + free(rules[i].transport.srcs); + free(rules[i].transport.dsts); + free(rules[i].transport.protos); + } + free(rules); +} + +void +balancer_free_decap_filters(struct balancer_packet_handler *handler) { + struct agent *agent = ADDR_OF(&handler->cp_module.agent); + struct memory_context *mctx = &agent->memory_context; + + struct filter *filter_ipv4 = ADDR_OF(&handler->decap_ipv4_filter); + if (filter_ipv4 != NULL) { + filter_free(filter_ipv4, decap_ipv4); + memory_bfree(mctx, filter_ipv4, sizeof(struct filter)); + SET_OFFSET_OF(&handler->decap_ipv4_filter, NULL); + } + + struct filter *filter_ipv6 = ADDR_OF(&handler->decap_ipv6_filter); + if (filter_ipv6 != NULL) { + filter_free(filter_ipv6, decap_ipv6); + memory_bfree(mctx, filter_ipv6, sizeof(struct filter)); + SET_OFFSET_OF(&handler->decap_ipv6_filter, NULL); + } +} + +void +balancer_free_vs_matchers(struct balancer_packet_handler *handler) { + struct agent *agent = ADDR_OF(&handler->cp_module.agent); + struct memory_context *mctx = &agent->memory_context; + + struct filter *ipv4 = ADDR_OF(&handler->ipv4_vs_matcher); + if (ipv4 != NULL) { + filter_free(ipv4, vs_matcher_ipv4); + memory_bfree(mctx, ipv4, sizeof(struct filter)); + SET_OFFSET_OF(&handler->ipv4_vs_matcher, NULL); + } + + struct filter *ipv6 = ADDR_OF(&handler->ipv6_vs_matcher); + if (ipv6 != NULL) { + filter_free(ipv6, vs_matcher_ipv6); + memory_bfree(mctx, ipv6, sizeof(struct filter)); + SET_OFFSET_OF(&handler->ipv6_vs_matcher, NULL); + } +} + +static int +make_dst_addr_rule(struct filter_rule *rule, uint8_t *dst, uint8_t ipproto) { + if (ipproto == IPPROTO_IPV6) { + rule->net6.dst_count = 1; + rule->net6.dsts = malloc(sizeof(struct net6)); + if (rule->net6.dsts == NULL) { + return -1; + } + memcpy(rule->net6.dsts[0].addr, dst, NET6_LEN); + memset(rule->net6.dsts[0].mask, 0xFF, NET6_LEN); + } else { + rule->net4.dst_count = 1; + rule->net4.dsts = malloc(sizeof(struct net4)); + if (rule->net4.dsts == NULL) { + return -1; + } + memcpy(rule->net4.dsts[0].addr, dst, NET4_LEN); + memset(rule->net4.dsts[0].mask, 0xFF, NET4_LEN); + } + return 0; +} + +static ssize_t +make_decap_rules( + struct balancer_packet_handler *handler, + struct filter_rule **out, + int is_ipv6 +) { + size_t count = + is_ipv6 ? handler->decap_v6_count : handler->decap_v4_count; + struct filter_rule *rules = calloc(count, sizeof(struct filter_rule)); + if (rules == NULL && count > 0) { + return -1; + } + // Calloc initializes the memory to zero. + + struct net4_addr *decap_v4 = ADDR_OF(&handler->decap_v4); + struct net6_addr *decap_v6 = ADDR_OF(&handler->decap_v6); + + for (size_t i = 0; i < count; ++i) { + struct filter_rule *rule = &rules[i]; + + int res; + if (is_ipv6) { + res = make_dst_addr_rule( + rule, decap_v6[i].bytes, IPPROTO_IPV6 + ); + } else { + res = make_dst_addr_rule( + rule, decap_v4[i].bytes, IPPROTO_IP + ); + } + if (res != 0) { + free_rules(rules, count); + return -1; + } + + rule->action = i; + } + + *out = rules; + return count; +} + +static int +build_decap_filter(struct balancer_packet_handler *handler, int is_ipv6) { + struct agent *agent = ADDR_OF(&handler->cp_module.agent); + struct memory_context *mctx = &agent->memory_context; + + struct filter *filter = memory_balloc(mctx, sizeof(struct filter)); + if (filter == NULL) { + return -1; + } + + struct filter_rule *rules = NULL; + ssize_t res = make_decap_rules(handler, &rules, is_ipv6); + if (res == -1) { + memory_bfree(mctx, filter, sizeof(struct filter)); + return -2; + } + size_t count = (size_t)res; + + const struct filter_rule **ptrs = calloc(count, sizeof(*ptrs)); + if (ptrs == NULL && count > 0) { + free_rules(rules, count); + memory_bfree(mctx, filter, sizeof(struct filter)); + return -1; + } + for (size_t i = 0; i < count; ++i) { + ptrs[i] = &rules[i]; + } + if (is_ipv6) { + res = filter_init(filter, decap_ipv6, ptrs, count, mctx); + } else { + res = filter_init(filter, decap_ipv4, ptrs, count, mctx); + } + free(ptrs); + free_rules(rules, count); + + if (res != 0) { + memory_bfree(mctx, filter, sizeof(struct filter)); + return -1; + } + + if (is_ipv6) { + SET_OFFSET_OF(&handler->decap_ipv6_filter, filter); + } else { + SET_OFFSET_OF(&handler->decap_ipv4_filter, filter); + } + + return 0; +} + +int +balancer_set_ipv4_decap_filter(struct balancer_packet_handler *handler) { + return build_decap_filter(handler, 0); +} + +int +balancer_set_ipv6_decap_filter(struct balancer_packet_handler *handler) { + return build_decap_filter(handler, 1); +} + +static int +make_transport_rule(struct filter_rule *rule, struct balancer_vs *vs) { + rule->transport.dst_count = 1; + rule->transport.dsts = malloc(sizeof(struct filter_port_range)); + if (rule->transport.dsts == NULL) { + return -1; + } + + if (vs->flags & balancer_vs_pure_l3) { + rule->transport.dsts[0].from = 0; + rule->transport.dsts[0].to = 65535; + } else { + rule->transport.dsts[0].from = vs->port; + rule->transport.dsts[0].to = vs->port; + } + + rule->transport.proto_count = 1; + rule->transport.protos = calloc(1, sizeof(struct filter_proto_range)); + if (rule->transport.protos == NULL) { + free(rule->transport.dsts); + return -1; + } + + rule->transport.protos[0].from = vs->transport_proto * 256; + rule->transport.protos[0].to = vs->transport_proto * 256 + 255; + + return 0; +} + +static int +make_vs_matcher_rule(struct filter_rule *rule, struct balancer_vs *vs) { + if (make_dst_addr_rule(rule, vs->addr.v6.bytes, vs->ip_proto) != 0) { + return -1; + } + + if (make_transport_rule(rule, vs) != 0) { + return -1; + } + + return 0; +} + +static ssize_t +make_vs_matcher_rules( + struct filter_rule **out, + uint8_t ipproto, + struct balancer_vs *services, + uint32_t service_count +) { + struct filter_rule *rules = + calloc(service_count, sizeof(struct filter_rule)); + if (rules == NULL && service_count > 0) { + return -1; + } + // Calloc initializes the memory to zero. + + for (size_t i = 0; i < service_count; ++i) { + struct balancer_vs *vs = &services[i]; + if (vs->ip_proto != ipproto || + (vs->flags & balancer_vs_removed)) { + continue; + } + + struct filter_rule *rule = &rules[i]; + + if (make_vs_matcher_rule(rule, vs) != 0) { + free_rules(rules, service_count); + return -1; + } + } + + *out = rules; + return service_count; +} + +static int +build_vs_matcher( + struct balancer_packet_handler *handler, + uint8_t ipproto, + struct filter **result +) { + struct agent *agent = ADDR_OF(&handler->cp_module.agent); + struct memory_context *mctx = &agent->memory_context; + struct balancer_vs *vs = ADDR_OF(&handler->vs); + size_t vs_count = handler->vs_count; + + struct filter *filter = memory_balloc(mctx, sizeof(struct filter)); + if (filter == NULL) { + return -1; + } + + struct filter_rule *rules = NULL; + ssize_t res = make_vs_matcher_rules(&rules, ipproto, vs, vs_count); + if (res == -1) { + memory_bfree(mctx, filter, sizeof(struct filter)); + return -2; + } + size_t count = (size_t)res; + + const struct filter_rule **ptrs = calloc(count, sizeof(*ptrs)); + if (ptrs == NULL && count > 0) { + free_rules(rules, count); + memory_bfree(mctx, filter, sizeof(struct filter)); + return -1; + } + for (size_t i = 0; i < count; ++i) { + if (vs[i].ip_proto != ipproto || + (vs[i].flags & balancer_vs_removed)) { + ptrs[i] = NULL; + } else { + ptrs[i] = &rules[i]; + } + } + if (ipproto == IPPROTO_IPV6) { + res = filter_init(filter, vs_matcher_ipv6, ptrs, count, mctx); + } else { + res = filter_init(filter, vs_matcher_ipv4, ptrs, count, mctx); + } + free(ptrs); + free_rules(rules, count); + + if (res != 0) { + memory_bfree(mctx, filter, sizeof(struct filter)); + return -1; + } + + *result = filter; + + return 0; +} + +int +balancer_set_ipv4_vs_matcher(struct balancer_packet_handler *handler) { + struct filter *filter; + int res = build_vs_matcher(handler, IPPROTO_IP, &filter); + if (res != 0) { + return res; + } + + SET_OFFSET_OF(&handler->ipv4_vs_matcher, filter); + + return 0; +} + +int +balancer_set_ipv6_vs_matcher(struct balancer_packet_handler *handler) { + struct filter *filter; + int res = build_vs_matcher(handler, IPPROTO_IPV6, &filter); + if (res != 0) { + return res; + } + + SET_OFFSET_OF(&handler->ipv6_vs_matcher, filter); + + return 0; +} \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/balancer.h b/modules/balancer/controlplane/helpers/balancer.h new file mode 100644 index 000000000..6f70d11a0 --- /dev/null +++ b/modules/balancer/controlplane/helpers/balancer.h @@ -0,0 +1,58 @@ +/* + * Balancer packet handler helpers. + * + * Error convention: functions returning int use 0 for success and + * non-zero for failure. Specific codes: + * -1 shared memory allocation failure + * -2 heap allocation failure + */ +#pragma once + +#include +#include + +struct balancer_packet_handler; +struct balancer_session_table; +struct agent; + +/* Zero-initializes handler, sets up cp_module and session table pointer. + * Returns 0 on success, -1 on failure. */ +int +balancer_initial_setup( + struct agent *agent, + struct balancer_packet_handler *handler, + const char *name, + struct balancer_session_table *session_table +); + +/* Registers per-VS, per-real, and module-level counters. + * Returns 0 on success, -1 on failure. */ +int +balancer_register_counters(struct balancer_packet_handler *handler); + +const char * +balancer_name(struct balancer_packet_handler *handler); + +/* Compile and set the IPv4/IPv6 VS matcher filter. + * Returns 0 on success, -1 on shared memory allocation failure, + * -2 on heap allocation failure. */ +int +balancer_set_ipv4_vs_matcher(struct balancer_packet_handler *handler); + +int +balancer_set_ipv6_vs_matcher(struct balancer_packet_handler *handler); + +void +balancer_free_vs_matchers(struct balancer_packet_handler *handler); + +/* Compile and set the IPv4/IPv6 decap address filter. + * Returns 0 on success, -1 on shared memory allocation failure, + * -2 on heap allocation failure. */ +int +balancer_set_ipv4_decap_filter(struct balancer_packet_handler *handler); + +int +balancer_set_ipv6_decap_filter(struct balancer_packet_handler *handler); + +void +balancer_free_decap_filters(struct balancer_packet_handler *handler); \ No newline at end of file diff --git a/modules/balancer/agent/meson.build b/modules/balancer/controlplane/helpers/meson.build similarity index 50% rename from modules/balancer/agent/meson.build rename to modules/balancer/controlplane/helpers/meson.build index 691f8a092..1def1c99f 100644 --- a/modules/balancer/agent/meson.build +++ b/modules/balancer/controlplane/helpers/meson.build @@ -1,23 +1,25 @@ -subdir('balancerpb') - dependencies = [ lib_common_dep, lib_agent_cp_dep, - lib_balancer_cp_dep, + lib_filter_compiler_dep, ] +includes = include_directories('.', '../../../../') + sources = files( - 'agent.c', - 'config.c', - 'manager.c' + 'balancer.c', + 'agent.c', + 'sessions.c', + 'vs.c', + 'real.c', ) static_library( - 'balancer_agent', + 'balancer_helpers', sources, c_args: yanet_c_args, link_args: yanet_link_args, dependencies: dependencies, - include_directories: [yanet_rootdir], + include_directories: includes, install: false, -) \ No newline at end of file +) diff --git a/modules/balancer/controlplane/helpers/real.c b/modules/balancer/controlplane/helpers/real.c new file mode 100644 index 000000000..860bf63b6 --- /dev/null +++ b/modules/balancer/controlplane/helpers/real.c @@ -0,0 +1,67 @@ +#include "common/memory_address.h" + +#include "modules/balancer/dataplane/types/interval_counter.h" +#include "modules/balancer/dataplane/types/real.h" +#include "modules/balancer/dataplane/types/sessions_tracker.h" + +#include "real.h" + +static uint32_t +calc_active_sessions( + const struct balancer_sessions_tracker_shard *shard, + uint32_t now, + uint32_t *last_packet_timestamp +) { + const struct balancer_interval_counter *counter = &shard->counter; + + const uint32_t initial_last_timestamp = shard->last_timestamp; + const uint32_t last_dp_tick = counter->last_timestamp; + uint32_t current_cp_tick = now / BALANCER_SESSIONS_TRACKER_PRECISION; + + uint32_t next_tick = counter->last_timestamp + 1; + + int32_t add = 0; + while (next_tick <= current_cp_tick && + next_tick - last_dp_tick < BALANCER_IC_RING_SIZE) { + add += counter->diff[next_tick & BALANCER_IC_RING_MASK]; + ++next_tick; + } + + // Prevent reuse of the last_timestamp value + // and force reload from memory. + __asm__ volatile("" ::: "memory"); + + uint32_t count = shard->count; + uint32_t last_timestamp = shard->last_timestamp; + if (last_timestamp != initial_last_timestamp) { + add = 0; + } + + *last_packet_timestamp = last_timestamp; + + return ((int64_t)count + add >= 0 ? count + add : 0); +} + +void +balancer_real_sessions( + struct balancer_real *real, + size_t workers, + uint64_t *active_sessions, + uint32_t *last_packet_timestamp, + uint32_t now +) { + struct balancer_sessions_tracker_shard *tracker_shards = + ADDR_OF(&real->tracker_shards); + *active_sessions = 0; + *last_packet_timestamp = 0; + for (size_t worker_idx = 0; worker_idx < workers; ++worker_idx) { + const struct balancer_sessions_tracker_shard *shard = + &tracker_shards[worker_idx]; + uint32_t last_timestamp; + *active_sessions += + calc_active_sessions(shard, now, &last_timestamp); + if (*last_packet_timestamp < last_timestamp) { + *last_packet_timestamp = last_timestamp; + } + } +} \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/real.h b/modules/balancer/controlplane/helpers/real.h new file mode 100644 index 000000000..6b6f44dc4 --- /dev/null +++ b/modules/balancer/controlplane/helpers/real.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +struct balancer_real; + +void +balancer_real_sessions( + struct balancer_real *real, + size_t workers, + uint64_t *active_sessions, + uint32_t *last_packet_timestamp, + uint32_t now +); \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/sessions.c b/modules/balancer/controlplane/helpers/sessions.c new file mode 100644 index 000000000..10790809c --- /dev/null +++ b/modules/balancer/controlplane/helpers/sessions.c @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include + +#include "common/ttlmap/ttlmap.h" + +#include "modules/balancer/dataplane/dataplane.h" +#include "modules/balancer/dataplane/types/session.h" + +#include "sessions.h" + +struct move_ctx { + struct ttlmap *dst; + uint32_t now; +}; + +static int +move_session_cb(void *key, void *value, void *userdata) { + struct move_ctx *ctx = userdata; + struct balancer_session_state *state = value; + struct balancer_session_id *id = key; + struct balancer_session_state *new_state; + + ttlmap_lock_t *lock; + int res = TTLMAP_GET( + ctx->dst, + id, + &new_state, + &lock, + state->last_packet_timestamp, + state->timeout + ); + int status = TTLMAP_STATUS(res); + + if (status == TTLMAP_INSERTED || status == TTLMAP_REPLACED) { + memcpy(new_state, state, sizeof(*state)); + ttlmap_release_lock(lock); + } else if (status == TTLMAP_FOUND) { + ttlmap_release_lock(lock); + } + + return 0; +} + +int +balancer_st_resize( + struct balancer_session_table *st, size_t new_size, uint32_t now +) { + uint32_t gen = + atomic_load_explicit(&st->current_gen, memory_order_acquire); + + struct ttlmap *next = balancer_st_prev_map(st, gen); + int res = TTLMAP_INIT( + next, + &st->mctx, + struct balancer_session_id, + struct balancer_session_state, + new_size + ); + if (res != 0) { + return -1; + } + + /* Begin transition: odd gen signals dataplane to use both maps. */ + struct ttlmap *cur = balancer_st_cur_map(st, gen); + gen++; + atomic_store_explicit(&st->current_gen, gen, memory_order_release); + + /* Migrate sessions from old map to new map. */ + struct move_ctx ctx = {.dst = next, .now = now}; + TTLMAP_ITER( + cur, + struct balancer_session_id, + struct balancer_session_state, + now, + move_session_cb, + &ctx + ); + + /* End transition: even gen, dataplane uses new map only. */ + gen++; + atomic_store_explicit(&st->current_gen, gen, memory_order_release); + + TTLMAP_FREE(cur); + return 0; +} + +size_t +balancer_st_capacity(struct balancer_session_table *st) { + uint32_t gen = + atomic_load_explicit(&st->current_gen, memory_order_acquire); + return ttlmap_capacity(balancer_st_cur_map(st, gen)); +} + +struct buf_ctx { + struct balancer_session_entry *buf; + int count; +}; + +static int +fill_buf_cb( + struct balancer_session_id *id, + struct balancer_session_state *state, + void *userdata +) { + struct buf_ctx *ctx = userdata; + ctx->buf[ctx->count].id = *id; + ctx->buf[ctx->count].state = *state; + ctx->count++; + return 0; +} + +int +balancer_st_iter_next_bucket_buf( + struct balancer_session_table_iter *iter, + uint32_t now, + struct balancer_session_entry *buf, + int *count +) { + struct buf_ctx ctx = {.buf = buf, .count = 0}; + int ret = balancer_st_iter_next_bucket(iter, now, fill_buf_cb, &ctx); + *count = ctx.count; + return ret; +} + +void +balancer_st_iter_init( + struct balancer_session_table_iter *iter, + struct balancer_session_table *st +) { + iter->gen = + atomic_load_explicit(&st->current_gen, memory_order_acquire); + struct ttlmap *map = balancer_st_cur_map(st, iter->gen); + ttlmap_bucket_iter_init(&iter->ttlmap_iter, map); +} + +int +balancer_st_iter_next_bucket( + struct balancer_session_table_iter *iter, + uint32_t now, + balancer_st_iter_callback cb, + void *userdata +) { + return TTLMAP_ITER_NEXT( + &iter->ttlmap_iter, + struct balancer_session_id, + struct balancer_session_state, + now, + cb, + userdata + ); +} \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/sessions.h b/modules/balancer/controlplane/helpers/sessions.h new file mode 100644 index 000000000..3b94f1e4a --- /dev/null +++ b/modules/balancer/controlplane/helpers/sessions.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include "common/ttlmap/ttlmap.h" + +#include "modules/balancer/dataplane/types/session.h" + +struct agent; +struct balancer_packet_handler; + +/* + * Resize the session table by allocating a new ttlmap and migrating + * existing sessions into it. + * + * Uses a two-phase gen-bump protocol coordinated with the dataplane: + * + * gen (even) — steady state, dataplane uses map selected by gen. + * gen+1 (odd) — transition: the new map is active for writes, + * dataplane must look up both maps for existing sessions. + * gen+2 (even) — migration complete, old map freed. + * + * The dataplane checks current_gen atomically on every packet: + * - Even gen: use balancer_st_cur_map(st, gen) only. + * - Odd gen: write to the new map, but fall back to both maps for + * lookups so in-flight sessions are not lost. + * + * Returns 0 on success, -1 on allocation failure. + */ +int +balancer_st_resize( + struct balancer_session_table *st, size_t new_size, uint32_t now +); + +size_t +balancer_st_capacity(struct balancer_session_table *st); + +struct balancer_session_table_iter { + struct ttlmap_bucket_iter ttlmap_iter; + uint32_t gen; +}; + +void +balancer_st_iter_init( + struct balancer_session_table_iter *iter, + struct balancer_session_table *st +); + +typedef int (*balancer_st_iter_callback)( + struct balancer_session_id *id, + struct balancer_session_state *state, + void *userdata +); + +// Returns 0 on error +int +balancer_st_iter_next_bucket( + struct balancer_session_table_iter *iter, + uint32_t now, + balancer_st_iter_callback cb, + void *userdata +); + +struct balancer_session_entry { + struct balancer_session_id id; + struct balancer_session_state state; +}; + +// Advances iterator to next bucket, fills buf with non-expired entries. +// Sets *count to number of valid entries. Returns 0 when no more buckets. +int +balancer_st_iter_next_bucket_buf( + struct balancer_session_table_iter *iter, + uint32_t now, + struct balancer_session_entry *buf, + int *count +); \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/vs.c b/modules/balancer/controlplane/helpers/vs.c new file mode 100644 index 000000000..acabecdd9 --- /dev/null +++ b/modules/balancer/controlplane/helpers/vs.c @@ -0,0 +1,417 @@ +#include +#include +#include +#include +#include + +#include "api/agent.h" + +#include "common/big_array.h" +#include "common/memory.h" +#include "common/memory_address.h" +#include "common/rcu.h" +#include "common/rng.h" + +#include "filter/compiler.h" +#include "filter/rule.h" + +#include "lib/controlplane/agent/agent.h" +#include "lib/dataplane/config/zone.h" + +#include "modules/balancer/dataplane/dataplane.h" +#include "modules/balancer/dataplane/types/real.h" +#include "modules/balancer/dataplane/types/selector.h" +#include "modules/balancer/dataplane/types/sessions_tracker.h" +#include "modules/balancer/dataplane/types/vs.h" + +FILTER_COMPILER_DECLARE(ipv4_vs_acl, net4_fast_src, port_fast_src); +FILTER_COMPILER_DECLARE(ipv6_vs_acl, net6_fast_src, port_fast_src); + +static void +free_rules(struct filter_rule *rules, size_t count) { + for (size_t i = 0; i < count; ++i) { + free(rules[i].net4.dsts); + free(rules[i].net6.dsts); + free(rules[i].transport.srcs); + free(rules[i].transport.dsts); + free(rules[i].transport.protos); + } + free(rules); +} + +static int +compile_acl( + struct filter *filter, + struct filter_rule *rules, + size_t count, + struct memory_context *mctx, + int ipv6 +) { + const struct filter_rule **ptrs = calloc(count, sizeof(*ptrs)); + if (ptrs == NULL && count > 0) { + return -1; + } + for (size_t i = 0; i < count; ++i) { + ptrs[i] = &rules[i]; + } + + int res; + if (ipv6) { + res = filter_init(filter, ipv6_vs_acl, ptrs, count, mctx); + } else { + res = filter_init(filter, ipv4_vs_acl, ptrs, count, mctx); + } + + free(ptrs); + return res; +} + +static int +make_acl_net_rule( + struct filter_rule *rule, struct net *nets, uint32_t count, int ipv6 +) { + if (count == 0) { + return 0; + } + + if (ipv6) { + rule->net6.src_count = count; + rule->net6.srcs = calloc(count, sizeof(struct net6)); + if (rule->net6.srcs == NULL) { + return -1; + } + for (uint32_t j = 0; j < count; ++j) { + memcpy(&rule->net6.srcs[j], + &nets[j].v6, + sizeof(struct net6)); + } + } else { + rule->net4.src_count = count; + rule->net4.srcs = calloc(count, sizeof(struct net4)); + if (rule->net4.srcs == NULL) { + return -1; + } + for (uint32_t j = 0; j < count; ++j) { + memcpy(&rule->net4.srcs[j], + &nets[j].v4, + sizeof(struct net4)); + } + } + + return 0; +} + +static int +make_acl_port_rule( + struct filter_rule *rule, struct filter_port_range *prs, uint32_t count +) { + if (count > 0) { + rule->transport.src_count = count; + rule->transport.srcs = + calloc(count, sizeof(struct filter_port_range)); + if (rule->transport.srcs == NULL) { + return -1; + } + memcpy(rule->transport.srcs, + prs, + count * sizeof(struct filter_port_range)); + } else { + rule->transport.src_count = 1; + rule->transport.srcs = + calloc(1, sizeof(struct filter_port_range)); + if (rule->transport.srcs == NULL) { + return -1; + } + rule->transport.srcs[0].from = 0; + rule->transport.srcs[0].to = 65535; + } + + return 0; +} + +static int +make_acl_rules(struct filter_rule **out, struct balancer_vs *vs) { + uint32_t allowed_src_count = vs->allowed_sources_count; + struct balancer_vs_allowed_source *allowed_srcs = + ADDR_OF(&vs->allowed_sources); + + struct filter_rule *rules = + calloc(allowed_src_count, sizeof(struct filter_rule)); + if (rules == NULL && allowed_src_count > 0) { + return -1; + } + + int ipv6 = vs->ip_proto == IPPROTO_IPV6; + + for (uint32_t i = 0; i < allowed_src_count; ++i) { + struct balancer_vs_allowed_source *src = &allowed_srcs[i]; + struct filter_rule *rule = &rules[i]; + + if (make_acl_net_rule( + rule, ADDR_OF(&src->nets), src->nets_count, ipv6 + ) != 0) { + free_rules(rules, allowed_src_count); + return -1; + } + + if (make_acl_port_rule( + rule, + ADDR_OF(&src->port_ranges), + src->port_ranges_count + ) != 0) { + free_rules(rules, allowed_src_count); + return -1; + } + + rule->action = i; + } + + *out = rules; + return 0; +} + +int +balancer_vs_set_acl(struct balancer_vs *vs, struct agent *agent) { + struct memory_context *mctx = &agent->memory_context; + + uint32_t src_count = vs->allowed_sources_count; + if (src_count == 0) { + SET_OFFSET_OF(&vs->acl, NULL); + return 0; + } + + struct filter_rule *rules = NULL; + if (make_acl_rules(&rules, vs) != 0) { + return -2; + } + + struct filter *filter = memory_balloc(mctx, sizeof(struct filter)); + if (filter == NULL) { + free_rules(rules, src_count); + return -1; + } + + int ipv6 = vs->ip_proto == IPPROTO_IPV6; + if (compile_acl(filter, rules, src_count, mctx, ipv6) != 0) { + memory_bfree(mctx, filter, sizeof(struct filter)); + free_rules(rules, src_count); + return -1; + } + + free_rules(rules, src_count); + SET_OFFSET_OF(&vs->acl, filter); + + return 0; +} + +static int +selector_ensure(struct balancer_vs *vs, struct memory_context *mctx) { + if (ADDR_OF(&vs->selector) != NULL) { + return 0; + } + + struct balancer_real_selector *sel = + memory_balloc(mctx, sizeof(struct balancer_real_selector)); + if (sel == NULL) { + return -1; + } + + memset(sel, 0, sizeof(*sel)); + SET_OFFSET_OF(&vs->selector, sel); + + return 0; +} + +static int +ring_fill( + struct balancer_ring *ring, + struct balancer_real *reals, + uint32_t reals_count, + uint64_t total_weight, + struct memory_context *mctx, + size_t seed +) { + size_t ring_bytes = total_weight * sizeof(uint32_t); + if (big_array_init(&ring->real_ids, ring_bytes, mctx) != 0) { + return -1; + } + + /* Fill: each enabled real gets weight copies. */ + size_t pos = 0; + for (uint32_t i = 0; i < reals_count; ++i) { + if (!(reals[i].flags & balancer_real_enabled) || + (reals[i].flags & balancer_real_removed)) + continue; + + uint32_t w = reals[i].effective_weight; + + for (uint32_t j = 0; j < w; ++j) { + uint32_t *slot = big_array_get( + &ring->real_ids, pos * sizeof(uint32_t) + ); + *slot = i; + pos++; + } + } + + /* + * Shuffle the ring using a deterministic PRNG seeded by vs->stable_idx. + * Determinism is intentional: it guarantees that the same VS always + * produces the same ring layout across restarts and config updates, + * preserving consistent load distribution. + */ + uint64_t rng = 0xdeadbeef ^ seed; + for (size_t i = pos; i > 1; --i) { + uint32_t *a = big_array_get( + &ring->real_ids, (i - 1) * sizeof(uint32_t) + ); + uint32_t *b = big_array_get( + &ring->real_ids, (rng % i) * sizeof(uint32_t) + ); + uint32_t tmp = *a; + *a = *b; + *b = tmp; + + rng = rng_next(&rng); + } + + return 0; +} + +void +balancer_vs_free_acl(struct balancer_vs *vs, struct agent *agent) { + struct memory_context *mctx = &agent->memory_context; + + struct filter *filter = ADDR_OF(&vs->acl); + if (filter == NULL) { + return; + } + + int ipv6 = vs->ip_proto == IPPROTO_IPV6; + if (ipv6) { + filter_free(filter, ipv6_vs_acl); + } else { + filter_free(filter, ipv4_vs_acl); + } + memory_bfree(mctx, filter, sizeof(struct filter)); + SET_OFFSET_OF(&vs->acl, NULL); +} + +int +balancer_vs_update_real_selector( + struct balancer_vs *vs, rcu_t *rcu, struct agent *agent +) { + struct memory_context *mctx = &agent->memory_context; + + if (selector_ensure(vs, mctx) != 0) { + return -1; + } + + struct balancer_real_selector *sel = ADDR_OF(&vs->selector); + struct balancer_real *reals = ADDR_OF(&vs->reals); + uint32_t reals_count = vs->reals_count; + + sel->use_rr = (vs->flags & balancer_vs_round_robin) ? 1 : 0; + + /* Total weight of enabled reals. */ + uint64_t total_weight = 0; + for (uint32_t i = 0; i < reals_count; ++i) { + if ((reals[i].flags & balancer_real_enabled) && + !(reals[i].flags & balancer_real_removed)) { + total_weight += reals[i].effective_weight; + } + } + + /* Prepare new ring in the inactive slot. */ + size_t cur_ring = + atomic_load_explicit(&sel->ring_id, memory_order_acquire); + size_t new_ring = cur_ring ^ 1; + big_array_free(&sel->rings[new_ring].real_ids); + + if (total_weight == 0) { + memset(&sel->rings[new_ring], 0, sizeof(struct balancer_ring)); + } else if (ring_fill( + &sel->rings[new_ring], + reals, + reals_count, + total_weight, + mctx, + vs->stable_idx + ) != 0) { + return -1; + } + + /* Swap using packet handler's RCU. */ + rcu_update(rcu, (atomic_ulong *)&sel->ring_id, new_ring); + + /* Free old ring. */ + big_array_free(&sel->rings[cur_ring].real_ids); + return 0; +} + +void +balancer_vs_free_real_selector(struct balancer_vs *vs, struct agent *agent) { + struct memory_context *mctx = &agent->memory_context; + + struct balancer_real_selector *sel = ADDR_OF(&vs->selector); + if (sel == NULL) { + return; + } + + big_array_free(&sel->rings[0].real_ids); + big_array_free(&sel->rings[1].real_ids); + memory_bfree(mctx, sel, sizeof(struct balancer_real_selector)); + SET_OFFSET_OF(&vs->selector, NULL); +} + +int +balancer_vs_set_session_trackers(struct balancer_vs *vs, struct agent *agent) { + struct memory_context *mctx = &agent->memory_context; + struct balancer_real *reals = ADDR_OF(&vs->reals); + + const size_t workers = ADDR_OF(&agent->dp_config)->worker_count; + + for (size_t real_idx = 0; real_idx < vs->reals_count; ++real_idx) { + struct balancer_real *real = &reals[real_idx]; + if (real->tracker_shards != NULL) { + continue; + } + struct balancer_sessions_tracker_shard *shards = memory_balloc( + mctx, + sizeof(struct balancer_sessions_tracker_shard) * workers + ); + if (shards == NULL) { + return -1; + } + memset(shards, + 0, + sizeof(struct balancer_sessions_tracker_shard) * workers + ); + SET_OFFSET_OF(&real->tracker_shards, shards); + } + + return 0; +} + +void +balancer_vs_free_session_trackers(struct balancer_vs *vs, struct agent *agent) { + struct memory_context *mctx = &agent->memory_context; + struct balancer_real *reals = ADDR_OF(&vs->reals); + + const size_t workers = ADDR_OF(&agent->dp_config)->worker_count; + + for (size_t real_idx = 0; real_idx < vs->reals_count; ++real_idx) { + struct balancer_real *real = &reals[real_idx]; + struct balancer_sessions_tracker_shard *shards = + ADDR_OF(&real->tracker_shards); + if (shards != NULL) { + memory_bfree( + mctx, + shards, + sizeof(struct balancer_sessions_tracker_shard) * + workers + ); + SET_OFFSET_OF(&real->tracker_shards, NULL); + } + } +} \ No newline at end of file diff --git a/modules/balancer/controlplane/helpers/vs.h b/modules/balancer/controlplane/helpers/vs.h new file mode 100644 index 000000000..07de552b8 --- /dev/null +++ b/modules/balancer/controlplane/helpers/vs.h @@ -0,0 +1,46 @@ +/* + * Per-VS helpers: ACL compilation, real selector (ring) management, + * and per-real session tracker allocation. + * + * Error convention: functions returning int use 0 for success and + * non-zero for failure. + * -1 shared-memory allocation failure + * -2 heap allocation failure + */ +#pragma once + +#include + +struct balancer_vs; +struct agent; +struct rcu; + +/* Compile an ACL filter from the VS's allowed_sources list. + * Sets vs->acl to NULL when allowed_sources_count == 0. + * Returns 0 on success, -1 on shared memory allocation failure, + * -2 on heap allocation failure. */ +int +balancer_vs_set_acl(struct balancer_vs *vs, struct agent *agent); + +void +balancer_vs_free_acl(struct balancer_vs *vs, struct agent *agent); + +/* Rebuild the weighted ring selector from current effective_weight values. + * Uses RCU to swap the active ring without disrupting dataplane lookups. + * Returns 0 on success, -1 on failure. */ +int +balancer_vs_update_real_selector( + struct balancer_vs *vs, struct rcu *rcu, struct agent *agent +); + +void +balancer_vs_free_real_selector(struct balancer_vs *vs, struct agent *agent); + +/* Allocate per-worker session tracker shards for each real that doesn't + * already have them (tracker_shards != NULL). + * Returns 0 on success, -1 on allocation failure. */ +int +balancer_vs_set_session_trackers(struct balancer_vs *vs, struct agent *agent); + +void +balancer_vs_free_session_trackers(struct balancer_vs *vs, struct agent *agent); \ No newline at end of file diff --git a/modules/balancer/controlplane/meson.build b/modules/balancer/controlplane/meson.build index 38b81c082..5f7d9916c 100644 --- a/modules/balancer/controlplane/meson.build +++ b/modules/balancer/controlplane/meson.build @@ -1,3 +1,2 @@ -subdir('state') -subdir('handler') -subdir('api') +subdir('balancerpb') +subdir('helpers') \ No newline at end of file diff --git a/modules/balancer/controlplane/metrics.go b/modules/balancer/controlplane/metrics.go new file mode 100644 index 000000000..c8af818e4 --- /dev/null +++ b/modules/balancer/controlplane/metrics.go @@ -0,0 +1,383 @@ +package balancer + +import ( + "strings" + "time" + + "github.com/yanet-platform/yanet2/common/commonpb" + "github.com/yanet-platform/yanet2/common/go/metrics" + "github.com/yanet-platform/yanet2/common/go/relptr" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" +) + +// commonCounters maps per-position metric names to getters over the four +// "global" counter groups (cmn, l4, iv4, iv6). +var commonCounters = []struct { + name string + getter func(*CommonStats, *L4Stats, *IcmpStats, *IcmpStats) uint64 +}{ + { + name: "incoming_bits", + getter: func(c *CommonStats, _ *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return c.Incoming_bytes * 8 + }, + }, + { + name: "incoming_packets", + getter: func(c *CommonStats, _ *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return c.Incoming_packets + }, + }, + { + name: "outgoing_bits", + getter: func(c *CommonStats, _ *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return c.Outgoing_bytes * 8 + }, + }, + { + name: "outgoing_packets", + getter: func(c *CommonStats, _ *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return c.Outgoing_packets + }, + }, + { + name: "l4_incoming_packets", + getter: func(_ *CommonStats, l *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return l.Incoming_packets + }, + }, + { + name: "l4_outgoing_packets", + getter: func(_ *CommonStats, l *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return l.Outgoing_packets + }, + }, + { + name: "l4_select_vs_failed", + getter: func(_ *CommonStats, l *L4Stats, _ *IcmpStats, _ *IcmpStats) uint64 { + return l.Select_vs_failed + }, + }, + { + name: "icmp_ipv4_incoming_packets", + getter: func(_ *CommonStats, _ *L4Stats, i4 *IcmpStats, _ *IcmpStats) uint64 { + return i4.Incoming_packets + }, + }, + { + name: "icmp_ipv4_forwarded_packets", + getter: func(_ *CommonStats, _ *L4Stats, i4 *IcmpStats, _ *IcmpStats) uint64 { + return i4.Forwarded_packets + }, + }, + { + name: "icmp_ipv4_packet_clones_sent", + getter: func(_ *CommonStats, _ *L4Stats, i4 *IcmpStats, _ *IcmpStats) uint64 { + return i4.Packet_clones_sent + }, + }, + { + name: "icmp_ipv4_packet_clones_received", + getter: func(_ *CommonStats, _ *L4Stats, i4 *IcmpStats, _ *IcmpStats) uint64 { + return i4.Packet_clones_received + }, + }, + { + name: "icmp_ipv4_packet_clone_failures", + getter: func(_ *CommonStats, _ *L4Stats, i4 *IcmpStats, _ *IcmpStats) uint64 { + return i4.Packet_clone_failures + }, + }, + { + name: "icmp_ipv6_incoming_packets", + getter: func(_ *CommonStats, _ *L4Stats, _ *IcmpStats, i6 *IcmpStats) uint64 { + return i6.Incoming_packets + }, + }, + { + name: "icmp_ipv6_forwarded_packets", + getter: func(_ *CommonStats, _ *L4Stats, _ *IcmpStats, i6 *IcmpStats) uint64 { + return i6.Forwarded_packets + }, + }, + { + name: "icmp_ipv6_packet_clones_sent", + getter: func(_ *CommonStats, _ *L4Stats, _ *IcmpStats, i6 *IcmpStats) uint64 { + return i6.Packet_clones_sent + }, + }, + { + name: "icmp_ipv6_packet_clones_received", + getter: func(_ *CommonStats, _ *L4Stats, _ *IcmpStats, i6 *IcmpStats) uint64 { + return i6.Packet_clones_received + }, + }, + { + name: "icmp_ipv6_packet_clone_failures", + getter: func(_ *CommonStats, _ *L4Stats, _ *IcmpStats, i6 *IcmpStats) uint64 { + return i6.Packet_clone_failures + }, + }, +} + +// vsCounters maps per-VS metric names to getters over VsStats. +var vsCounters = []struct { + name string + getter func(*VsStats) uint64 +}{ + {"vs_incoming_bits", func(s *VsStats) uint64 { return s.Incoming_bytes * 8 }}, + {"vs_incoming_packets", func(s *VsStats) uint64 { return s.Incoming_packets }}, + {"vs_outgoing_bits", func(s *VsStats) uint64 { return s.Outgoing_bytes * 8 }}, + {"vs_outgoing_packets", func(s *VsStats) uint64 { return s.Outgoing_packets }}, + {"vs_created_sessions", func(s *VsStats) uint64 { return s.Created_sessions }}, + {"vs_packet_src_not_allowed", func(s *VsStats) uint64 { return s.Packet_src_not_allowed }}, + {"vs_no_reals", func(s *VsStats) uint64 { return s.No_reals }}, + {"vs_session_table_overflow", func(s *VsStats) uint64 { return s.Session_table_overflow }}, + {"vs_echo_icmp_packets", func(s *VsStats) uint64 { return s.Echo_icmp_packets }}, + {"vs_error_icmp_packets", func(s *VsStats) uint64 { return s.Error_icmp_packets }}, + {"vs_real_is_disabled", func(s *VsStats) uint64 { return s.Real_is_disabled }}, + {"vs_real_is_removed", func(s *VsStats) uint64 { return s.Real_is_removed }}, + {"vs_not_rescheduled_packets", func(s *VsStats) uint64 { return s.Not_rescheduled_packets }}, + {"vs_broadcasted_icmp_packets", func(s *VsStats) uint64 { return s.Broadcasted_icmp_packets }}, +} + +// realCounters maps per-real metric names to getters over RealStats. +var realCounters = []struct { + name string + getter func(*RealStats) uint64 +}{ + {"real_incoming_bits", func(s *RealStats) uint64 { return s.Bytes * 8 }}, + {"real_incoming_packets", func(s *RealStats) uint64 { return s.Packets }}, + {"real_created_sessions", func(s *RealStats) uint64 { return s.Created_sessions }}, + {"real_icmp_error_packets", func(s *RealStats) uint64 { return s.Error_icmp_packets }}, + {"packets_real_disabled", func(s *RealStats) uint64 { return s.Packets_real_disabled }}, +} + +type methodMetrics struct { + latencies *metrics.MetricMap[*metrics.Histogram] +} + +func newMethodMetrics() methodMetrics { + return methodMetrics{ + latencies: metrics.NewMetricMap[*metrics.Histogram](), + } +} + +func (m *methodMetrics) collect() []*commonpb.Metric { + return commonpb.MetricRefsToProto(m.latencies.Metrics()) +} + +var defaultLatencyBoundsMS = []float64{ + 1, + 2, + 5, + 10, + 25, + 50, + 75, + 100, + 150, + 200, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + 1500, + 2000, + 3000, + 4000, + 5000, +} + +type methodMetricsTracker struct { + metricID metrics.MetricID + startTime time.Time + metrics methodMetrics + latencies []float64 +} + +func newMetricsTracker( + handlerName string, + methodMetrics methodMetrics, + latencies []float64, + labels metrics.Labels, +) *methodMetricsTracker { + id := metrics.MetricID{ + Name: handlerName, + Labels: labels, + } + return &methodMetricsTracker{ + metricID: id, + startTime: time.Now(), + metrics: methodMetrics, + latencies: latencies, + } +} + +func (m *methodMetricsTracker) Fix() { + duration := time.Since(m.startTime) + m.metrics.latencies.GetOrCreate(m.metricID, func() *metrics.Histogram { + return metrics.NewHistogram(m.latencies) + }).Observe(float64(duration.Milliseconds())) +} + +func collectCounterMetrics( + services []VS, + counters []yanet.CounterInfo, + refLabels []*commonpb.Label, +) []*commonpb.Metric { + var result []*commonpb.Metric + + var ( + cmn *CommonStats + l4s *L4Stats + iv4 *IcmpStats + iv6 *IcmpStats + ) + + for _, counter := range counters { + name := counter.Name + switch { + case name == "cmn": + cmn = commonStats(counter.Values) + case name == "l4": + l4s = l4Stats(counter.Values) + case name == "iv4": + iv4 = icmpStats(counter.Values) + case name == "iv6": + iv6 = icmpStats(counter.Values) + case strings.HasPrefix(name, "vs_"): + result = append(result, collectVSMetrics(services, counter, refLabels)...) + case strings.HasPrefix(name, "rl_"): + result = append(result, collectRealMetrics(services, counter, refLabels)...) + case strings.HasPrefix(name, "acl_"): + result = append(result, collectACLMetrics(services, counter, refLabels)...) + } + } + + for _, c := range commonCounters { + result = append(result, &commonpb.Metric{ + Name: c.name, + Labels: refLabels, + Value: &commonpb.Metric_Counter{Counter: c.getter(cmn, l4s, iv4, iv6)}, + }) + } + + return result +} + +func collectVSMetrics( + services []VS, + counter yanet.CounterInfo, + refLabels []*commonpb.Label, +) []*commonpb.Metric { + vsStableIndex, ok := vsIndexFromCounterName(counter.Name) + if !ok { + return nil + } + vs, ok := resolveVS(services, vsStableIndex) + if !ok { + return nil + } + vsLabels := append(vs.labels(), refLabels...) + stats := vsStats(counter.Values) + result := make([]*commonpb.Metric, 0, len(vsCounters)) + for _, c := range vsCounters { + result = append(result, &commonpb.Metric{ + Name: c.name, + Labels: vsLabels, + Value: &commonpb.Metric_Counter{Counter: c.getter(stats)}, + }) + } + return result +} + +func collectRealMetrics( + services []VS, + counter yanet.CounterInfo, + refLabels []*commonpb.Label, +) []*commonpb.Metric { + vsStableIndex, realStableIndex, ok := realIndexFromCounterName(counter.Name) + if !ok { + return nil + } + vs, ok := resolveVS(services, vsStableIndex) + if !ok { + return nil + } + r, ok := resolveReal(vs, realStableIndex) + if !ok { + return nil + } + realLabels := append(r.labels(), refLabels...) + realLabels = append(realLabels, vs.labels()...) + stats := realStats(counter.Values) + result := make([]*commonpb.Metric, 0, len(realCounters)) + for _, c := range realCounters { + result = append(result, &commonpb.Metric{ + Name: c.name, + Labels: realLabels, + Value: &commonpb.Metric_Counter{Counter: c.getter(stats)}, + }) + } + return result +} + +func collectACLMetrics( + services []VS, + counter yanet.CounterInfo, + refLabels []*commonpb.Label, +) []*commonpb.Metric { + vsStableIndex, tag, ok := aclTagFromCounterName(counter.Name) + if !ok { + return nil + } + vs, ok := resolveVS(services, vsStableIndex) + if !ok { + return nil + } + aclLabels := append(vs.labels(), refLabels...) + aclLabels = append(aclLabels, &commonpb.Label{Name: "acl_tag", Value: tag}) + return []*commonpb.Metric{{ + Name: "acl_passes", + Labels: aclLabels, + Value: &commonpb.Metric_Counter{Counter: aggregateACLPasses(counter.Values)}, + }} +} + +func collectSessionMetrics( + services []VS, + workers uint32, + now time.Time, + refLabels []*commonpb.Label, +) []*commonpb.Metric { + var result []*commonpb.Metric + + for vsIdx := range services { + vs := &services[vsIdx] + if vs.isRemoved() { + continue + } + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + for realIdx := range reals { + r := &reals[realIdx] + if r.isRemoved() { + continue + } + active, _ := r.sessions(workers, now) + realLabels := append(vs.labels(), refLabels...) + realLabels = append(realLabels, r.labels()...) + result = append(result, &commonpb.Metric{ + Name: "real_active_sessions", + Labels: realLabels, + Value: &commonpb.Metric_Gauge{Gauge: float64(active)}, + }) + } + } + + return result +} diff --git a/modules/balancer/controlplane/mod.go b/modules/balancer/controlplane/mod.go new file mode 100644 index 000000000..67eade1b5 --- /dev/null +++ b/modules/balancer/controlplane/mod.go @@ -0,0 +1,65 @@ +package balancer + +import ( + "fmt" + + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +type Module struct { + cfg *Config + shm *yanet.SharedMemory + service *Service +} + +func NewModule( + cfg *Config, + log *zap.SugaredLogger, +) (*Module, error) { + log = log.With(zap.String("module", "balancerpb.Balancer")) + + shm, err := yanet.AttachSharedMemory(cfg.MemoryPath.Unwrap()) + if err != nil { + return nil, fmt.Errorf("failed to attach shared memory: %w", err) + } + + svc, err := NewService(shm, cfg.InstanceID, cfg.MemoryRequirements.Unwrap(), log) + if err != nil { + _ = shm.Detach() + return nil, fmt.Errorf("failed to create balancer service: %w", err) + } + + return &Module{ + cfg: cfg, + shm: shm, + service: svc, + }, nil +} + +func (m *Module) Name() string { + return "balancer" +} + +func (m *Module) Endpoint() string { + return m.cfg.Endpoint.Unwrap() +} + +func (m *Module) ServicesNames() []string { + return []string{"balancerpb.Balancer"} +} + +func (m *Module) RegisterService(server *grpc.Server) { + balancerpb.RegisterBalancerServer(server, m.service) +} + +func (m *Module) Close() error { + m.service.mu.Lock() + defer m.service.mu.Unlock() + + m.service.agent.Close() + + return m.shm.Detach() +} diff --git a/modules/balancer/controlplane/net.go b/modules/balancer/controlplane/net.go new file mode 100644 index 000000000..9b6cb0c14 --- /dev/null +++ b/modules/balancer/controlplane/net.go @@ -0,0 +1,86 @@ +package balancer + +import ( + "unsafe" + + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +// writeNet4Addr writes a 4-byte address to a shared-memory Net4Addr. +func writeNet4Addr(dst *Net4Addr, addr []byte) { + copy(dst.Bytes[:], addr) +} + +// writeNet6Addr writes a 16-byte address to a shared-memory Net6Addr. +func writeNet6Addr(dst *Net6Addr, addr []byte) { + copy(dst.Bytes[:], addr) +} + +// Bytes returns the address bytes for the given IP protocol. +// For IPv6, this reinterprets the C union (V4 [4]byte + 12 byte padding = 16 bytes) +// as a contiguous 16-byte array via unsafe, matching the C union layout. +func (a *NetAddr) Bytes(ipProto int) []byte { + if ipProto == ipprotoIP { + return a.V4.Bytes[:] + } + return (*[16]byte)(unsafe.Pointer(&a.V4.Bytes[0]))[:] +} + +func (a *Net) AddrBytes(ipProto int) []byte { + if ipProto == ipprotoIP { + return a.V4.Addr[:] + } + // Net is defined as Net4 (8 bytes) + 24 bytes padding = 32 bytes total, + // which is the same layout as Net6 (addr[16] + mask[16]). We write addr + // and mask directly into the 32-byte buffer via an unsafe cast instead of + // going through the Net4 fields, which would only cover the first 8 bytes. + b := (*[32]byte)(unsafe.Pointer(&a.V4.Addr[0])) + return b[:16] +} + +func (a *Net) MaskBytes(ipProto int) []byte { + if ipProto == ipprotoIP { + return a.V4.Mask[:] + } + b := (*[32]byte)(unsafe.Pointer(&a.V4.Addr[0])) + return b[16:] +} + +func writeNetAddr(dst *NetAddr, addr []byte) { + if len(addr) == 4 { + copy(dst.Bytes(ipprotoIP), addr) + } else { + copy(dst.Bytes(ipprotoIPv6), addr) + } +} + +// writeNet writes a filterpb.IPNet (addr + mask) to a shared-memory Net union. +// The Net union is 32 bytes: for IPv4 it uses Net4 (addr[4] + mask[4]), +// for IPv6 it uses Net6 (addr[16] + mask[16]). +// Src must be IPv4 or IPv6. +// +// The address is pre-masked (addr[i] &= mask[i]) before writing to satisfy +// the dataplane invariant documented in real.h: the tunnel code relies on +// addr having zero bits in every position where mask is zero. +func writeNet(dst *Net, src *filterpb.IPNet) { + addr := src.Addr + mask := src.Mask + proto := ipprotoIP + if len(mask) == 16 { + proto = ipprotoIPv6 + } + dstAddr := dst.AddrBytes(proto) + for i := range len(addr) { + dstAddr[i] = addr[i] & mask[i] + } + copy(dst.MaskBytes(proto), mask) +} + +// transportProtoToC converts a protobuf TransportProto to the C constant. +func transportProtoToC(proto balancerpb.TransportProto) uint8 { + if proto == balancerpb.TransportProto_TCP { + return ipprotoTCP + } + return ipprotoUDP +} diff --git a/modules/balancer/controlplane/real.go b/modules/balancer/controlplane/real.go new file mode 100644 index 000000000..be7c4ef84 --- /dev/null +++ b/modules/balancer/controlplane/real.go @@ -0,0 +1,200 @@ +package balancer + +import ( + "net" + "time" + + "github.com/yanet-platform/yanet2/common/commonpb" + "github.com/yanet-platform/yanet2/common/go/relptr" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// realKey is a hashable identifier for a real server, usable as a map key. +type realKey struct { + ip [16]byte + ipLen uint8 +} + +func makeRealKey(id *balancerpb.RelativeRealIdentifier) realKey { + var k realKey + k.ipLen = uint8(len(id.Ip)) + copy(k.ip[:], id.Ip) + return k +} + +func (r *Real) key() realKey { + var k realKey + proto := ipprotoIP + addrLen := 4 + if r.isIPv6() { + proto = ipprotoIPv6 + addrLen = 16 + } + k.ipLen = uint8(addrLen) + copy(k.ip[:], r.Addr.Bytes(proto)) + return k +} + +// populate fills a Real from a protobuf definition and optional previous state. +// Precondition: if inheritEffectiveWeight is true, prevReal must be non-nil +// (the effective weight is read from it). +func (r *Real) populate( + pb *balancerpb.Real, + stableIdx uint64, + prevReal *Real, + inheritEffectiveWeight bool, +) { + r.Stable_idx = stableIdx + r.Flags = 0 + r.Weight = uint32(pb.Weight) + writeNetAddr(&r.Addr, pb.Id.Ip) + writeNet(&r.Src, pb.Src) + if len(pb.Id.Ip) == 16 { + r.Flags |= RealFlagIPv6 + } + if prevReal != nil && prevReal.isEnabled() { + r.Flags |= RealFlagEnabled + } + if inheritEffectiveWeight { + r.Effective_weight = uint32(prevReal.Effective_weight) + } else { + r.Effective_weight = uint32(pb.Weight) + } + if prevReal != nil { + relptr.Equate(&r.Tracker_shards, &prevReal.Tracker_shards) + } +} + +func (r *Real) isRemoved() bool { + return r.Flags&RealFlagRemoved != 0 +} + +func (r *Real) isEnabled() bool { + return r.Flags&RealFlagEnabled != 0 +} + +func (r *Real) epoch() uint32 { + return epochOf(r.Stable_idx) +} + +func (r *Real) id() *balancerpb.RelativeRealIdentifier { + proto := ipprotoIP + if r.isIPv6() { + proto = ipprotoIPv6 + } + return &balancerpb.RelativeRealIdentifier{ + Ip: r.Addr.Bytes(proto), + Port: 0, + } +} + +func (r *Real) state(workers uint32, now time.Time) *balancerpb.RealState { + activeSessions, lastPacketTimestamp := r.sessions(workers, now) + return &balancerpb.RealState{ + Id: r.id(), + Weight: uint64(r.Weight), + EffectiveWeight: uint64(r.Effective_weight), + Enabled: r.isEnabled(), + ActiveSessions: activeSessions, + LastPacketTimestamp: timestamppb.New(lastPacketTimestamp), + } +} + +// placeExistingReals places reals that exist in both previous and new configs into their +// original slot positions in targetReals. Each placed real is deleted from pbRealIndex, +// so after this call pbRealIndex contains only genuinely new reals for placeNewReals. +func placeExistingReals( + pbReals []*balancerpb.Real, + targetReals []Real, + prevReals []Real, + pbRealIndex map[realKey]int, + inheritEffectiveWeights bool, +) (realsUnchanged bool) { + realsUnchanged = true + for idx := range prevReals { + prevReal := &prevReals[idx] + if prevReal.isRemoved() { + continue + } + + k := prevReal.key() + if _, ok := pbRealIndex[k]; !ok { + realsUnchanged = false + continue + } + pbIdx := pbRealIndex[k] + delete(pbRealIndex, k) + + pbReal := pbReals[pbIdx] + if pbReal.Weight != uint32(prevReal.Weight) { + realsUnchanged = false + } + targetReals[idx].populate(pbReal, prevReal.Stable_idx, prevReal, inheritEffectiveWeights) + } + return realsUnchanged +} + +// placeNewReals places genuinely new reals (remaining in pbRealIndex after placeExistingReals) +// into removed (empty) slots in targetReals. +// Invariant: same slot-availability guarantee as placeNewVS — see its comment. +func placeNewReals( + pbReals []*balancerpb.Real, + targetReals []Real, + prevReals []Real, + pbRealIndex map[realKey]int, +) (noNewReals bool) { + noNewReals = true + + nextRemoved := 0 + + for idx := range pbReals { + k := makeRealKey(pbReals[idx].Id) + if _, ok := pbRealIndex[k]; !ok { + continue + } + noNewReals = false + + for !targetReals[nextRemoved].isRemoved() { + nextRemoved++ + } + + epoch := uint32(0) + if nextRemoved < len(prevReals) { + epoch = prevReals[nextRemoved].epoch() + 1 + } + + stableIdx := makeStableIdx(epoch, uint32(nextRemoved)) + targetReals[nextRemoved].populate(pbReals[idx], stableIdx, nil, false) + } + + return noNewReals +} + +func formatRealAddr(addr []byte) string { + return net.IP(addr).String() +} + +func realIDToString(id *balancerpb.RelativeRealIdentifier) string { + return formatRealAddr(id.Ip) +} + +func (r *Real) isIPv6() bool { + return r.Flags&RealFlagIPv6 != 0 +} + +func (r *Real) ip() []byte { + proto := ipprotoIP + if r.isIPv6() { + proto = ipprotoIPv6 + } + return r.Addr.Bytes(proto) +} + +func (r *Real) String() string { + return formatRealAddr(r.ip()) +} + +func (r *Real) labels() []*commonpb.Label { + return []*commonpb.Label{{Name: "real_ip", Value: formatRealAddr(r.ip())}} +} diff --git a/modules/balancer/controlplane/refresh.go b/modules/balancer/controlplane/refresh.go new file mode 100644 index 000000000..bff7a6fab --- /dev/null +++ b/modules/balancer/controlplane/refresh.go @@ -0,0 +1,203 @@ +package balancer + +import ( + "context" + "math" + "sync" + "time" + + "github.com/yanet-platform/yanet2/common/go/relptr" +) + +type Refresher struct { + balancer *Balancer + mu *sync.Mutex + parentCtx context.Context + cancel context.CancelFunc + done chan struct{} + refreshPeriod time.Duration +} + +func NewRefresher(balancer *Balancer, mu *sync.Mutex) *Refresher { + return &Refresher{ + balancer: balancer, + mu: mu, + refreshPeriod: balancer.config.State.RefreshPeriod.AsDuration(), + } +} + +func (r *Refresher) Run(ctx context.Context) { + if r.refreshPeriod == 0 { + return + } + r.parentCtx = ctx + derived, cancel := context.WithCancel(ctx) + r.cancel = cancel + r.done = make(chan struct{}) + go func() { + defer close(r.done) + ticker := time.NewTicker(r.refreshPeriod) + defer ticker.Stop() + for { + select { + case <-derived.Done(): + return + case <-ticker.C: + r.mu.Lock() + r.refresh() + r.mu.Unlock() + } + } + }() +} + +func (r *Refresher) Stop() { + if r.cancel == nil { + return + } + r.cancel() + <-r.done + r.cancel = nil +} + +func (r *Refresher) UpdateRefreshPeriod(period time.Duration) { + if period == r.refreshPeriod { + return + } + parentCtx := r.parentCtx + r.Stop() + r.refreshPeriod = period + if parentCtx == nil { + parentCtx = context.Background() + } + r.Run(parentCtx) +} + +func (r *Refresher) refresh() { + b := r.balancer + ph := b.handler + st := relptr.Deref(&ph.Session_table) + if st == nil { + return + } + workers := st.Workers + now := time.Now() + + services := relptr.Slice(&ph.Vs, ph.Vs_count) + + var totalActiveSessions uint64 + for vsIdx := range services { + vs := &services[vsIdx] + if vs.isRemoved() { + continue + } + active := vsCalcEffectiveWeights( + vs, + workers, + now, + float64(ph.Wlc_power), + uint32(ph.Wlc_max_weight), + ) + totalActiveSessions += active + if vs.isWLC() { + if err := vs.updateRealSelector(&ph.Rcu, b.agent); err != nil { + b.log.Errorw("failed to update real selector", "vs", vs, "error", err) + } + } + } + + for { + capacity := uint64(st.capacity()) + if capacity > 0 && + float64( + totalActiveSessions, + ) > float64( + capacity, + )*float64( + ph.Session_table_max_load_factor, + ) { + newSize := int(capacity * 2) + if err := b.handler.resizeSessionTable(st, newSize, now); err != nil { + b.log.Errorw("failed to resize session table", "error", err) + break + } + } else { + break + } + } +} + +// vsCalcEffectiveWeights computes active session count for the VS and, when +// the VS has WLC enabled, updates effective_weight on each enabled real using +// the following formula: +// +// ratio = (real_sessions * total_weight) / (total_sessions * real_weight) +// wlc_factor = max(1.0, power * (1.0 - ratio)) +// eff_weight = min(real_weight * wlc_factor, max_weight) +// +// It returns the total number of active sessions across all non-removed reals. +func vsCalcEffectiveWeights( + vs *VS, + workers uint32, + now time.Time, + power float64, + maxWeight uint32, +) uint64 { + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + + // First pass: gather session counts and weight sums. + type realInfo struct { + sessions uint64 + } + infos := make([]realInfo, len(reals)) + + var totalSessions uint64 + var totalWeight uint64 + for i := range reals { + r := &reals[i] + if r.isRemoved() { + continue + } + active, _ := r.sessions(workers, now) + infos[i].sessions = active + totalSessions += active + if r.isEnabled() { + totalWeight += uint64(r.Weight) + } + } + + if !vs.isWLC() { + return totalSessions + } + + // Second pass: update effective weights. + for i := range reals { + r := &reals[i] + if r.isRemoved() || !r.isEnabled() { + continue + } + if r.Weight == 0 { + r.Effective_weight = 0 + continue + } + + var eff float64 + if totalSessions == 0 || totalWeight == 0 { + // No sessions yet: treat as fully unloaded. + eff = float64(r.Weight) + } else { + ratio := float64(infos[i].sessions) * float64(totalWeight) / + (float64(totalSessions) * float64(r.Weight)) + ratio = math.Min(ratio, 1.0) + wlcFactor := math.Max(1.0, power*(1.0-ratio)) + eff = float64(r.Weight) * wlcFactor + } + + if maxWeight > 0 { + eff = math.Min(eff, float64(maxWeight)) + } + r.Effective_weight = uint32(eff) + } + + return totalSessions +} diff --git a/modules/balancer/controlplane/restore.go b/modules/balancer/controlplane/restore.go new file mode 100644 index 000000000..332849686 --- /dev/null +++ b/modules/balancer/controlplane/restore.go @@ -0,0 +1,238 @@ +package balancer + +import ( + "time" + + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/common/go/relptr" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/durationpb" +) + +// restoreBalancerFromPacketHandler reconstructs a Balancer Go object from an +// existing PacketHandler in shared memory. Called on control-plane restart to +// re-attach to already-running packet handlers without re-allocating resources. +// +// The returned Balancer shares the same packet handler memory; no new shared +// memory is allocated. +func restoreBalancerFromPacketHandler( + agent *Agent, + ph *PacketHandler, + log *zap.SugaredLogger, +) *Balancer { + b := &Balancer{ + handler: ph, + agent: agent, + config: restoreConfigFromPacketHandler(ph), + log: log, + } + b.buildIndexes() + return b +} + +func restoreConfigFromPacketHandler(ph *PacketHandler) *balancerpb.BalancerConfig { + return &balancerpb.BalancerConfig{ + PacketHandler: restorePacketHandlerConfig(ph), + State: restoreStateConfig(ph), + } +} + +func restorePacketHandlerConfig(ph *PacketHandler) *balancerpb.PacketHandlerConfig { + return &balancerpb.PacketHandlerConfig{ + SourceAddressV4: append([]byte(nil), ph.Source_v4.Bytes[:]...), + SourceAddressV6: append([]byte(nil), ph.Source_v6.Bytes[:]...), + DecapAddresses: restoreDecapAddrs(ph), + SessionsTimeouts: restoreSessionTimeouts(&ph.Session_timeouts), + Vs: restoreVirtualServices(ph), + } +} + +func restoreSessionTimeouts(t *SessionTimeouts) *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{ + TcpSynAck: uint32(t.Syn_ack), + TcpSyn: uint32(t.Syn), + TcpFin: uint32(t.Fin), + Tcp: uint32(t.Tcp), + Udp: uint32(t.Udp), + } +} + +func restoreDecapAddrs(ph *PacketHandler) [][]byte { + total := int(ph.Decap_v4_count) + int(ph.Decap_v6_count) + addrs := make([][]byte, 0, total) + for _, a := range relptr.Slice(&ph.Decap_v4, ph.Decap_v4_count) { + addr := make([]byte, 4) + copy(addr, a.Bytes[:]) + addrs = append(addrs, addr) + } + for _, a := range relptr.Slice(&ph.Decap_v6, ph.Decap_v6_count) { + addr := make([]byte, 16) + copy(addr, a.Bytes[:]) + addrs = append(addrs, addr) + } + return addrs +} + +func restoreVirtualServices(ph *PacketHandler) []*balancerpb.VirtualService { + vsSlice := relptr.Slice(&ph.Vs, ph.Vs_count) + result := make([]*balancerpb.VirtualService, 0) + for i := range vsSlice { + if vsSlice[i].isRemoved() { + continue + } + result = append(result, restoreVS(&vsSlice[i])) + } + return result +} + +func restoreVS(vs *VS) *balancerpb.VirtualService { + isV6 := vs.Ip_proto == ipprotoIPv6 + + addr := append([]byte(nil), vs.Addr.Bytes(int(vs.Ip_proto))...) + + proto := balancerpb.TransportProto_UDP + if vs.Transport_proto == ipprotoTCP { + proto = balancerpb.TransportProto_TCP + } + + return &balancerpb.VirtualService{ + Id: &balancerpb.VsIdentifier{ + Addr: addr, + Port: uint32(vs.Port), + Proto: proto, + }, + Flags: vs.flags(), + Reals: restoreReals(vs), + AllowedSrcs: restoreAllowedSources(vs, isV6), + Peers: restorePeers(vs), + Scheduler: vs.scheduler(), + } +} + +func restoreReals(vs *VS) []*balancerpb.Real { + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + result := make([]*balancerpb.Real, 0) + for i := range reals { + if reals[i].isRemoved() { + continue + } + result = append(result, restoreReal(&reals[i])) + } + return result +} + +func restoreReal(r *Real) *balancerpb.Real { + ipProto := ipprotoIP + if r.isIPv6() { + ipProto = ipprotoIPv6 + } + + ip := append([]byte(nil), r.Addr.Bytes(ipProto)...) + + return &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{Ip: ip}, + Weight: r.Weight, + Src: restoreIPNet(&r.Src, r.isIPv6()), + } +} + +// restoreIPNet reads a Net union back to a filterpb.IPNet. +func restoreIPNet(net *Net, isV6 bool) *filterpb.IPNet { + proto := ipprotoIP + size := 4 + if isV6 { + proto = ipprotoIPv6 + size = 16 + } + addr := make([]byte, size) + mask := make([]byte, size) + copy(addr, net.AddrBytes(proto)) + copy(mask, net.MaskBytes(proto)) + return &filterpb.IPNet{Addr: addr, Mask: mask} +} + +func restoreAllowedSources(vs *VS, isV6 bool) []*balancerpb.AllowedSources { + srcs := relptr.Slice(&vs.Allowed_sources, vs.Allowed_sources_count) + result := make([]*balancerpb.AllowedSources, len(srcs)) + for i := range srcs { + result[i] = restoreAllowedSource(&srcs[i], isV6) + } + return result +} + +func restoreAllowedSource(src *AllowedSource, isV6 bool) *balancerpb.AllowedSources { + rawNets := relptr.Slice(&src.Nets, src.Nets_count) + nets := make([]*filterpb.IPNet, len(rawNets)) + for i := range rawNets { + nets[i] = restoreIPNet(&rawNets[i], isV6) + } + + rawPr := relptr.Slice(&src.Port_ranges, src.Port_ranges_count) + pr := make([]*filterpb.PortRange, len(rawPr)) + for i, p := range rawPr { + pr[i] = &filterpb.PortRange{ + From: uint32(p.From), + To: uint32(p.To), + } + } + + result := &balancerpb.AllowedSources{ + Nets: nets, + Ports: pr, + } + + tag := restoreACLTag(src.Tag[:]) + if len(tag) > 0 { + result.Tag = &tag + } + + return result +} + +func restoreACLTag(b []int8) string { + s := make([]byte, 0, len(b)) + for _, c := range b { + if c == 0 { + break + } + s = append(s, byte(c)) + } + return string(s) +} + +func restorePeers(vs *VS) [][]byte { + total := int(vs.Peers_v4_count) + int(vs.Peers_v6_count) + peers := make([][]byte, 0, total) + for _, p := range relptr.Slice(&vs.Peers_v4, vs.Peers_v4_count) { + addr := make([]byte, 4) + copy(addr, p.Bytes[:]) + peers = append(peers, addr) + } + for _, p := range relptr.Slice(&vs.Peers_v6, vs.Peers_v6_count) { + addr := make([]byte, 16) + copy(addr, p.Bytes[:]) + peers = append(peers, addr) + } + return peers +} + +func restoreStateConfig(ph *PacketHandler) *balancerpb.StateConfig { + capacity := uint64(relptr.Deref(&ph.Session_table).capacity()) + maxLoadFactor := ph.Session_table_max_load_factor + wlcPower := uint64(ph.Wlc_power) + wlcMaxWeight := uint32(ph.Wlc_max_weight) + refreshPeriod := durationpb.New( + time.Duration(ph.Refresh_period_ms) * time.Millisecond, + ) + + return &balancerpb.StateConfig{ + SessionTableCapacity: &capacity, + SessionTableMaxLoadFactor: &maxLoadFactor, + Wlc: &balancerpb.WlcConfig{ + Power: &wlcPower, + MaxWeight: &wlcMaxWeight, + }, + RefreshPeriod: refreshPeriod, + } +} diff --git a/modules/balancer/controlplane/service.go b/modules/balancer/controlplane/service.go new file mode 100644 index 000000000..3a1abbef6 --- /dev/null +++ b/modules/balancer/controlplane/service.go @@ -0,0 +1,399 @@ +package balancer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/c2h5oh/datasize" + "github.com/yanet-platform/yanet2/common/commonpb" + "github.com/yanet-platform/yanet2/common/go/metrics" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +type Service struct { + balancerpb.UnimplementedBalancerServer + + agent *Agent + mu sync.Mutex + log *zap.SugaredLogger + + metrics methodMetrics +} + +func NewService( + shm *yanet.SharedMemory, + instanceIdx uint32, + size datasize.ByteSize, + log *zap.SugaredLogger, +) (*Service, error) { + log.Info("initializing balancer service") + + agent, err := ReattachAgent(shm, instanceIdx, size, log) + if err != nil { + log.Errorw("failed to reattach balancer agent", "error", err) + return nil, fmt.Errorf("failed to reattach balancer agent: %w", err) + } + + s := &Service{ + agent: agent, + log: log, + metrics: newMethodMetrics(), + } + + for _, balancer := range agent.Balancers() { + balancer.startRefreshing(&s.mu) + } + + return s, nil +} + +// getBalancerWithAutoSelection retrieves a balancer by name. +// If name is nil or empty, attempts to auto-select when exactly one balancer exists. +// The caller must hold s.mu. +func (s *Service) getBalancerWithAutoSelection( + name *string, +) (*Balancer, string, error) { + if name != nil { + b, ok := s.agent.GetBalancer(*name) + if !ok { + return nil, "", CodedErrorf(codes.NotFound, "balancer %q not found", *name) + } + return b, *name, nil + } + + names := s.agent.BalancerNames() + + if len(names) == 0 { + return nil, "", CodedErrorf(codes.NotFound, "no balancers found") + } + + if len(names) > 1 { + return nil, "", CodedErrorf( + codes.InvalidArgument, + "multiple balancers found (%d), please specify name explicitly", len(names), + ) + } + + selected := names[0] + s.log.Debugw("auto-selected balancer", "name", selected) + + b, _ := s.agent.GetBalancer(selected) + return b, selected, nil +} + +func (s *Service) SetConfig( + _ context.Context, + req *balancerpb.SetConfigRequest, +) (*balancerpb.SetConfigResponse, error) { + tracker := newMetricsTracker( + "set_config", s.metrics, defaultLatencyBoundsMS, + metrics.Labels{"config": req.GetName()}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + name := req.GetName() + if name == "" { + return nil, CodedErrorf(codes.InvalidArgument, "name is required") + } + + b, exists := s.agent.GetBalancer(name) + if exists { + s.log.Infow("updating balancer config", "name", name) + + now := time.Now() + reuseReport, err := b.Update(req.Config, &now) + if err != nil { + s.log.Errorw("failed to update balancer", "name", name, "error", err) + return nil, err + } + + s.log.Infow("balancer config updated", "name", name) + + return &balancerpb.SetConfigResponse{ + Name: name, + Reuse: reuseReport, + SessionTableCapacity: b.SessionTableCapacity(), + }, nil + } + + s.log.Infow("creating new balancer", "name", name) + + b, err := NewBalancer(s.agent, name, req.Config, s.log) + if err != nil { + s.log.Errorw("failed to create balancer", "name", name, "error", err) + return nil, err + } + s.agent.PutBalancer(name, b) + + b.startRefreshing(&s.mu) + + s.log.Infow("balancer created", "name", name) + + return &balancerpb.SetConfigResponse{ + Name: name, + SessionTableCapacity: b.SessionTableCapacity(), + }, nil +} + +func (s *Service) ListBalancers( + _ context.Context, + _ *balancerpb.ListBalancersRequest, +) (*balancerpb.ListBalancersResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return &balancerpb.ListBalancersResponse{ + Names: s.agent.BalancerNames(), + }, nil +} + +func (s *Service) GetConfig( + _ context.Context, + req *balancerpb.GetConfigRequest, +) (*balancerpb.GetConfigResponse, error) { + tracker := newMetricsTracker( + "get_config", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, name, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return nil, fmt.Errorf("failed to auto-select balancer: %w", err) + } + + return &balancerpb.GetConfigResponse{ + Name: name, + Config: b.Config(), + BufferedRealUpdates: b.BufferedRealUpdates(), + }, nil +} + +func (s *Service) GetState( + _ context.Context, + req *balancerpb.GetStateRequest, +) (*balancerpb.GetStateResponse, error) { + tracker := newMetricsTracker( + "get_state", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + var balancers map[string]*Balancer + if req.Name != nil { + b, ok := s.agent.GetBalancer(*req.Name) + if !ok { + return nil, CodedErrorf(codes.NotFound, "balancer %q not found", *req.Name) + } + balancers = map[string]*Balancer{*req.Name: b} + } else { + balancers = s.agent.Balancers() + } + + var allStates []*balancerpb.BalancerState + for _, b := range balancers { + states, err := b.GetState(req.PacketHandlerRef, req.Filter, req.IncludeCounters, time.Now()) + if err != nil { + return nil, err + } + allStates = append(allStates, states...) + } + + return &balancerpb.GetStateResponse{ + State: allStates, + }, nil +} + +func (s *Service) ListSessions( + req *balancerpb.ListSessionsRequest, + stream grpc.ServerStreamingServer[balancerpb.Session], +) error { + tracker := newMetricsTracker( + "list_sessions", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, _, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return fmt.Errorf("failed to auto-select balancer: %w", err) + } + + return b.ListSessions(req.Filter, time.Now(), func(session *balancerpb.Session) error { + return stream.Send(session) + }) +} + +func (s *Service) UpdateReals( + _ context.Context, + req *balancerpb.UpdateRealsRequest, +) (*balancerpb.UpdateRealsResponse, error) { + tracker := newMetricsTracker( + "update_reals", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, name, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return nil, fmt.Errorf("failed to auto-select balancer: %w", err) + } + + count, err := b.UpdateReals(req.Updates, req.Buffer) + if err != nil { + s.log.Errorw("failed to update reals", "name", name, "error", err) + return nil, err + } + + resp := &balancerpb.UpdateRealsResponse{Name: name} + if req.Buffer { + s.log.Infow("real updates buffered", "name", name, "count", count) + } else { + s.log.Infow("real updates applied", "name", name, "count", count) + } + + resp.UpdatesApplied = uint32(count) + + return resp, nil +} + +func (s *Service) FlushReals( + _ context.Context, + req *balancerpb.FlushRealsRequest, +) (*balancerpb.FlushRealsResponse, error) { + tracker := newMetricsTracker( + "flush_reals", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, name, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return nil, fmt.Errorf("failed to auto-select balancer: %w", err) + } + + count, err := b.FlushRealUpdates() + if err != nil { + s.log.Errorw("failed to flush real updates", "name", name, "error", err) + return nil, err + } + + s.log.Infow("real updates flushed", "name", name, "count", count) + + return &balancerpb.FlushRealsResponse{ + Name: name, + UpdatesFlushed: uint64(count), + }, nil +} + +func (s *Service) UpdateVS( + _ context.Context, + req *balancerpb.UpdateVSRequest, +) (*balancerpb.UpdateVSResponse, error) { + tracker := newMetricsTracker( + "update_vs", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, name, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return nil, fmt.Errorf("failed to auto-select balancer: %w", err) + } + + s.log.Infow("updating virtual services", "name", name, "vs_count", len(req.Services)) + + reuseReport, err := b.UpdateVS(req.Services) + if err != nil { + s.log.Errorw("failed to update virtual services", "name", name, "error", err) + return nil, err + } + + s.log.Infow("virtual services updated", "name", name, "vs_count", len(req.Services)) + + return &balancerpb.UpdateVSResponse{ + Name: name, + Reuse: reuseReport, + }, nil +} + +func (s *Service) DeleteVS( + _ context.Context, + req *balancerpb.DeleteVSRequest, +) (*balancerpb.DeleteVSResponse, error) { + tracker := newMetricsTracker( + "delete_vs", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + b, name, err := s.getBalancerWithAutoSelection(req.Name) + if err != nil { + return nil, fmt.Errorf("failed to auto-select balancer: %w", err) + } + + s.log.Infow("deleting virtual services", "name", name, "vs_count", len(req.Services)) + + reuseReport, err := b.DeleteVS(req.Services) + if err != nil { + s.log.Errorw("failed to delete virtual services", "name", name, "error", err) + return nil, err + } + + s.log.Infow("virtual services deleted", "name", name, "vs_count", len(req.Services)) + + return &balancerpb.DeleteVSResponse{ + Name: name, + Reuse: reuseReport, + }, nil +} + +func (s *Service) GetMetrics( + _ context.Context, + _ *balancerpb.GetMetricsRequest, +) (*balancerpb.GetMetricsResponse, error) { + tracker := newMetricsTracker( + "get_metrics", s.metrics, defaultLatencyBoundsMS, metrics.Labels{}, + ) + defer tracker.Fix() + + s.mu.Lock() + defer s.mu.Unlock() + + result := make([]*commonpb.Metric, 0) + + for _, b := range s.agent.Balancers() { + bMetrics, err := b.Metrics(time.Now()) + if err != nil { + return nil, err + } + result = append(result, bMetrics...) + } + + result = append(result, s.metrics.collect()...) + + return &balancerpb.GetMetricsResponse{Metrics: result}, nil +} diff --git a/modules/balancer/controlplane/sessions.go b/modules/balancer/controlplane/sessions.go new file mode 100644 index 000000000..beefa4ce1 --- /dev/null +++ b/modules/balancer/controlplane/sessions.go @@ -0,0 +1,118 @@ +package balancer + +import ( + "encoding/binary" + "time" + "unsafe" + + "github.com/yanet-platform/yanet2/common/go/relptr" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ListSessions iterates the session table bucket-by-bucket and calls yield +// for each session that matches the filter. Iteration stops early if yield +// returns a non-nil error (e.g. gRPC stream cancelled). +func (b *Balancer) ListSessions( + filter *balancerpb.Filter, + now time.Time, + yield func(*balancerpb.Session) error, +) error { + if err := validateFilter(filter); err != nil { + return err + } + + matcher := newFilterMatcher(filter) + st := relptr.Deref(&b.handler.Session_table) + services := relptr.Slice(&b.handler.Vs, b.handler.Vs_count) + + iter := st.newSessionIter() + unixNow := uint32(now.Unix()) + + var buf [bucketMaxEntries]SessionEntry + + for { + count := iter.nextBucket(unixNow, buf[:]) + if count < 0 { + break + } + for i := range count { + entry := &buf[i] + + session, ok := resolveSession(entry, services, &matcher) + if !ok { + continue + } + + if err := yield(session); err != nil { + return err + } + } + } + + return nil +} + +// resolveSession converts a raw session entry into a protobuf Session, +// applying filter checks. Returns false if the session is stale or filtered out. +func resolveSession( + entry *SessionEntry, + services []VS, + matcher *filterMatcher, +) (*balancerpb.Session, bool) { + vsStableIdx := entry.Id.Vs_stable_idx + vsConfigIdx := configIndexOf(vsStableIdx) + + // No bounds check needed here: + // - vsConfigIdx is guaranteed to be within bounds because VS array never shrinks. + vs := &services[vsConfigIdx] + if vs.isRemoved() || vs.Stable_idx != vsStableIdx { + return nil, false + } + + vsID := vs.id() + if matcher.hasVsFilter && !matcher.matchVsID(vsID) { + return nil, false + } + + realStableIdx := entry.State.Real_stable_idx + realConfigIdx := configIndexOf(realStableIdx) + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + + r := &reals[realConfigIdx] + if r.isRemoved() || r.Stable_idx != realStableIdx { + return nil, false + } + + realRelID := r.id() + if matcher.hasRealFilter && !matcher.matchRealID(realRelID) { + return nil, false + } + + clientAddrLen := 16 + if vs.Ip_proto == ipprotoIP { + clientAddrLen = 4 + } + clientAddr := make([]byte, clientAddrLen) + copy(clientAddr, entry.Id.Client_ip[:clientAddrLen]) + + // client_port is stored in network byte order (copied directly from the + // TCP/UDP header by the dataplane). Convert to host byte order. + portBytes := (*[2]byte)(unsafe.Pointer(&entry.Id.Client_port)) + clientPort := binary.BigEndian.Uint16(portBytes[:]) + + createTimestamp := time.Unix(int64(entry.State.Create_timestamp), 0) + lastPacketTimestamp := time.Unix(int64(entry.State.Last_packet_timestamp), 0) + timeout := time.Duration(entry.State.Timeout) * time.Second + + return &balancerpb.Session{ + ClientAddr: clientAddr, + ClientPort: uint32(clientPort), + VsId: vsID, + RealId: realRelID, + CreateTimestamp: timestamppb.New(createTimestamp), + LastPacketTimestamp: timestamppb.New(lastPacketTimestamp), + Timeout: durationpb.New(timeout), + }, true +} diff --git a/modules/balancer/controlplane/state/active_sessions.c b/modules/balancer/controlplane/state/active_sessions.c deleted file mode 100644 index b2f70c461..000000000 --- a/modules/balancer/controlplane/state/active_sessions.c +++ /dev/null @@ -1,36 +0,0 @@ -#include "active_sessions.h" -#include "interval_counter.h" -#include "modules/balancer/dataplane/active_sessions.h" - -static inline size_t -tracker_size(size_t shards) { - return sizeof(struct active_sessions_tracker_shard) * shards; -} - -struct active_sessions_tracker_shard * -active_sessions_tracker_create( - struct memory_context *mctx, size_t shards, uint32_t now -) { - size_t size = tracker_size(shards); - struct active_sessions_tracker_shard *tracker_shards = - memory_balloc(mctx, size); - if (tracker_shards != NULL) { - for (size_t shard = 0; shard < shards; ++shard) { - rt_interval_counter_init( - &tracker_shards[shard].counter, now - ); - tracker_shards[shard].count = 0; - } - } - return tracker_shards; -} - -void -active_sessions_tracker_destroy( - struct active_sessions_tracker_shard *tracker_shards, - size_t shards, - struct memory_context *mctx -) { - size_t size = tracker_size(shards); - memory_bfree(mctx, tracker_shards, size); -} \ No newline at end of file diff --git a/modules/balancer/controlplane/state/active_sessions.h b/modules/balancer/controlplane/state/active_sessions.h deleted file mode 100644 index 5d8a148bb..000000000 --- a/modules/balancer/controlplane/state/active_sessions.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -/* Control-plane helpers for allocating per-worker active-session trackers. */ -#include "modules/balancer/dataplane/active_sessions.h" - -#include "common/memory.h" -#include - -/* Allocate and initialize tracker shards. */ -struct active_sessions_tracker_shard * -active_sessions_tracker_create( - struct memory_context *mctx, size_t shards, uint32_t now -); - -/* Release tracker shards created by `active_sessions_tracker_create()`. */ -void -active_sessions_tracker_destroy( - struct active_sessions_tracker_shard *tracker_shards, - size_t shards, - struct memory_context *mctx -); diff --git a/modules/balancer/controlplane/state/interval_counter.c b/modules/balancer/controlplane/state/interval_counter.c deleted file mode 100644 index 97ea72b00..000000000 --- a/modules/balancer/controlplane/state/interval_counter.c +++ /dev/null @@ -1,9 +0,0 @@ -#include "interval_counter.h" -#include "modules/balancer/dataplane/interval_counter.h" -#include - -void -rt_interval_counter_init(struct rt_interval_counter *counter, uint32_t now) { - memset(counter->diff, 0, sizeof(counter->diff)); - counter->last_timestamp = now; -} diff --git a/modules/balancer/controlplane/state/interval_counter.h b/modules/balancer/controlplane/state/interval_counter.h deleted file mode 100644 index a4ee64cb3..000000000 --- a/modules/balancer/controlplane/state/interval_counter.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -/** - * @file rt_interval_counter.h - * - * Control-plane initialization helpers for [`struct - * rt_interval_counter`](modules/balancer/dataplane/interval_counter.h:19). - */ - -#include "modules/balancer/dataplane/interval_counter.h" - -/* Initialize a counter so it can start tracking intervals at `now`. */ -void -rt_interval_counter_init(struct rt_interval_counter *counter, uint32_t now); diff --git a/modules/balancer/controlplane/state/meson.build b/modules/balancer/controlplane/state/meson.build deleted file mode 100644 index 330f3862f..000000000 --- a/modules/balancer/controlplane/state/meson.build +++ /dev/null @@ -1,27 +0,0 @@ -dependencies = [ - lib_common_dep, - lib_agent_cp_dep, -] - -includes = include_directories('.', '../') - -sources = files( - 'session_table.c', - 'state.c', - 'interval_counter.c', - 'active_sessions.c', -) - -lib_balancer_state = static_library( - 'balancer_state', - sources, - c_args: yanet_c_args, - link_args: yanet_link_args, - dependencies: dependencies, - include_directories: includes, - install: false, -) - -lib_balancer_state_dep = declare_dependency( - link_with: [lib_balancer_state], -) diff --git a/modules/balancer/controlplane/state/session.h b/modules/balancer/controlplane/state/session.h deleted file mode 100644 index 355434362..000000000 --- a/modules/balancer/controlplane/state/session.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include "common/network.h" -#include "common/ttlmap/detail/lock.h" - -typedef ttlmap_lock_t session_lock_t; - -/** - * Key that identifies a session in the state layer. - */ -struct session_id { - struct net_addr client_ip; // Client source IP (IPv4/IPv6) - uint16_t client_port; // Client source port (network byte order) - size_t vs_id; // Target virtual service -}; - -/** - * Stored session metadata in the state layer. - */ -struct session_state { - size_t real_id; // Global stable real ID for this session - uint32_t timeout; // Current timeout applied (seconds) - uint32_t last_packet_timestamp; // Last packet timestamp (monotonic) - uint32_t create_timestamp; // Creation timestamp (monotonic) -}; \ No newline at end of file diff --git a/modules/balancer/controlplane/state/session_table.c b/modules/balancer/controlplane/state/session_table.c deleted file mode 100644 index d4e20bb02..000000000 --- a/modules/balancer/controlplane/state/session_table.c +++ /dev/null @@ -1,216 +0,0 @@ -#include "session_table.h" -#include "common/memory.h" -#include "common/rcu.h" -#include "common/ttlmap/detail/ttlmap.h" -#include "common/ttlmap/ttlmap.h" - -#include "lib/controlplane/diag/diag.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/ttlmap/ttlmap.h" - -#include "common/ttlmap/ttlmap.h" - -#include "api/session.h" - -#include "session.h" - -#include "state.h" - -//////////////////////////////////////////////////////////////////////////////// - -int -session_table_init( - struct session_table *table, struct memory_context *mctx, size_t size -) { - memory_context_init_from(&table->mctx, mctx, "session_table"); - - int res = TTLMAP_INIT( - &table->maps[0], - &table->mctx, - struct session_id, - struct session_state, - size - ); - if (res != 0) { - return -1; - } - - ttlmap_init_empty(&table->maps[1]); - - // Init generation count - // (guarded with rcu) - rcu_init(&table->rcu); - table->current_gen = 0; - - return 0; -} - -void -session_table_free(struct session_table *table) { - for (size_t i = 0; i < 2; ++i) { - TTLMAP_FREE(&table->maps[i]); - } -} - -//////////////////////////////////////////////////////////////////////////////// - -size_t -session_table_capacity(struct session_table *table) { - struct ttlmap *ttlmap = - session_table_map(table, session_table_current_gen(table)); - return ttlmap_capacity(ttlmap); -} - -//////////////////////////////////////////////////////////////////////////////// - -struct fill_sessions_context { - struct balancer_state *state; - struct named_session_info *info; - size_t count; - size_t size; - bool only_count; - uint32_t now; -}; - -//////////////////////////////////////////////////////////////////////////////// - -struct move_sessions_context { - struct ttlmap *next_map; - uint32_t now; -}; - -static int -move_sessions_callback( - struct session_id *id, - struct session_state *state, - struct move_sessions_context *ctx -) { - if (state->last_packet_timestamp + state->timeout <= ctx->now) { - return 0; - } - - session_lock_t *lock; - struct session_state *found; - int res = TTLMAP_GET( - ctx->next_map, - id, - &found, - &lock, - state->last_packet_timestamp, - state->timeout - ); - - int status = TTLMAP_STATUS(res); - if (status == TTLMAP_INSERTED || status == TTLMAP_REPLACED) { - memcpy(found, state, sizeof(struct session_state)); - ttlmap_release_lock(lock); - } else if (status == TTLMAP_FOUND) { - ttlmap_release_lock(lock); - } else { // status == TTLMAP_FAILED - // critical: misses some session, session table grows too fast - } - - return 0; -} - -static inline void -set_gen(struct session_table *table, uint32_t gen) { - rcu_update(&table->rcu, &table->current_gen, gen); -} - -static inline uint64_t -get_gen(struct session_table *table) { - return rcu_load(&table->rcu, &table->current_gen); -} - -int -session_table_resize( - struct session_table *table, size_t new_size, uint32_t now -) { - uint32_t current_gen = get_gen(table); - - struct ttlmap *next_map = session_table_prev_map(table, current_gen); - struct memory_context *mctx = &table->mctx; - - int init_result = TTLMAP_INIT( - next_map, - mctx, - struct session_id, - struct session_state, - new_size - ); - if (init_result != 0) { - NEW_ERROR("failed to init new table"); - // no memory - return -1; - } - - // Update current gen, so all workers use primary `next_map` - // and fallbacks to the `current_map` - struct ttlmap *current_map = session_table_map(table, current_gen); - ++current_gen; - set_gen(table, current_gen); - - // Now, workers can not update `current_map`. - // They insert only into the `next_map`. - - // After that, we should move all sessions from the current_map to the - // next_map - - struct move_sessions_context ctx = { - .next_map = next_map, - .now = now, - }; - TTLMAP_ITER( - current_map, - struct session_id, - struct session_state, - now, - move_sessions_callback, - &ctx - ); - - // Sessions are moved, so workers dont need to use previous map - ++current_gen; - set_gen(table, current_gen); - - // After that, workers will not use previous map - - // So we can free current_map - TTLMAP_FREE(current_map); - - return 0; -} - -int -session_table_iter( - struct session_table *table, - uint32_t now, - session_table_iter_callback cb, - void *userdata -) { - return TTLMAP_ITER( - session_table_map(table, session_table_current_gen(table)), - struct session_id, - struct session_state, - now, - cb, - userdata - ); -} - -size_t -session_table_memory_usage(struct session_table *table) { - return table->mctx.balloc_size - table->mctx.bfree_size + - ttlmap_memory_usage(&table->maps[0]) + - ttlmap_memory_usage(&table->maps[1]); -} \ No newline at end of file diff --git a/modules/balancer/controlplane/state/session_table.h b/modules/balancer/controlplane/state/session_table.h deleted file mode 100644 index 966bb42c2..000000000 --- a/modules/balancer/controlplane/state/session_table.h +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include "common/memory.h" -#include "common/rcu.h" -#include "common/ttlmap/detail/ttlmap.h" - -#include "state/session.h" - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -#define SESSION_FOUND TTLMAP_FOUND -#define SESSION_CREATED (TTLMAP_INSERTED | TTLMAP_REPLACED) -#define SESSION_TABLE_OVERFLOW TTLMAP_FAILED - -/** - * Lock-free session table with RCU-protected generation swapping. - */ -struct session_table { - struct ttlmap maps[2]; // Active and previous maps - - rcu_t rcu; // RCU guard for map swaps - _Atomic uint64_t current_gen; // Workers read, control-plane updates - - struct memory_context mctx; // Allocation context -}; - -static inline int -session_table_map_idx(uint32_t gen) { - return ((gen + 1) & 0b11) >> 1; -} - -static inline struct ttlmap * -session_table_map(struct session_table *table, uint32_t gen) { - return &table->maps[session_table_map_idx(gen)]; -} - -static inline struct ttlmap * -session_table_prev_map(struct session_table *table, uint32_t gen) { - return &table->maps[session_table_map_idx(gen) ^ 1]; -} - -static inline uint32_t -session_table_current_gen(struct session_table *table) { - return atomic_load_explicit(&table->current_gen, memory_order_acquire); -} - -/** - * Initialize session table. - * Returns 0 on success, -1 on error. - */ -int -session_table_init( - struct session_table *table, struct memory_context *mctx, size_t size -); - -/** - * Free resources held by the session table. - */ -void -session_table_free(struct session_table *table); - -/** - * Current capacity (number of buckets/entries). - */ -size_t -session_table_capacity(struct session_table *table); - -/** - * Try to resize session table. - * Returns 0 on success, -1 on error (e.g., out of memory). - */ -int -session_table_resize( - struct session_table *table, size_t new_size, uint32_t now -); - -//////////////////////////////////////////////////////////////////////////////// - -struct balancer_info; - -void -session_table_fill_balancer_info( - struct session_table *table, struct balancer_info *info, uint32_t now -); - -typedef int (*session_table_iter_callback)( - struct session_id *id, struct session_state *state, void *userdata -); - -int -session_table_iter( - struct session_table *table, - uint32_t now, - session_table_iter_callback cb, - void *userdata -); - -size_t -session_table_memory_usage(struct session_table *table); \ No newline at end of file diff --git a/modules/balancer/controlplane/state/state.c b/modules/balancer/controlplane/state/state.c deleted file mode 100644 index 4e9d0baed..000000000 --- a/modules/balancer/controlplane/state/state.c +++ /dev/null @@ -1,71 +0,0 @@ -#include "state.h" - -#include "common/memory.h" -#include "controlplane/diag/diag.h" -#include "session_table.h" -#include -#include -#include - -int -balancer_state_init( - struct balancer_state *state, - struct memory_context *mctx, - size_t workers, - size_t table_size -) { - assert((uintptr_t)state % alignof(struct balancer_state) == 0); - - // workers - state->workers = workers; - - // init session table - int res = session_table_init(&state->session_table, mctx, table_size); - if (res != 0) { - NEW_ERROR("failed to initialize session table"); - return -1; - } - - return 0; -} - -void -balancer_state_free(struct balancer_state *state) { - session_table_free(&state->session_table); -} - -//////////////////////////////////////////////////////////////////////////////// - -int -balancer_state_resize_session_table( - struct balancer_state *state, size_t new_size, uint32_t now -) { - return session_table_resize(&state->session_table, new_size, now); -} - -size_t -balancer_state_session_table_capacity(struct balancer_state *state) { - return session_table_capacity(&state->session_table); -} - -//////////////////////////////////////////////////////////////////////////////// - -int -balancer_state_iter_session_table( - struct balancer_state *state, - uint32_t now, - session_table_iter_callback cb, - void *userdata -) { - return session_table_iter(&state->session_table, now, cb, userdata); -} - -// TODO: docs -void -balancer_state_inspect( - struct balancer_state *state, struct state_inspect *inspect -) { - inspect->session_table_usage = - session_table_memory_usage(&state->session_table); - inspect->total_usage = inspect->session_table_usage; -} \ No newline at end of file diff --git a/modules/balancer/controlplane/state/state.h b/modules/balancer/controlplane/state/state.h deleted file mode 100644 index b2c57e68a..000000000 --- a/modules/balancer/controlplane/state/state.h +++ /dev/null @@ -1,66 +0,0 @@ -#pragma once - -#include "common/memory.h" - -#include "session_table.h" - -#include "api/inspect.h" - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Persistent balancer state. - * Holds registries (VS/reals), session table and per-worker stats. - */ -struct balancer_state { - // number of workers - size_t workers; - - // session table - struct session_table session_table; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Initialize balancer_state. - * Returns 0 on success, -1 on error. - */ -int -balancer_state_init( - struct balancer_state *state, - struct memory_context *mctx, - size_t workers, - size_t table_size -); - -/** - * Free resources held by balancer_state. - */ -void -balancer_state_free(struct balancer_state *state); - -//////////////////////////////////////////////////////////////////////////////// - -/** - * Resize session table; returns 0 on success, -1 on error. - */ -int -balancer_state_resize_session_table( - struct balancer_state *state, size_t new_size, uint32_t now -); - -// TODO: docs -int -balancer_state_iter_session_table( - struct balancer_state *state, - uint32_t now, - session_table_iter_callback cb, - void *userdata -); - -// TODO: docs -void -balancer_state_inspect( - struct balancer_state *state, struct state_inspect *inspect -); \ No newline at end of file diff --git a/modules/balancer/controlplane/state/worker.h b/modules/balancer/controlplane/state/worker.h deleted file mode 100644 index 8ee5d1acc..000000000 --- a/modules/balancer/controlplane/state/worker.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "common/rcu.h" - -#include - -#define MAX_WORKERS_NUM 8 - -static_assert(MAX_WORKERS_NUM <= RCU_WORKERS, "too many workers"); diff --git a/modules/balancer/controlplane/stats.go b/modules/balancer/controlplane/stats.go new file mode 100644 index 000000000..88e3cef5e --- /dev/null +++ b/modules/balancer/controlplane/stats.go @@ -0,0 +1,270 @@ +package balancer + +import ( + "strconv" + "strings" + "unsafe" + + "github.com/yanet-platform/yanet2/common/go/relptr" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +// Aggregate the worker local counters into the first counter. +func aggregateCounter(counter [][]uint64) { + for idx := 1; idx < len(counter); idx++ { + for i := range counter[idx] { + counter[0][i] += counter[idx][i] + } + } +} + +func commonStats(counter [][]uint64) *CommonStats { + aggregateCounter(counter) + return (*CommonStats)(unsafe.Pointer(&counter[0][0])) +} + +func (common *CommonStats) proto() *balancerpb.CommonStats { + return &balancerpb.CommonStats{ + IncomingPackets: common.Incoming_packets, + IncomingBytes: common.Incoming_bytes, + UnexpectedNetworkProto: common.Unexpected_network_proto, + DecapSuccessful: common.Decap_successful, + DecapFailed: common.Decap_failed, + OutgoingPackets: common.Outgoing_packets, + OutgoingBytes: common.Outgoing_bytes, + } +} + +func icmpStats(counter [][]uint64) *IcmpStats { + aggregateCounter(counter) + return (*IcmpStats)(unsafe.Pointer(&counter[0][0])) +} + +func (i *IcmpStats) proto() *balancerpb.IcmpStats { + return &balancerpb.IcmpStats{ + IncomingPackets: i.Incoming_packets, + SrcNotAllowed: i.Src_not_allowed, + EchoResponses: i.Echo_responses, + PayloadTooShortIp: i.Payload_too_short_ip, + UnmatchingSrcFromOriginal: i.Unmatching_src_from_original, + PayloadTooShortPort: i.Payload_too_short_port, + UnexpectedTransport: i.Unexpected_transport, + UnrecognizedVs: i.Unrecognized_vs, + ForwardedPackets: i.Forwarded_packets, + BroadcastedPackets: i.Broadcasted_packets, + PacketClonesSent: i.Packet_clones_sent, + PacketClonesReceived: i.Packet_clones_received, + PacketCloneFailures: i.Packet_clone_failures, + } +} + +func l4Stats(counter [][]uint64) *L4Stats { + aggregateCounter(counter) + return (*L4Stats)(unsafe.Pointer(&counter[0][0])) +} + +func (l4 *L4Stats) proto() *balancerpb.L4Stats { + return &balancerpb.L4Stats{ + IncomingPackets: l4.Incoming_packets, + SelectVsFailed: l4.Select_vs_failed, + InvalidPackets: l4.Invalid_packets, + SelectRealFailed: l4.Select_real_failed, + OutgoingPackets: l4.Outgoing_packets, + } +} + +func vsStats(counter [][]uint64) *VsStats { + aggregateCounter(counter) + return (*VsStats)(unsafe.Pointer(&counter[0][0])) +} + +func (vs *VsStats) proto() *balancerpb.VsStats { + return &balancerpb.VsStats{ + IncomingPackets: vs.Incoming_packets, + IncomingBytes: vs.Incoming_bytes, + PacketSrcNotAllowed: vs.Packet_src_not_allowed, + NoReals: vs.No_reals, + SessionTableOverflow: vs.Session_table_overflow, + EchoIcmpPackets: vs.Echo_icmp_packets, + ErrorIcmpPackets: vs.Error_icmp_packets, + RealIsDisabled: vs.Real_is_disabled, + RealIsRemoved: vs.Real_is_removed, + NotRescheduledPackets: vs.Not_rescheduled_packets, + BroadcastedIcmpPackets: vs.Broadcasted_icmp_packets, + CreatedSessions: vs.Created_sessions, + OutgoingPackets: vs.Outgoing_packets, + OutgoingBytes: vs.Outgoing_bytes, + } +} + +func realStats(counter [][]uint64) *RealStats { + aggregateCounter(counter) + return (*RealStats)(unsafe.Pointer(&counter[0][0])) +} + +func (rs *RealStats) proto() *balancerpb.RealStats { + return &balancerpb.RealStats{ + CreatedSessions: rs.Created_sessions, + Packets: rs.Packets, + Bytes: rs.Bytes, + PacketsRealDisabled: rs.Packets_real_disabled, + ErrorIcmpPackets: rs.Error_icmp_packets, + } +} + +func aggregateACLPasses(counter [][]uint64) uint64 { + aggregateCounter(counter) + return counter[0][0] +} + +func resolveVS(services []VS, vsStableIndex uint64) (*VS, bool) { + vsConfigIndex := configIndexOf(vsStableIndex) + vs := &services[vsConfigIndex] + if vs.isRemoved() || vs.Stable_idx != vsStableIndex { + return nil, false + } + return vs, true +} + +func resolveReal(vs *VS, realStableIndex uint64) (*Real, bool) { + realConfigIndex := configIndexOf(realStableIndex) + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + r := &reals[realConfigIndex] + if r.isRemoved() || r.Stable_idx != realStableIndex { + return nil, false + } + return r, true +} + +func vsIndexFromCounterName(name string) (uint64, bool) { + stableIndex, err := strconv.ParseUint(strings.TrimPrefix(name, "vs_"), 10, 64) + if err != nil { + return 0, false + } + return stableIndex, true +} + +func realIndexFromCounterName(name string) (uint64, uint64, bool) { + parts := strings.SplitN(strings.TrimPrefix(name, "rl_"), "_", 2) + if len(parts) != 2 { + return 0, 0, false + } + vsStableIdx, err := strconv.ParseUint(parts[0], 10, 64) + if err != nil { + return 0, 0, false + } + realStableIdx, err := strconv.ParseUint(parts[1], 10, 64) + if err != nil { + return 0, 0, false + } + return vsStableIdx, realStableIdx, true +} + +func aclTagFromCounterName(name string) (uint64, string, bool) { + parts := strings.SplitN(strings.TrimPrefix(name, "acl_"), "_", 2) + if len(parts) != 2 { + return 0, "", false + } + vsStableIdx, err := strconv.ParseUint(parts[0], 10, 64) + if err != nil { + return 0, "", false + } + return vsStableIdx, parts[1], true +} + +func applyCounter( + handler *PacketHandler, + state *balancerpb.BalancerState, + counter yanet.CounterInfo, +) { + name := counter.Name + services := relptr.Slice(&handler.Vs, handler.Vs_count) + + switch { + case strings.HasPrefix(name, "vs_"): + vsStableIndex, ok := vsIndexFromCounterName(name) + if !ok { + return + } + if _, ok = resolveVS(services, vsStableIndex); !ok { + return + } + vsState := state.VirtualServices[configIndexOf(vsStableIndex)] + if vsState == nil { + return + } + vsState.Stats = vsStats(counter.Values).proto() + + case strings.HasPrefix(name, "rl_"): + vsStableIndex, realStableIndex, ok := realIndexFromCounterName(name) + if !ok { + return + } + vs, ok := resolveVS(services, vsStableIndex) + if !ok { + return + } + if _, ok := resolveReal(vs, realStableIndex); !ok { + return + } + vsState := state.VirtualServices[configIndexOf(vsStableIndex)] + if vsState == nil { + return + } + realState := vsState.Reals[configIndexOf(realStableIndex)] + if realState == nil { + return + } + realState.RealStats = realStats(counter.Values).proto() + + case strings.HasPrefix(name, "acl_"): + vsStableIndex, tag, ok := aclTagFromCounterName(name) + if !ok { + return + } + if _, ok := resolveVS(services, vsStableIndex); !ok { + return + } + vsState := state.VirtualServices[configIndexOf(vsStableIndex)] + if vsState == nil { + return + } + vsState.AllowedSources = append(vsState.AllowedSources, &balancerpb.AllowedSourcesStats{ + Tag: tag, + Passes: aggregateACLPasses(counter.Values), + }) + + case name == "cmn": + state.CommonStats = commonStats(counter.Values).proto() + case name == "iv4": + state.IcmpIpv4Stats = icmpStats(counter.Values).proto() + case name == "iv6": + state.IcmpIpv6Stats = icmpStats(counter.Values).proto() + case name == "l4": + state.L4Stats = l4Stats(counter.Values).proto() + } +} + +func compactVsState(state *balancerpb.VsState) { + next := 0 + for idx := range state.Reals { + if state.Reals[idx] != nil { + state.Reals[next] = state.Reals[idx] + next++ + } + } + state.Reals = state.Reals[:next] +} + +func compactBalancerState(state *balancerpb.BalancerState) { + next := 0 + for idx := range state.VirtualServices { + if state.VirtualServices[idx] != nil { + state.VirtualServices[next] = state.VirtualServices[idx] + compactVsState(state.VirtualServices[next]) + next++ + } + } + state.VirtualServices = state.VirtualServices[:next] +} diff --git a/modules/balancer/controlplane/validate.go b/modules/balancer/controlplane/validate.go new file mode 100644 index 000000000..7c33e08ea --- /dev/null +++ b/modules/balancer/controlplane/validate.go @@ -0,0 +1,379 @@ +package balancer + +import ( + "bytes" + "cmp" + "fmt" + "slices" + + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/grpc/codes" +) + +// invalidArg is a validator-local shorthand for leaf errors that should carry +// codes.InvalidArgument through the RPC boundary. +func invalidArg(format string, args ...any) error { + return CodedErrorf(codes.InvalidArgument, format, args...) +} + +func compareIPNet(a, b *filterpb.IPNet) int { + if c := bytes.Compare(a.Addr, b.Addr); c != 0 { + return c + } + return bytes.Compare(a.Mask, b.Mask) +} + +func comparePortRange(a, b *filterpb.PortRange) int { + if c := cmp.Compare(a.From, b.From); c != 0 { + return c + } + return cmp.Compare(a.To, b.To) +} + +func compareAllowedSourcesPb(a, b *balancerpb.AllowedSources) int { + if c := cmp.Compare(len(a.Nets), len(b.Nets)); c != 0 { + return c + } + for i := range a.Nets { + if c := compareIPNet(a.Nets[i], b.Nets[i]); c != 0 { + return c + } + } + if c := cmp.Compare(len(a.Ports), len(b.Ports)); c != 0 { + return c + } + for i := range a.Ports { + if c := comparePortRange(a.Ports[i], b.Ports[i]); c != 0 { + return c + } + } + return 0 +} + +func validateWlcConfig(wlc *balancerpb.WlcConfig) error { + if wlc.Power == nil { + return invalidArg("power is nil") + } + if wlc.MaxWeight == nil { + return invalidArg("max_weight is nil") + } + return nil +} + +func validateStateConfig(state *balancerpb.StateConfig) error { + if state.SessionTableCapacity == nil { + return invalidArg("session_table_capacity is nil") + } + if *state.SessionTableCapacity == 0 { + return invalidArg("session_table_capacity must be greater than 0") + } + if state.RefreshPeriod == nil { + return invalidArg("refresh_period is nil") + } + if state.SessionTableMaxLoadFactor == nil { + return invalidArg("session_table_max_load_factor is nil") + } + if *state.SessionTableMaxLoadFactor <= 0 || *state.SessionTableMaxLoadFactor > 1 { + return invalidArg("session_table_max_load_factor must be between 0 and 1") + } + if state.Wlc == nil { + return invalidArg("wlc config is nil") + } + if err := validateWlcConfig(state.Wlc); err != nil { + return fmt.Errorf("wlc: %w", err) + } + return nil +} + +func validateSessionsTimeouts(timeouts *balancerpb.SessionsTimeouts) error { + if timeouts.TcpSynAck > MaxSessionTimeout { + return invalidArg("tcp_syn_ack must be less than or equal to %d", MaxSessionTimeout) + } + if timeouts.TcpSyn > MaxSessionTimeout { + return invalidArg("tcp_syn must be less than or equal to %d", MaxSessionTimeout) + } + if timeouts.TcpFin > MaxSessionTimeout { + return invalidArg("tcp_fin must be less than or equal to %d", MaxSessionTimeout) + } + if timeouts.Tcp > MaxSessionTimeout { + return invalidArg("tcp must be less than or equal to %d", MaxSessionTimeout) + } + if timeouts.Udp > MaxSessionTimeout { + return invalidArg("udp must be less than or equal to %d", MaxSessionTimeout) + } + return nil +} + +func validateMask4(mask []byte) error { + bits := uint32(mask[0])<<24 | uint32(mask[1])<<16 | uint32(mask[2])<<8 | uint32(mask[3]) + inverted := ^bits + if inverted&(inverted+1) != 0 { + return invalidArg("mask is not contiguous") + } + return nil +} + +func isContiguous8(mask []byte) bool { + bits := uint64(0) + for i := range 8 { + bits |= uint64(mask[i]) << ((7 - i) * 8) + } + inverted := ^bits + return inverted&(inverted+1) == 0 +} + +// Check if the mask halves are contiguous. +func validateMask6(mask []byte) error { + if !isContiguous8(mask[:8]) { + return invalidArg("high mask bits are not contiguous") + } + if !isContiguous8(mask[8:]) { + return invalidArg("low mask bits are not contiguous") + } + return nil +} + +func validateNet(net *filterpb.IPNet, isV6 bool) error { + requiredLen := 4 + if isV6 { + requiredLen = 16 + } + if len(net.Addr) != requiredLen { + return invalidArg("net.addr must be %d bytes", requiredLen) + } + if len(net.Mask) != requiredLen { + return invalidArg("net.mask must be %d bytes", requiredLen) + } + if isV6 { + if err := validateMask6(net.Mask); err != nil { + return fmt.Errorf("IPv6 net mask: %w", err) + } + } else { + if err := validateMask4(net.Mask); err != nil { + return fmt.Errorf("IPv4 net mask: %w", err) + } + } + return nil +} + +func validatePortRange(portRange *filterpb.PortRange) error { + if portRange.From > portRange.To { + return invalidArg("port_range.from must be less than or equal to port_range.to") + } + if portRange.To > 65535 { + return invalidArg("port_range.to must be less than or equal to 65535") + } + return nil +} + +func validateAllowedSrc( + allowedSrc *balancerpb.AllowedSources, + isIPv6 bool, +) error { + for i, net := range allowedSrc.Nets { + if err := validateNet(net, isIPv6); err != nil { + return fmt.Errorf("net %x/%x at index %d: %w", net.Addr, net.Mask, i, err) + } + } + slices.SortFunc(allowedSrc.Nets, compareIPNet) + allowedSrc.Nets = slices.CompactFunc(allowedSrc.Nets, func(a, b *filterpb.IPNet) bool { + return compareIPNet(a, b) == 0 + }) + for i, port := range allowedSrc.Ports { + if err := validatePortRange(port); err != nil { + return fmt.Errorf("port range [%d-%d] at index %d: %w", port.From, port.To, i, err) + } + } + slices.SortFunc(allowedSrc.Ports, comparePortRange) + allowedSrc.Ports = slices.CompactFunc(allowedSrc.Ports, func(a, b *filterpb.PortRange) bool { + return comparePortRange(a, b) == 0 + }) + if allowedSrc.Tag != nil && len(*allowedSrc.Tag) > int(AllowedSourceMaxTagLength) { + return invalidArg( + "tag %s must be less than or equal to %d characters", + *allowedSrc.Tag, + AllowedSourceMaxTagLength, + ) + } + return nil +} + +func validateReal(r *balancerpb.Real) error { + if r.Id == nil { + return invalidArg("id is nil") + } + id := r.Id + if len(id.Ip) != 4 && len(id.Ip) != 16 { + return invalidArg("id.ip must be 4 or 16 bytes long") + } + if id.Port != 0 { + return invalidArg("only zero ports is currently supported") + } + if r.Src == nil { + return invalidArg("src is nil") + } + if len(r.Src.Addr) != len(id.Ip) { + return invalidArg("src.addr must be the same length as id.ip") + } + if len(r.Src.Mask) != len(id.Ip) { + return invalidArg("src.mask must be the same length as id.ip") + } + return nil +} + +// validateAllowedSources validates, sorts, and deduplicates the allowed sources slice. +// Side effect: sorts and compacts allowedSources. canReuseACL depends on this sort order +// to compare allowed sources element-by-element between old and new configs. +func validateAllowedSources( + allowedSources []*balancerpb.AllowedSources, + isIPv6 bool, +) ([]*balancerpb.AllowedSources, error) { + for i, allowedSrc := range allowedSources { + if allowedSrc == nil { + return nil, invalidArg("allowed_src at index %d is nil", i) + } + if err := validateAllowedSrc(allowedSrc, isIPv6); err != nil { + return nil, fmt.Errorf("allowed_src at index %d: %w", i, err) + } + } + slices.SortFunc(allowedSources, compareAllowedSourcesPb) + allowedSources = slices.CompactFunc(allowedSources, func(a, b *balancerpb.AllowedSources) bool { + return compareAllowedSourcesPb(a, b) == 0 + }) + return allowedSources, nil +} + +func validateReals(reals []*balancerpb.Real) error { + realsMap := make(map[realKey]int, len(reals)) + for i, r := range reals { + if r == nil { + return invalidArg("real at index %d is nil", i) + } + if err := validateReal(r); err != nil { + return fmt.Errorf("real %s at index %d: %w", realIDToString(r.Id), i, err) + } + key := makeRealKey(r.Id) + if prevIdx, ok := realsMap[key]; ok { + return invalidArg( + "real %s at index %d: duplicate of real at index %d", + realIDToString(r.Id), + i, + prevIdx, + ) + } + realsMap[key] = i + } + return nil +} + +func validateVS(vs *balancerpb.VirtualService) error { + if vs.Id == nil { + return invalidArg("id is nil") + } + if len(vs.Id.Addr) != 4 && len(vs.Id.Addr) != 16 { + return invalidArg("id.addr must be 4 or 16 bytes") + } + if vs.Scheduler != balancerpb.VsScheduler_SH && + vs.Scheduler != balancerpb.VsScheduler_WRR && + vs.Scheduler != balancerpb.VsScheduler_WLC { + return invalidArg("scheduler must be SH/WRR/WLC") + } + if vs.Id.Proto != balancerpb.TransportProto_TCP && + vs.Id.Proto != balancerpb.TransportProto_UDP { + return invalidArg("id.proto must be TCP or UDP") + } + if vs.Flags == nil { + return invalidArg("flags is nil") + } + if vs.Flags.PureL3 && vs.Id.Port != 0 { + return invalidArg("pure_l3 flag is set but port is not 0") + } + for i, peer := range vs.Peers { + if len(peer) != 4 && len(peer) != 16 { + return invalidArg("peer %x at index %d: addr must be 4 or 16 bytes long", peer, i) + } + } + var err error + vs.AllowedSrcs, err = validateAllowedSources(vs.AllowedSrcs, len(vs.Id.Addr) == 16) + if err != nil { + return err + } + if err := validateReals(vs.Reals); err != nil { + return err + } + return nil +} + +func validatePacketHandlerConfig(config *balancerpb.PacketHandlerConfig) error { + if len(config.SourceAddressV4) != 4 { + return invalidArg("source_address_v4 %x must be 4 bytes", config.SourceAddressV4) + } + if len(config.SourceAddressV6) != 16 { + return invalidArg("source_address_v6 %x must be 16 bytes", config.SourceAddressV6) + } + if config.SessionsTimeouts == nil { + return invalidArg("sessions_timeouts is nil") + } + if err := validateSessionsTimeouts(config.SessionsTimeouts); err != nil { + return fmt.Errorf("sessions_timeouts: %w", err) + } + for idx, addr := range config.DecapAddresses { + if len(addr) != 4 && len(addr) != 16 { + return invalidArg("decap_addresses %x at index %d: must be 4 or 16 bytes", addr, idx) + } + } + + // Sort decap addresses by family (IPv4 first, then IPv6), then by value. + // decapFiltersReusable depends on this ordering to find the IPv4/IPv6 split point. + slices.SortFunc(config.DecapAddresses, func(a, b []byte) int { + if c := cmp.Compare(len(a), len(b)); c != 0 { + return c + } + return bytes.Compare(a, b) + }) + config.DecapAddresses = slices.CompactFunc(config.DecapAddresses, bytes.Equal) + + vsMap := make(map[vsKey]int, len(config.Vs)) + for i, vs := range config.Vs { + if vs == nil { + return invalidArg("vs at index %d is nil", i) + } + if err := validateVS(vs); err != nil { + return fmt.Errorf("vs %s at index %d: %w", vsIDToString(vs.Id), i, err) + } + key := makeVsKey(vs.Id) + if prevIdx, ok := vsMap[key]; ok { + return fmt.Errorf( + "vs %s at index %d: duplicated at index %d", + vsIDToString(vs.Id), + i, + prevIdx, + ) + } + vsMap[key] = i + } + + return nil +} + +// validateBalancerConfig checks that all required fields are present +// for creating a new balancer. +func validateBalancerConfig(config *balancerpb.BalancerConfig) error { + if config == nil { + return invalidArg("config is nil") + } + if config.PacketHandler == nil { + return invalidArg("packet_handler is nil") + } + if err := validatePacketHandlerConfig(config.PacketHandler); err != nil { + return fmt.Errorf("packet_handler: %w", err) + } + if config.State == nil { + return invalidArg("state is nil") + } + if err := validateStateConfig(config.State); err != nil { + return fmt.Errorf("state: %w", err) + } + return nil +} diff --git a/modules/balancer/controlplane/validate_test.go b/modules/balancer/controlplane/validate_test.go new file mode 100644 index 000000000..fd91440bc --- /dev/null +++ b/modules/balancer/controlplane/validate_test.go @@ -0,0 +1,842 @@ +package balancer + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/protobuf/types/known/durationpb" +) + +// --------------------------------------------------------------------------- +// Test helpers — minimal valid proto objects reusable across tests. +// --------------------------------------------------------------------------- + +func makeValidReal() *balancerpb.Real { + return &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{ + Ip: []byte{10, 0, 0, 1}, + Port: 0, + }, + Weight: 1, + Src: &filterpb.IPNet{ + Addr: []byte{10, 0, 0, 0}, + Mask: []byte{255, 255, 255, 0}, + }, + } +} + +func makeValidVS() *balancerpb.VirtualService { + return &balancerpb.VirtualService{ + Id: &balancerpb.VsIdentifier{ + Addr: []byte{1, 1, 1, 1}, + Port: 80, + Proto: balancerpb.TransportProto_TCP, + }, + Scheduler: balancerpb.VsScheduler_SH, + Flags: &balancerpb.VsFlags{}, + Reals: []*balancerpb.Real{makeValidReal()}, + } +} + +func makeValidPacketHandlerConfig() *balancerpb.PacketHandlerConfig { + return &balancerpb.PacketHandlerConfig{ + SourceAddressV4: []byte{5, 5, 5, 5}, + SourceAddressV6: make([]byte, 16), + SessionsTimeouts: &balancerpb.SessionsTimeouts{ + TcpSynAck: 10, + TcpSyn: 10, + TcpFin: 10, + Tcp: 10, + Udp: 10, + }, + Vs: []*balancerpb.VirtualService{makeValidVS()}, + } +} + +func makeValidStateConfig() *balancerpb.StateConfig { + capacity := uint64(1024) + lf := float32(0.7) + power := uint64(2) + maxWeight := uint32(100) + return &balancerpb.StateConfig{ + SessionTableCapacity: &capacity, + SessionTableMaxLoadFactor: &lf, + Wlc: &balancerpb.WlcConfig{ + Power: &power, + MaxWeight: &maxWeight, + }, + RefreshPeriod: durationpb.New(1000000000), // 1s + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestValidateMask4(t *testing.T) { + tests := []struct { + name string + mask []byte + wantErr bool + }{ + {"slash-16", []byte{0xFF, 0xFF, 0x00, 0x00}, false}, + {"slash-32", []byte{0xFF, 0xFF, 0xFF, 0xFF}, false}, + {"slash-0", []byte{0x00, 0x00, 0x00, 0x00}, false}, + {"slash-24", []byte{0xFF, 0xFF, 0xFF, 0x00}, false}, + {"non-contiguous", []byte{0xFF, 0x00, 0xFF, 0x00}, true}, + {"hole-at-start", []byte{0x00, 0xFF, 0x00, 0x00}, true}, + {"single-bit-gap", []byte{0xFF, 0xFE, 0x01, 0x00}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMask4(tt.mask) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestIsContiguous8(t *testing.T) { + tests := []struct { + name string + mask []byte + want bool + }{ + {"all-ones", []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, true}, + {"all-zeros", []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, true}, + {"half-ones", []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00}, true}, + {"47-bits", []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, 0x00, 0x00}, true}, + {"non-contiguous", []byte{0xFF, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00}, false}, + {"hole-in-low", []byte{0x00, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isContiguous8(tt.mask)) + }) + } +} + +func TestValidateMask6(t *testing.T) { + tests := []struct { + name string + mask []byte + wantErr bool + }{ + { + "slash-64", + []byte{ + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + false, + }, + { + "slash-128", + []byte{ + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + }, + false, + }, + { + "slash-0", + make([]byte, 16), + false, + }, + { + "non-contiguous-high", + []byte{ + 0xFF, + 0x00, + 0xFF, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + true, + }, + { + "non-contiguous-low", + []byte{ + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0x00, + 0xFF, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMask6(tt.mask) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateNet(t *testing.T) { + tests := []struct { + name string + net *filterpb.IPNet + isV6 bool + wantErr bool + }{ + { + "valid-ipv4", + &filterpb.IPNet{Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0xFF, 0xFF, 0x00}}, + false, false, + }, + { + "valid-ipv6", + &filterpb.IPNet{ + Addr: make([]byte, 16), + Mask: []byte{ + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0xFF, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + }, + true, false, + }, + { + "ipv4-wrong-addr-len", + &filterpb.IPNet{Addr: []byte{10, 0, 0}, Mask: []byte{0xFF, 0xFF, 0xFF, 0x00}}, + false, true, + }, + { + "ipv4-wrong-mask-len", + &filterpb.IPNet{Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0xFF}}, + false, true, + }, + { + "ipv6-wrong-addr-len", + &filterpb.IPNet{Addr: make([]byte, 15), Mask: make([]byte, 16)}, + true, true, + }, + { + "ipv6-wrong-addr-len-2", + &filterpb.IPNet{Addr: make([]byte, 4), Mask: make([]byte, 16)}, + true, true, + }, + { + "ipv6-wrong-mask-len", + &filterpb.IPNet{Addr: make([]byte, 16), Mask: make([]byte, 4)}, + true, true, + }, + { + "ipv4-non-contiguous-mask", + &filterpb.IPNet{Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0x00, 0xFF, 0x00}}, + false, true, + }, + { + "ipv6-non-contiguous-mask", + &filterpb.IPNet{ + Addr: make([]byte, 16), + Mask: []byte{ + 0xFF, + 0xFE, + 0xFF, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + }, + true, true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNet(tt.net, tt.isV6) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidatePortRange(t *testing.T) { + tests := []struct { + name string + pr *filterpb.PortRange + wantErr bool + }{ + {"full-range", &filterpb.PortRange{From: 0, To: 65535}, false}, + {"single-port", &filterpb.PortRange{From: 80, To: 80}, false}, + {"normal-range", &filterpb.PortRange{From: 1024, To: 2048}, false}, + {"from-greater-than-to", &filterpb.PortRange{From: 100, To: 50}, true}, + {"to-exceeds-max", &filterpb.PortRange{From: 0, To: 65536}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePortRange(tt.pr) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateSessionsTimeouts(t *testing.T) { + t.Run("all-at-max", func(t *testing.T) { + err := validateSessionsTimeouts(&balancerpb.SessionsTimeouts{ + TcpSynAck: MaxSessionTimeout, + TcpSyn: MaxSessionTimeout, + TcpFin: MaxSessionTimeout, + Tcp: MaxSessionTimeout, + Udp: MaxSessionTimeout, + }) + require.NoError(t, err) + }) + + fields := []struct { + name string + make func() *balancerpb.SessionsTimeouts + }{ + {"tcp_syn_ack", func() *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{TcpSynAck: MaxSessionTimeout + 1} + }}, + {"tcp_syn", func() *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{TcpSyn: MaxSessionTimeout + 1} + }}, + {"tcp_fin", func() *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{TcpFin: MaxSessionTimeout + 1} + }}, + {"tcp", func() *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{Tcp: MaxSessionTimeout + 1} + }}, + {"udp", func() *balancerpb.SessionsTimeouts { + return &balancerpb.SessionsTimeouts{Udp: MaxSessionTimeout + 1} + }}, + } + for _, f := range fields { + t.Run(f.name+"-exceeds-max", func(t *testing.T) { + require.Error(t, validateSessionsTimeouts(f.make())) + }) + } +} + +func TestValidateWlcConfig(t *testing.T) { + power := uint64(2) + maxWeight := uint32(100) + + t.Run("valid", func(t *testing.T) { + require.NoError(t, validateWlcConfig(&balancerpb.WlcConfig{ + Power: &power, MaxWeight: &maxWeight, + })) + }) + t.Run("nil-power", func(t *testing.T) { + require.Error(t, validateWlcConfig(&balancerpb.WlcConfig{ + MaxWeight: &maxWeight, + })) + }) + t.Run("nil-max-weight", func(t *testing.T) { + require.Error(t, validateWlcConfig(&balancerpb.WlcConfig{ + Power: &power, + })) + }) +} + +func TestValidateStateConfig(t *testing.T) { + t.Run("valid", func(t *testing.T) { + require.NoError(t, validateStateConfig(makeValidStateConfig())) + }) + + t.Run("nil-session-table-capacity", func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.SessionTableCapacity = nil + require.Error(t, validateStateConfig(cfg)) + }) + + t.Run("zero-session-table-capacity", func(t *testing.T) { + cfg := makeValidStateConfig() + zero := uint64(0) + cfg.SessionTableCapacity = &zero + require.Error(t, validateStateConfig(cfg)) + }) + + t.Run("nil-refresh-period", func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.RefreshPeriod = nil + require.Error(t, validateStateConfig(cfg)) + }) + + t.Run("nil-session-table-max-load-factor", func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.SessionTableMaxLoadFactor = nil + require.Error(t, validateStateConfig(cfg)) + }) + + loadFactorTests := []struct { + name string + lf float32 + wantErr bool + }{ + {"zero", 0, true}, + {"half", 0.5, false}, + {"one", 1.0, false}, + {"above-one", 1.1, true}, + {"negative", -0.1, true}, + } + for _, tt := range loadFactorTests { + t.Run("load-factor-"+tt.name, func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.SessionTableMaxLoadFactor = &tt.lf + err := validateStateConfig(cfg) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } + + t.Run("nil-wlc", func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.Wlc = nil + require.Error(t, validateStateConfig(cfg)) + }) + + t.Run("invalid-wlc", func(t *testing.T) { + cfg := makeValidStateConfig() + cfg.Wlc = &balancerpb.WlcConfig{} + require.Error(t, validateStateConfig(cfg)) + }) +} + +func TestValidateReal(t *testing.T) { + t.Run("valid-ipv4", func(t *testing.T) { + require.NoError(t, validateReal(makeValidReal())) + }) + + t.Run("valid-ipv6", func(t *testing.T) { + r := &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{ + Ip: make([]byte, 16), + Port: 0, + }, + Weight: 1, + Src: &filterpb.IPNet{ + Addr: make([]byte, 16), + Mask: make([]byte, 16), + }, + } + require.NoError(t, validateReal(r)) + }) + + t.Run("nil-id", func(t *testing.T) { + r := makeValidReal() + r.Id = nil + require.Error(t, validateReal(r)) + }) + + t.Run("wrong-ip-length", func(t *testing.T) { + r := makeValidReal() + r.Id.Ip = []byte{10, 0, 0} + require.Error(t, validateReal(r)) + }) + + t.Run("non-zero-port", func(t *testing.T) { + r := makeValidReal() + r.Id.Port = 8080 + require.Error(t, validateReal(r)) + }) + + t.Run("nil-src", func(t *testing.T) { + r := makeValidReal() + r.Src = nil + require.Error(t, validateReal(r)) + }) + + t.Run("mismatched-src-addr-length", func(t *testing.T) { + r := makeValidReal() + r.Src.Addr = make([]byte, 16) // IPv6 addr but IPv4 id + require.Error(t, validateReal(r)) + }) + + t.Run("mismatched-src-mask-length", func(t *testing.T) { + r := makeValidReal() + r.Src.Mask = make([]byte, 16) // IPv6 mask but IPv4 id + require.Error(t, validateReal(r)) + }) +} + +func TestValidateReals(t *testing.T) { + t.Run("valid", func(t *testing.T) { + r1 := makeValidReal() + r2 := makeValidReal() + r2.Id.Ip = []byte{10, 0, 0, 2} + require.NoError(t, validateReals([]*balancerpb.Real{r1, r2})) + }) + + t.Run("empty", func(t *testing.T) { + require.NoError(t, validateReals([]*balancerpb.Real{})) + }) + + t.Run("nil-entry", func(t *testing.T) { + err := validateReals([]*balancerpb.Real{nil}) + require.Error(t, err) + require.Contains(t, err.Error(), "index 0") + }) + + t.Run("duplicate", func(t *testing.T) { + r1 := makeValidReal() + r2 := makeValidReal() // same IP = duplicate + err := validateReals([]*balancerpb.Real{r1, r2}) + require.Error(t, err) + }) +} + +func TestValidateAllowedSrc(t *testing.T) { + t.Run("valid-with-nets-ports-tag", func(t *testing.T) { + tag := "test" + src := &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0xFF, 0x00, 0x00}}, + }, + Ports: []*filterpb.PortRange{ + {From: 1024, To: 2048}, + }, + Tag: &tag, + } + require.NoError(t, validateAllowedSrc(src, false)) + }) + + t.Run("valid-empty", func(t *testing.T) { + require.NoError(t, validateAllowedSrc(&balancerpb.AllowedSources{}, false)) + }) + + t.Run("invalid-net", func(t *testing.T) { + src := &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: []byte{10, 0}, Mask: []byte{0xFF, 0xFF}}, // wrong length for IPv4 + }, + } + require.Error(t, validateAllowedSrc(src, false)) + }) + + t.Run("invalid-port-range", func(t *testing.T) { + src := &balancerpb.AllowedSources{ + Ports: []*filterpb.PortRange{ + {From: 100, To: 50}, + }, + } + require.Error(t, validateAllowedSrc(src, false)) + }) + + t.Run("tag-too-long", func(t *testing.T) { + longTag := strings.Repeat("a", int(AllowedSourceMaxTagLength)+1) + src := &balancerpb.AllowedSources{ + Tag: &longTag, + } + require.Error(t, validateAllowedSrc(src, false)) + }) + + t.Run("tag-at-max-length", func(t *testing.T) { + maxTag := strings.Repeat("a", int(AllowedSourceMaxTagLength)) + src := &balancerpb.AllowedSources{ + Tag: &maxTag, + } + require.NoError(t, validateAllowedSrc(src, false)) + }) +} + +func TestValidateAllowedSources(t *testing.T) { + t.Run("squashes-duplicates", func(t *testing.T) { + src := &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0xFF, 0x00, 0x00}}, + }, + } + srcDup := &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: []byte{10, 0, 0, 0}, Mask: []byte{0xFF, 0xFF, 0x00, 0x00}}, + }, + } + result, err := validateAllowedSources([]*balancerpb.AllowedSources{src, srcDup}, false) + require.NoError(t, err) + require.Len(t, result, 1) + }) + + t.Run("nil-entry", func(t *testing.T) { + _, err := validateAllowedSources([]*balancerpb.AllowedSources{nil}, false) + require.Error(t, err) + require.Contains(t, err.Error(), "index 0") + }) + + t.Run("empty-ok", func(t *testing.T) { + result, err := validateAllowedSources(nil, false) + require.NoError(t, err) + require.Empty(t, result) + }) +} + +func TestValidateVS(t *testing.T) { + t.Run("valid-ipv4", func(t *testing.T) { + require.NoError(t, validateVS(makeValidVS())) + }) + + t.Run("valid-ipv6", func(t *testing.T) { + vs := makeValidVS() + vs.Id.Addr = make([]byte, 16) + // Fix reals to match IPv6. + vs.Reals[0].Id.Ip = make([]byte, 16) + vs.Reals[0].Src = &filterpb.IPNet{ + Addr: make([]byte, 16), + Mask: make([]byte, 16), + } + require.NoError(t, validateVS(vs)) + }) + + t.Run("nil-id", func(t *testing.T) { + vs := makeValidVS() + vs.Id = nil + require.Error(t, validateVS(vs)) + }) + + t.Run("wrong-addr-length", func(t *testing.T) { + vs := makeValidVS() + vs.Id.Addr = []byte{1, 1, 1} + require.Error(t, validateVS(vs)) + }) + + t.Run("invalid-proto", func(t *testing.T) { + vs := makeValidVS() + vs.Id.Proto = balancerpb.TransportProto(99) + require.Error(t, validateVS(vs)) + }) + + t.Run("nil-flags", func(t *testing.T) { + vs := makeValidVS() + vs.Flags = nil + require.Error(t, validateVS(vs)) + }) + + t.Run("pure-l3-with-port", func(t *testing.T) { + vs := makeValidVS() + vs.Flags.PureL3 = true + vs.Id.Port = 80 + require.Error(t, validateVS(vs)) + }) + + t.Run("pure-l3-port-zero", func(t *testing.T) { + vs := makeValidVS() + vs.Flags.PureL3 = true + vs.Id.Port = 0 + require.NoError(t, validateVS(vs)) + }) + + t.Run("invalid-scheduler", func(t *testing.T) { + vs := makeValidVS() + vs.Scheduler = balancerpb.VsScheduler(99) + require.Error(t, validateVS(vs)) + }) + + t.Run("round-robin-scheduler", func(t *testing.T) { + vs := makeValidVS() + vs.Scheduler = balancerpb.VsScheduler_WRR + require.NoError(t, validateVS(vs)) + }) + + t.Run("bad-peer-length", func(t *testing.T) { + vs := makeValidVS() + vs.Peers = [][]byte{{1, 2, 3}} // not 4 or 16 + require.Error(t, validateVS(vs)) + }) + + t.Run("valid-peers", func(t *testing.T) { + vs := makeValidVS() + vs.Peers = [][]byte{{1, 1, 1, 1}, make([]byte, 16)} + require.NoError(t, validateVS(vs)) + }) +} + +func TestValidatePacketHandlerConfig(t *testing.T) { + t.Run("valid", func(t *testing.T) { + require.NoError(t, validatePacketHandlerConfig(makeValidPacketHandlerConfig())) + }) + + t.Run("wrong-source-v4-length", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.SourceAddressV4 = []byte{1, 2, 3} + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("wrong-source-v6-length", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.SourceAddressV6 = []byte{1, 2, 3} + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("nil-sessions-timeouts", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.SessionsTimeouts = nil + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("invalid-sessions-timeouts", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.SessionsTimeouts.TcpSynAck = MaxSessionTimeout + 1 + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("invalid-decap-address-length", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.DecapAddresses = [][]byte{{1, 2, 3}} // not 4 or 16 + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("decap-addresses-sorted", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + v6 := make([]byte, 16) + v6[0] = 0xFE + v4 := []byte{10, 0, 0, 1} + cfg.DecapAddresses = [][]byte{v6, v4} // v6 first + require.NoError(t, validatePacketHandlerConfig(cfg)) + // After validation, v4 should sort before v6. + require.Len(t, cfg.DecapAddresses[0], 4) + require.Len(t, cfg.DecapAddresses[1], 16) + }) + + t.Run("duplicate-vs", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.Vs = append(cfg.Vs, makeValidVS()) // same VS id + require.Error(t, validatePacketHandlerConfig(cfg)) + }) + + t.Run("nil-vs-entry", func(t *testing.T) { + cfg := makeValidPacketHandlerConfig() + cfg.Vs = []*balancerpb.VirtualService{nil} + require.Error(t, validatePacketHandlerConfig(cfg)) + }) +} + +func TestValidateBalancerConfig(t *testing.T) { + t.Run("valid", func(t *testing.T) { + cfg := &balancerpb.BalancerConfig{ + PacketHandler: makeValidPacketHandlerConfig(), + State: makeValidStateConfig(), + } + require.NoError(t, validateBalancerConfig(cfg)) + }) + + t.Run("nil-config", func(t *testing.T) { + require.Error(t, validateBalancerConfig(nil)) + }) + + t.Run("nil-packet-handler", func(t *testing.T) { + cfg := &balancerpb.BalancerConfig{ + State: makeValidStateConfig(), + } + require.Error(t, validateBalancerConfig(cfg)) + }) + + t.Run("nil-state", func(t *testing.T) { + cfg := &balancerpb.BalancerConfig{ + PacketHandler: makeValidPacketHandlerConfig(), + } + require.Error(t, validateBalancerConfig(cfg)) + }) + + t.Run("invalid-packet-handler", func(t *testing.T) { + cfg := &balancerpb.BalancerConfig{ + PacketHandler: &balancerpb.PacketHandlerConfig{}, + State: makeValidStateConfig(), + } + require.Error(t, validateBalancerConfig(cfg)) + }) + + t.Run("invalid-state", func(t *testing.T) { + cfg := &balancerpb.BalancerConfig{ + PacketHandler: makeValidPacketHandlerConfig(), + State: &balancerpb.StateConfig{}, + } + require.Error(t, validateBalancerConfig(cfg)) + }) +} diff --git a/modules/balancer/controlplane/vs.go b/modules/balancer/controlplane/vs.go new file mode 100644 index 000000000..31332d78e --- /dev/null +++ b/modules/balancer/controlplane/vs.go @@ -0,0 +1,661 @@ +package balancer + +import ( + "bytes" + "fmt" + "net" + "strconv" + "time" + + "github.com/yanet-platform/yanet2/common/commonpb" + "github.com/yanet-platform/yanet2/common/go/relptr" + yanet "github.com/yanet-platform/yanet2/controlplane/ffi" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// vsKey is a hashable identifier for a virtual service, usable as a map key. +type vsKey struct { + addr [16]byte + addrLen uint8 + port uint16 + proto uint8 +} + +// vsSlot tracks the position of a virtual service and its reals in the +// shared-memory arrays, enabling O(1) lookups by identity. +type vsSlot struct { + // Position of this VS in the packet handler's VS array. + index int + // Maps real server identity to its position in this VS's reals array. + realSlots map[realKey]int +} + +func makeVsKey(id *balancerpb.VsIdentifier) vsKey { + var k vsKey + k.addrLen = uint8(len(id.Addr)) + copy(k.addr[:], id.Addr) + k.port = uint16(id.Port) + k.proto = transportProtoToC(id.Proto) + return k +} + +func (vs *VS) key() vsKey { + var k vsKey + k.addrLen = 4 + if vs.Ip_proto == ipprotoIPv6 { + k.addrLen = 16 + } + copy(k.addr[:], vs.Addr.Bytes(int(vs.Ip_proto))) + k.port = uint16(vs.Port) + k.proto = uint8(vs.Transport_proto) + return k +} + +func (vs *VS) free(agent *Agent) { + yanetAgent := agent.AsYanetAgent() + + vs.freeACL(agent) + vs.freeRealSelector(agent) + vs.freeSessionTrackers(agent) + + // Rule counter IDs. + ruleCounterIDs := relptr.Slice(&vs.Rule_counter_ids, vs.Allowed_sources_count) + yanet.FreeSlice(yanetAgent, ruleCounterIDs) + + // Allowed sources.t + allowedSources := relptr.Slice(&vs.Allowed_sources, vs.Allowed_sources_count) + freeAllowedSources(agent, allowedSources) + yanet.FreeSlice(yanetAgent, allowedSources) + + // Peers. + peersV4 := relptr.Slice(&vs.Peers_v4, vs.Peers_v4_count) + yanet.FreeSlice(yanetAgent, peersV4) + peersV6 := relptr.Slice(&vs.Peers_v6, vs.Peers_v6_count) + yanet.FreeSlice(yanetAgent, peersV6) + + // Reals. + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + yanet.FreeSlice(yanetAgent, reals) +} + +// populateAllowedSources allocates and fills the allowed sources array for a VS. +func (vs *VS) populateAllowedSources( + agent *Agent, + pbSources []*balancerpb.AllowedSources, +) error { + count := len(pbSources) + if count == 0 { + return nil + } + + srcs := yanet.AllocSlice[AllowedSource](agent.AsYanetAgent(), count) + if srcs == nil { + return errNoAgentMemory + } + + // Explicit zero-init for safe freeing. + for idx := range srcs { + srcs[idx] = AllowedSource{} + } + + for i, pb := range pbSources { + if err := srcs[i].populate(agent, pb); err != nil { + freeAllowedSources(agent, srcs) + yanet.FreeSlice(agent.AsYanetAgent(), srcs) + return err + } + } + + relptr.SetSlice(&vs.Allowed_sources, srcs) + vs.Allowed_sources_count = uint32(count) + + return nil +} + +func freeAllowedSources(agent *Agent, allowedSources []AllowedSource) { + for i := range allowedSources { + allowedSources[i].free(agent) + } +} + +func (as *AllowedSource) free(agent *Agent) { + yanetAgent := agent.AsYanetAgent() + + // Nets. + nets := relptr.Slice(&as.Nets, as.Nets_count) + yanet.FreeSlice(yanetAgent, nets) + + // Port ranges. + portRanges := relptr.Slice(&as.Port_ranges, as.Port_ranges_count) + yanet.FreeSlice(yanetAgent, portRanges) +} + +func (as *AllowedSource) populate(agent *Agent, pbSrc *balancerpb.AllowedSources) error { + // Nets. + nets := pbSrc.Nets + if len(nets) > 0 { + netSlice := yanet.AllocSlice[Net](agent.AsYanetAgent(), len(nets)) + if netSlice == nil { + return errNoAgentMemory + } + relptr.SetSlice(&as.Nets, netSlice) + for j, n := range nets { + writeNet(&netSlice[j], n) + } + as.Nets_count = uint32(len(nets)) + } + + // Port ranges. + ports := pbSrc.Ports + if len(ports) > 0 { + prSlice := yanet.AllocSlice[PortRange](agent.AsYanetAgent(), len(ports)) + if prSlice == nil { + return errNoAgentMemory + } + relptr.SetSlice(&as.Port_ranges, prSlice) + for j, p := range ports { + prSlice[j].From = uint16(p.From) + prSlice[j].To = uint16(p.To) + } + as.Port_ranges_count = uint32(len(ports)) + } + + // Tag (null-terminated C string in fixed-size buffer). + if pbSrc.Tag != nil { + tag := *pbSrc.Tag + for i := range tag { + as.Tag[i] = int8(tag[i]) + } + as.Tag[len(tag)] = 0 + } else { + as.Tag[0] = 0 + } + + return nil +} + +// populatePeers allocates and fills the peer address arrays for a VS. +func (vs *VS) populatePeers(agent *Agent, peers [][]byte) error { + v4Count := 0 + v6Count := 0 + for _, addr := range peers { + if len(addr) == 4 { + v4Count++ + } else { + v6Count++ + } + } + + // First, allocate slice. Then, set slice and count in the shared memory atomically. + if v4Count > 0 { + slice := yanet.AllocSlice[Net4Addr](agent.AsYanetAgent(), v4Count) + if slice == nil { + return errNoAgentMemory + } + j := 0 + for _, addr := range peers { + if len(addr) == 4 { + writeNet4Addr(&slice[j], addr) + j++ + } + } + relptr.SetSlice(&vs.Peers_v4, slice) + vs.Peers_v4_count = uint32(v4Count) + } + + if v6Count > 0 { + slice := yanet.AllocSlice[Net6Addr](agent.AsYanetAgent(), v6Count) + if slice == nil { + return errNoAgentMemory + } + j := 0 + for _, addr := range peers { + if len(addr) == 16 { + writeNet6Addr(&slice[j], addr) + j++ + } + } + relptr.SetSlice(&vs.Peers_v6, slice) + vs.Peers_v6_count = uint32(v6Count) + } + + return nil +} + +// placeExistingVS places virtual services that exist in both the previous and new configs +// into their original slot positions in targetVs. Each placed VS is deleted from vsMap, +// so after this call vsMap contains only genuinely new VSes for placeNewVS to handle. +func placeExistingVS( + ph *PacketHandler, + agent *Agent, + pbVS []*balancerpb.VirtualService, + targetVs []VS, + prevVs []VS, + vsMap map[vsKey]int, + reuseReport *balancerpb.ReuseReport, +) (oldIPv4VsMatches bool, oldIPv6VsMatches bool, err error) { + oldIPv4VsMatches = true + oldIPv6VsMatches = true + + for idx := range prevVs { + prev := &prevVs[idx] + if prev.isRemoved() { + continue + } + + k := prev.key() + pbIdx, ok := vsMap[k] + if !ok { + // This VS not present in new config + if k.addrLen == 4 { + oldIPv4VsMatches = false + } else { + oldIPv6VsMatches = false + } + continue + } + + delete(vsMap, k) + + target := &targetVs[idx] + target.Flags &^= uint16(VSFlagRemoved) + report, err := target.populate(agent, pbVS[pbIdx], prev.Stable_idx, prev, ph) + if err != nil { + return false, false, fmt.Errorf("vs %s: %w", prev, err) + } + + reuseReport.VsReuseReports = append(reuseReport.VsReuseReports, report) + } + + return oldIPv4VsMatches, oldIPv6VsMatches, nil +} + +// placeNewVS places virtual services that are new in the config (remaining in vsMap +// after placeExistingVS) into removed (empty) slots in targetVs. +// Invariant: there are always enough removed slots because targetVs has len(vsList) slots, +// placeExistingVS and placeNewVS together account for exactly len(vsList) entries, +// and each entry fills exactly one slot. +func placeNewVS( + ph *PacketHandler, + agent *Agent, + vsList []*balancerpb.VirtualService, + targetVs []VS, + prevVs []VS, + vsMap map[vsKey]int, + reuseReport *balancerpb.ReuseReport, +) (noNewIPv4Vs bool, noNewIPv6Vs bool, err error) { + noNewIPv4Vs = true + noNewIPv6Vs = true + + nextRemoved := 0 + for idx, vs := range vsList { + k := makeVsKey(vs.Id) + if _, ok := vsMap[k]; !ok { + continue + } + + if k.addrLen == 4 { + noNewIPv4Vs = false + } else { + noNewIPv6Vs = false + } + + for !targetVs[nextRemoved].isRemoved() { + nextRemoved++ + } + + epoch := uint32(0) + if nextRemoved < len(prevVs) { + epoch = prevVs[nextRemoved].epoch() + 1 + } + stableIdx := makeStableIdx(epoch, uint32(nextRemoved)) + + target := &targetVs[nextRemoved] + target.Flags &^= uint16(VSFlagRemoved) + report, err := target.populate(agent, vsList[idx], stableIdx, nil, ph) + if err != nil { + return false, false, fmt.Errorf("vs %s: %w", vsIDToString(vs.Id), err) + } + + nextRemoved++ + + reuseReport.VsReuseReports = append(reuseReport.VsReuseReports, report) + } + + return noNewIPv4Vs, noNewIPv6Vs, nil +} + +// protoVsFlagsToC converts protobuf VsFlags to the C bit field value. +func protoVsFlagsToC(f *balancerpb.VsFlags, s balancerpb.VsScheduler) uint16 { + var flags uint16 + if f.PureL3 { + flags |= VSFlagPureL3 + } + if f.FixMss { + flags |= VSFlagFixMSS + } + if f.Gre { + flags |= VSFlagGRE + } + if f.Ops { + flags |= VSFlagOPS + } + + switch s { + case balancerpb.VsScheduler_WLC: + flags |= VSFlagWLC + flags |= VSFlagRoundRobin + case balancerpb.VsScheduler_WRR: + flags |= VSFlagRoundRobin + } + + return flags +} + +func allowedSourcesEqual( + prevAllowedSrc *AllowedSource, + curAllowedSrc *balancerpb.AllowedSources, + ipproto int, +) bool { + prevNets := relptr.Slice(&prevAllowedSrc.Nets, prevAllowedSrc.Nets_count) + curNets := curAllowedSrc.Nets + if len(prevNets) != len(curNets) { + return false + } + + for i := range prevAllowedSrc.Nets_count { + prevNet := &prevNets[i] + curNet := curNets[i] + if !bytes.Equal(prevNet.AddrBytes(ipproto), curNet.Addr) { + return false + } + if !bytes.Equal(prevNet.MaskBytes(ipproto), curNet.Mask) { + return false + } + } + + prevPr := relptr.Slice(&prevAllowedSrc.Port_ranges, prevAllowedSrc.Port_ranges_count) + curPr := curAllowedSrc.Ports + if len(prevPr) != len(curPr) { + return false + } + for i := range prevAllowedSrc.Port_ranges_count { + prevPortRange := &prevPr[i] + curPortRange := curPr[i] + if prevPortRange.From != uint16(curPortRange.From) || + prevPortRange.To != uint16(curPortRange.To) { + return false + } + } + return true +} + +// canReuseACL checks whether the previous VS's compiled ACL can be reused. +// Precondition: allowed sources in both prev and new config must be in the same sort order. +// This is guaranteed by validatePacketHandlerConfig -> validateVS -> validateAllowedSources. +func canReuseACL(prevVs *VS, pbVs *balancerpb.VirtualService) bool { + if prevVs == nil { + return false + } + prevAllowedSrcs := relptr.Slice(&prevVs.Allowed_sources, prevVs.Allowed_sources_count) + curAllowedSrcs := pbVs.AllowedSrcs + if len(prevAllowedSrcs) != len(curAllowedSrcs) { + return false + } + for i := range prevVs.Allowed_sources_count { + prevAllowedSrc := &prevAllowedSrcs[i] + curAllowedSrc := pbVs.AllowedSrcs[i] + if !allowedSourcesEqual(prevAllowedSrc, curAllowedSrc, int(prevVs.Ip_proto)) { + return false + } + } + return true +} + +func (vs *VS) isWLC() bool { + return vs.Flags&VSFlagWLC != 0 +} + +func (vs *VS) populateReals( + agent *Agent, + pbReals []*balancerpb.Real, + prevVs *VS, +) (reuseSelector bool, err error) { + wlcChanged := prevVs != nil && vs.isWLC() != prevVs.isWLC() + wrrChanged := prevVs != nil && vs.isWRR() != prevVs.isWRR() + + var prevReals []Real + if prevVs != nil { + prevReals = relptr.Slice(&prevVs.Reals, prevVs.Reals_count) + } + + realSlotCount := max(len(pbReals), len(prevReals)) + newReals := yanet.AllocSlice[Real](agent.AsYanetAgent(), realSlotCount) + if newReals == nil { + return false, errNoAgentMemory + } + for idx := range newReals { + stableIdx := uint64(0) + if idx < len(prevReals) { + stableIdx = prevReals[idx].Stable_idx + } + newReals[idx] = Real{ + Flags: RealFlagRemoved, + Stable_idx: stableIdx, + } + } + + pbRealIndex := make(map[realKey]int, len(pbReals)) + for idx := range pbReals { + k := makeRealKey(pbReals[idx].Id) + pbRealIndex[k] = idx + } + + inheritEffectiveWeights := !wlcChanged + inheritWRR := !wrrChanged + + prevRealsUnchanged := placeExistingReals( + pbReals, + newReals, + prevReals, + pbRealIndex, + inheritEffectiveWeights, + ) + noNewReals := placeNewReals( + pbReals, + newReals, + prevReals, + pbRealIndex, + ) + + vs.Reals_count = uint32(len(newReals)) + relptr.SetSlice(&vs.Reals, newReals) + + // The real selector can be reused only when all four conditions hold: + // 1. prevRealsUnchanged: old reals were not changed. + // 2. newRealsUnchanged: no new reals were placed. + // 3. inheritEffectiveWeights: WLC was not just changed, so effective weights + // were inherited from the previous config. + // 4. inheritWRR: WRR flag was not just changed, so selector logic inherited. + return prevRealsUnchanged && noNewReals && inheritEffectiveWeights && inheritWRR, nil +} + +func (vs *VS) populate( + agent *Agent, + pb *balancerpb.VirtualService, + stableIdx uint64, + prevVs *VS, + handler *PacketHandler, +) (*balancerpb.VsReuseReport, error) { + vs.Stable_idx = stableIdx + + // Set identifier fields. The caller must have cleared VSFlagRemoved before calling; + // vs.Flags is expected to have no other flags set at this point. + writeNetAddr(&vs.Addr, pb.Id.Addr) + vs.Port = uint16(pb.Id.Port) + vs.Transport_proto = transportProtoToC(pb.Id.Proto) + vs.Ip_proto = ipprotoIP + if len(pb.Id.Addr) == 16 { + vs.Ip_proto = ipprotoIPv6 + } + vs.Flags |= protoVsFlagsToC(pb.Flags, pb.Scheduler) + if pb.Scheduler == balancerpb.VsScheduler_WRR { + vs.Flags |= VSFlagRoundRobin + } + + if err := vs.populateAllowedSources(agent, pb.AllowedSrcs); err != nil { + return nil, err + } + + if err := vs.populatePeers(agent, pb.Peers); err != nil { + return nil, err + } + + reuseSelector, err := vs.populateReals(agent, pb.Reals, prevVs) + if err != nil { + return nil, err + } + + if err := vs.setSessionsTrackers(agent); err != nil { + return nil, err + } + + if reuseSelector { + relptr.Equate(&vs.Selector, &prevVs.Selector) + } else if err := vs.updateRealSelector(&handler.Rcu, agent); err != nil { + return nil, err + } + + reuseACL := canReuseACL(prevVs, pb) + if reuseACL { + relptr.Equate(&vs.Acl, &prevVs.Acl) + } else if err := vs.setACL(agent); err != nil { + return nil, err + } + + return &balancerpb.VsReuseReport{ + VsIdentifier: pb.Id, + AclReused: reuseACL, + SelectorReused: reuseSelector, + }, nil +} + +func (vs *VS) epoch() uint32 { + return epochOf(vs.Stable_idx) +} + +func (vs *VS) isRemoved() bool { + return vs.Flags&VSFlagRemoved != 0 +} + +func (vs *VS) id() *balancerpb.VsIdentifier { + return &balancerpb.VsIdentifier{ + Addr: vs.Addr.Bytes(int(vs.Ip_proto)), + Port: uint32(vs.Port), + Proto: transportProtoToPB(vs.Transport_proto), + } +} + +func transportProtoToPB(proto uint8) balancerpb.TransportProto { + protoPB := balancerpb.TransportProto_TCP + if proto == ipprotoUDP { + protoPB = balancerpb.TransportProto_UDP + } + return protoPB +} + +func (vs *VS) flags() *balancerpb.VsFlags { + return &balancerpb.VsFlags{ + PureL3: vs.Flags&VSFlagPureL3 != 0, + FixMss: vs.Flags&VSFlagFixMSS != 0, + Gre: vs.Flags&VSFlagGRE != 0, + Ops: vs.Flags&VSFlagOPS != 0, + } +} + +func (vs *VS) isWRR() bool { + return vs.Flags&VSFlagRoundRobin != 0 +} + +func (vs *VS) scheduler() balancerpb.VsScheduler { + if vs.isWLC() { + return balancerpb.VsScheduler_WLC + } + if vs.isWRR() { + return balancerpb.VsScheduler_WRR + } + return balancerpb.VsScheduler_SH +} + +func (vs *VS) state(workers uint32, now time.Time) *balancerpb.VsState { + reals := relptr.Slice(&vs.Reals, vs.Reals_count) + activeSessions := uint64(0) + lastPacketTimestamp := time.Unix(0, 0) + realsState := make([]*balancerpb.RealState, len(reals)) + for realIdx := range reals { + if reals[realIdx].isRemoved() { + continue + } + r := reals[realIdx].state(workers, now) + if r.LastPacketTimestamp.AsTime().After(lastPacketTimestamp) { + lastPacketTimestamp = r.LastPacketTimestamp.AsTime() + } + activeSessions += r.ActiveSessions + realsState[realIdx] = r + } + isV6 := vs.Ip_proto == ipprotoIPv6 + vsState := &balancerpb.VsState{ + Id: vs.id(), + Flags: vs.flags(), + Scheduler: vs.scheduler(), + Reals: realsState, + ActiveSessions: activeSessions, + LastPacketTimestamp: timestamppb.New(lastPacketTimestamp), + AllowedSrcsConfig: restoreAllowedSources(vs, isV6), + Peers: restorePeers(vs), + } + return vsState +} + +func formatVS(proto balancerpb.TransportProto, addr []byte, port uint32) string { + protoStr := "TCP" + if proto == balancerpb.TransportProto_UDP { + protoStr = "UDP" + } + addrStr := net.IP(addr).String() + if len(addr) == 16 { + addrStr = fmt.Sprintf("[%s]", addrStr) + } + return fmt.Sprintf("%s:%d/%s", addrStr, port, protoStr) +} + +func vsIDToString(id *balancerpb.VsIdentifier) string { + return formatVS(id.Proto, id.Addr, id.Port) +} + +func (vs *VS) String() string { + return formatVS( + transportProtoToPB(vs.Transport_proto), + vs.Addr.Bytes(int(vs.Ip_proto)), + uint32(vs.Port), + ) +} + +func (vs *VS) labels() []*commonpb.Label { + labels := make([]*commonpb.Label, 0, 3) + + vip := vs.Addr.Bytes(int(vs.Ip_proto)) + labels = append(labels, &commonpb.Label{Name: "vip", Value: net.IP(vip).String()}) + + port := vs.Port + labels = append(labels, &commonpb.Label{Name: "vs_port", Value: strconv.Itoa(int(port))}) + + proto := "UDP" + if vs.Transport_proto == ipprotoTCP { + proto = "TCP" + } + labels = append(labels, &commonpb.Label{Name: "proto", Value: proto}) + + return labels +} diff --git a/modules/balancer/dataplane/active_sessions.h b/modules/balancer/dataplane/active_sessions.h deleted file mode 100644 index 51a3f46cc..000000000 --- a/modules/balancer/dataplane/active_sessions.h +++ /dev/null @@ -1,83 +0,0 @@ -#pragma once - -#include "interval_counter.h" -#include - -/* - * The underlying rt_interval_counter ring has size 8, so the tick - * distance (until_tick - now_tick) must be < 8. With precision=16 - * this means the session timeout must satisfy: - * (ts + timeout + 15)/16 - ts/16 < 8 - * For ts=0: (timeout + 15)/16 < 8 => timeout < 113. - * Safe timeouts: 16, 32, 48, 64, 80, 96, 112. - */ - -#define ACTIVE_SESSIONS_TRACKER_MAX_TIMEOUT 100 -#define ACTIVE_SESSIONS_TRACKER_PRECISION 16 - -/* - * Per-worker active-session tracker. - * - * Session lifetimes are rounded to `ACTIVE_SESSIONS_TRACKER_PRECISION` - * ticks and accumulated through [`struct - * rt_interval_counter`](modules/balancer/dataplane/interval_counter.h:19). - */ -struct active_sessions_tracker_shard { - struct rt_interval_counter counter; - uint32_t count; - uint32_t last_packet_timestamp; -} __attribute__((aligned(64))); - -/* Convert a packet timestamp to the current tracker tick. */ -static inline uint32_t -active_sessions_tracker_now(uint32_t timestamp) { - return timestamp / ACTIVE_SESSIONS_TRACKER_PRECISION; -} - -/* Round a packet timestamp up to the tick where the session expires. */ -static inline uint32_t -active_sessions_tracker_until(uint32_t timestamp) { - return (timestamp + ACTIVE_SESSIONS_TRACKER_PRECISION - 1) / - ACTIVE_SESSIONS_TRACKER_PRECISION; -} - -/* Account for a newly created session on the selected worker shard. */ -static inline void -active_sessions_tracker_new_session( - struct active_sessions_tracker_shard *tracker_shards, - uint32_t worker_idx, - uint32_t now, - uint32_t timeout -) { - struct active_sessions_tracker_shard *shard = - &tracker_shards[worker_idx]; - shard->count += rt_interval_counter_make( - &shard->counter, - active_sessions_tracker_now(now), - active_sessions_tracker_until(now + timeout) - ); - shard->last_packet_timestamp = now; -} - -/* Extend an existing session and move its scheduled expiration. */ -static inline void -active_sessions_tracker_prolong_session( - struct active_sessions_tracker_shard *tracker_shards, - uint32_t worker_idx, - uint32_t last_packet_timestamp, - uint32_t prev_timeout, - uint32_t now, - uint32_t new_timeout -) { - struct active_sessions_tracker_shard *shard = - &tracker_shards[worker_idx]; - shard->count += rt_interval_counter_prolong( - &shard->counter, - active_sessions_tracker_now(now), - active_sessions_tracker_until( - last_packet_timestamp + prev_timeout - ), - active_sessions_tracker_until(now + new_timeout) - ); - shard->last_packet_timestamp = now; -} diff --git a/modules/balancer/dataplane/checksum.h b/modules/balancer/dataplane/checksum.h deleted file mode 100644 index d4a03c756..000000000 --- a/modules/balancer/dataplane/checksum.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -//////////////////////////////////////////////////////////////////////////////// - -static inline uint16_t -csum_plus(uint16_t val0, uint16_t val1) { - uint16_t sum = val0 + val1; - - if (sum < val0) { - ++sum; - } - - return sum; -} - -static inline uint16_t -csum_minus(uint16_t val0, uint16_t val1) { - uint16_t sum = val0 - val1; - - if (sum > val0) { - --sum; - } - - return sum; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/context.h b/modules/balancer/dataplane/context.h new file mode 100644 index 000000000..a804684b6 --- /dev/null +++ b/modules/balancer/dataplane/context.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +struct packet_front; +struct balancer_packet_handler; +struct counter_storage; + +struct worker_context { + struct packet_front *packet_front; + struct balancer_packet_handler *packet_handler; + struct counter_storage *counter_storage; + uint32_t worker_idx; + + struct balancer_common_stats *common_stats; + struct balancer_l4_stats *l4_stats; + struct balancer_icmp_stats *icmp_v4_stats; + struct balancer_icmp_stats *icmp_v6_stats; + + /* Current time in seconds. */ + uint32_t now; +}; \ No newline at end of file diff --git a/modules/balancer/dataplane/dataplane.c b/modules/balancer/dataplane/dataplane.c index 49efa6767..47e4fee97 100644 --- a/modules/balancer/dataplane/dataplane.c +++ b/modules/balancer/dataplane/dataplane.c @@ -1,114 +1,148 @@ #include #include -#include -#include -#include -#include - -#include "flow/setup.h" -#include "flow/stats.h" +#include +#include #include "common/memory_address.h" +#include "lib/counters/counters.h" #include "lib/dataplane/config/zone.h" +#include "lib/dataplane/module/packet_front.h" +#include "lib/dataplane/pipeline/econtext.h" + +#include "types/stats.h" +#include "context.h" #include "dataplane.h" -#include "decap.h" -#include "handler/handler.h" #include "icmp/handle.h" #include "l4/handle.h" -#include "worker.h" +#define MAX_BATCH_SIZE 64 -//////////////////////////////////////////////////////////////////////////////// +typedef void (*batch_handler)( + struct worker_context *context, + struct packet **packets, + size_t packets_count +); -struct balancer_module { - struct module module; +struct packet_batcher { + struct packet *packets[MAX_BATCH_SIZE]; + size_t count; + batch_handler handler; }; -static inline int -packet_ctx_try_decap(struct packet_ctx *ctx) { - return try_decap(ctx); -} - -void -handle_batch(size_t packets_count) { - assert(packets_count <= batch_size); - - if (unlikely(packets_count == 0)) { - return; +static void +batcher_add( + struct packet_batcher *batcher, + struct worker_context *context, + struct packet *packet +) { + batcher->packets[batcher->count++] = packet; + if (batcher->count == MAX_BATCH_SIZE) { + batcher->handler(context, batcher->packets, batcher->count); + batcher->count = 0; } +} - // first, handle icmp packets - for (size_t i = 0; i < packets_count; ++i) { - struct packet_ctx *ctx = &packet_ctxs[i]; - uint16_t packet_type = ctx->packet->transport_header.type; - if (!ctx->processed && (packet_type == IPPROTO_ICMP || - packet_type == IPPROTO_ICMPV6)) { - handle_icmp_packet(ctx); - } +static void +batcher_flush(struct packet_batcher *batcher, struct worker_context *context) { + if (batcher->count > 0) { + batcher->handler(context, batcher->packets, batcher->count); + batcher->count = 0; } +} - // handle TCP and UDP packets - handle_l4_packets(packet_ctxs, packets_count); +static uint64_t * +context_get_counter(struct worker_context *context, uint64_t counter_id) { + return counter_get_address( + counter_id, context->worker_idx, context->counter_storage + ); } -void -balancer_handle_packets( +static void +build_context( + struct worker_context *ctx, struct dp_worker *dp_worker, struct module_ectx *module_ectx, struct packet_front *packet_front ) { - // Get balancer module config as container of provided cp_module. - struct packet_handler *handler = container_of( + struct balancer_packet_handler *packet_handler = container_of( ADDR_OF(&module_ectx->cp_module), - struct packet_handler, + struct balancer_packet_handler, cp_module ); + ctx->packet_handler = packet_handler; + ctx->packet_front = packet_front; + ctx->counter_storage = ADDR_OF(&module_ectx->counter_storage); + ctx->worker_idx = dp_worker->idx; + ctx->now = dp_worker->current_time / (1000 * 1000 * 1000); /* ns -> s */ + + ctx->common_stats = (struct balancer_common_stats *)context_get_counter( + ctx, packet_handler->common_counter_id + ); + ctx->icmp_v4_stats = (struct balancer_icmp_stats *)context_get_counter( + ctx, packet_handler->icmp_v4_counter_id + ); + ctx->icmp_v6_stats = (struct balancer_icmp_stats *)context_get_counter( + ctx, packet_handler->icmp_v6_counter_id + ); + ctx->l4_stats = (struct balancer_l4_stats *)context_get_counter( + ctx, packet_handler->l4_counter_id + ); +} - // Get current time in seconds. - uint32_t now = dp_worker->current_time / (1000 * 1000 * 1000); +void +balancer_handle_packets( + struct dp_worker *dp_worker, + struct module_ectx *module_ectx, + struct packet_front *packet_front +) { + struct worker_context context; + build_context(&context, dp_worker, module_ectx, packet_front); + + /* + * Classify incoming packets by protocol (L4/ICMP) and IP version + * (IPv4/IPv6), accumulating them into per-category batches. + * + * Batched processing allows the filter engine to evaluate multiple + * packets at once, which is significantly faster than one-at-a-time + * lookups due to memory prefetching and reduced per-packet overhead. + * + * Batches are flushed when they reach MAX_BATCH_SIZE or when all + * input packets have been classified. + */ + enum { l4_ipv4, l4_ipv6, icmp_ipv4, icmp_ipv6, batcher_count }; + struct packet_batcher batchers[batcher_count] = { + [l4_ipv4] = {.handler = balancer_handle_l4_ipv4}, + [l4_ipv6] = {.handler = balancer_handle_l4_ipv6}, + [icmp_ipv4] = {.handler = balancer_handle_icmp_ipv4}, + [icmp_ipv6] = {.handler = balancer_handle_icmp_ipv6}, + }; + + context.common_stats->incoming_packets += packet_front->input.count; - // setup packet ctxs and try to decap packets - // handle packets by batches - size_t packets_count = 0; struct packet *packet; while ((packet = packet_list_pop(&packet_front->input)) != NULL) { - struct packet_ctx *ctx = &packet_ctxs[packets_count++]; - packet_ctx_setup( - ctx, now, dp_worker, module_ectx, handler, packet_front - ); - - // Set incoming packet - packet_ctx_set_packet(ctx, packet); - - // Update module common stats - packet_ctx_update_common_stats_on_incoming_packet(ctx); - - // Try decap packet if its destination - // is from the balancer decap list. - // - // If packet dst is from the destination list - // and decap failed, drop packet. - if (packet_ctx_try_decap(ctx) != 0) { - packet_ctx_drop_packet(ctx); - continue; - } - - // batch is full - if (packets_count == batch_size) { - // handle batch of packets - handle_batch(packets_count); - packets_count = 0; - } + context.common_stats->incoming_bytes += packet->mbuf->pkt_len; + + int is_ipv6 = packet->network_header.type == + rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6); + int is_icmp = packet->transport_header.type == IPPROTO_ICMP || + packet->transport_header.type == IPPROTO_ICMPV6; + int idx = is_icmp * 2 + is_ipv6; + batcher_add(&batchers[idx], &context, packet); } - // if there are some unhandled packets - // in the last batch - handle_batch(packets_count); + for (int i = 0; i < batcher_count; ++i) { + batcher_flush(&batchers[i], &context); + } } +struct balancer_module { + struct module module; +}; + struct module * new_module_balancer() { struct balancer_module *module = diff --git a/modules/balancer/dataplane/dataplane.h b/modules/balancer/dataplane/dataplane.h index 180e03bf1..fb85ad342 100644 --- a/modules/balancer/dataplane/dataplane.h +++ b/modules/balancer/dataplane/dataplane.h @@ -1,6 +1,70 @@ #pragma once -#include "dataplane/module/module.h" +#include + +#include + +#include "lib/controlplane/config/cp_module.h" +#include "lib/dataplane/module/module.h" + +#include "common/network.h" + +#include "types/session.h" + +struct balancer_session_table; +struct balancer_vs; struct module * -new_module_balancer(); \ No newline at end of file +new_module_balancer(); + +struct balancer_packet_handler { + struct cp_module cp_module; + + uint64_t common_counter_id; + uint64_t icmp_v4_counter_id; + uint64_t icmp_v6_counter_id; + uint64_t l4_counter_id; + + struct filter *decap_ipv4_filter; + struct filter *decap_ipv6_filter; + + struct balancer_session_table *session_table; + + struct filter *ipv4_vs_matcher; + struct filter *ipv6_vs_matcher; + + struct balancer_vs *vs; + uint32_t vs_count; + + struct balancer_session_timeouts session_timeouts; + + struct net4_addr source_v4; + struct net6_addr source_v6; + + /* + * RCU guard for the inner atomic changes on the packet handler. + * It includes changes on reals ring of virtual services. + */ + rcu_t rcu; + + /* ---Controlplane data --- */ + /* No padding needed, as rcu has 64 bytes alignment */ + + struct net4_addr *decap_v4; + uint32_t decap_v4_count; + + struct net6_addr *decap_v6; + uint32_t decap_v6_count; + + uint32_t wlc_power; + uint32_t wlc_max_weight; + uint32_t refresh_period_ms; + float session_table_max_load_factor; +}; + +void +balancer_handle_packets( + struct dp_worker *dp_worker, + struct module_ectx *module_ectx, + struct packet_front *packet_front +); diff --git a/modules/balancer/dataplane/decap.h b/modules/balancer/dataplane/decap.h deleted file mode 100644 index 36e92dce7..000000000 --- a/modules/balancer/dataplane/decap.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include "common/lpm.h" -#include "common/network.h" - -#include -#include -#include - -#include "flow/context.h" - -#include "flow/helpers.h" -#include "lib/dataplane/packet/decap.h" - -#include "handler/handler.h" -#include "rte_ether.h" - -//////////////////////////////////////////////////////////////////////////////// - -static inline int -decap_ipv4(struct packet *packet, struct packet_handler *handler) { - struct rte_ipv4_hdr *ipv4 = rte_pktmbuf_mtod_offset( - packet->mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - if (lpm_lookup( - &handler->decap_ipv4, - NET4_LEN, - (const uint8_t *)&ipv4->dst_addr - ) != LPM_VALUE_INVALID) { - return 1; - } else { - return 0; - } -} - -static inline int -decap_ipv6(struct packet *packet, struct packet_handler *handler) { - struct rte_ipv6_hdr *ipv6 = rte_pktmbuf_mtod_offset( - packet->mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - if (lpm_lookup(&handler->decap_ipv6, NET6_LEN, ipv6->dst_addr) != - LPM_VALUE_INVALID) { - return 1; - } else { - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// Try to decapsulate packet if its destination address is from the allowed -// list. If decap failed, just pass packet further. Returns -1 only if packet -// network proto is invalid. -static inline int -try_decap(struct packet_ctx *ctx) { - ctx->decap_flag = false; - - struct packet *packet = ctx->packet; - struct packet_handler *handler = ctx->handler; - - uint16_t network_protocol = packet->network_header.type; - - // check if decap is allowed. - // decap is allowed if destination address - // of the packet is in the decap list of the balancer. - int decap_is_allowed; - if (network_protocol == rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - decap_is_allowed = decap_ipv4(packet, handler); - } else if (network_protocol == rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6)) { - decap_is_allowed = decap_ipv6(packet, handler); - } else { - COMMON_STATS_INC(unexpected_network_proto, ctx); - return -1; - } - - // check if decap is allowed - if (decap_is_allowed) { - // if decap is allowed, make decap - // and check result - int decap_result = packet_decap(packet); - if (decap_result != 0) { - // decap failed, but it is ok - COMMON_STATS_INC(decap_failed, ctx); - } else { - // successfully made decap - COMMON_STATS_INC(decap_successful, ctx); - ctx->decap_flag = true; - } - } - - return 0; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/flow/common.h b/modules/balancer/dataplane/flow/common.h deleted file mode 100644 index 9fef1693b..000000000 --- a/modules/balancer/dataplane/flow/common.h +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include "../vs.h" -#include "context.h" -#include "lib/dataplane/module/module.h" -#include "lib/dataplane/module/packet_front.h" -#include "real.h" - -#include - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_set_vs(struct packet_ctx *ctx, struct vs *vs) { - ctx->vs.ptr = vs; - ctx->vs.stats = vs_counter(vs, ctx->worker_idx, ctx->stats.storage); -} - -static inline void -packet_ctx_set_real(struct packet_ctx *ctx, struct real *real) { - ctx->real.ptr = real; - ctx->real.stats = - real_counter(real, ctx->worker_idx, ctx->stats.storage); -} - -static inline void -packet_ctx_unset_real(struct packet_ctx *ctx) { - memset(&ctx->real, 0, sizeof(ctx->real)); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_send_packet(struct packet_ctx *ctx) { - ctx->processed = true; - packet_front_output(ctx->packet_front, ctx->packet); -} - -static inline void -packet_ctx_drop_packet(struct packet_ctx *ctx) { - ctx->processed = true; - packet_front_drop(ctx->packet_front, ctx->packet); -} diff --git a/modules/balancer/dataplane/flow/context.h b/modules/balancer/dataplane/flow/context.h deleted file mode 100644 index 488bc84ff..000000000 --- a/modules/balancer/dataplane/flow/context.h +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once - -#include "handler/handler.h" - -#include "lib/controlplane/config/econtext.h" -#include "lib/counters/counters.h" -#include "lib/dataplane/config/zone.h" -#include "lib/dataplane/module/module.h" -#include "lib/dataplane/packet/packet.h" -#include "state/session.h" - -//////////////////////////////////////////////////////////////////////////////// - -// Context of the packet flow. -struct packet_ctx { - struct session_id session; - uint32_t session_timeout; - - uint8_t transport_proto; - uint16_t tcp_flags; - - // packet context belongs to - struct packet *packet; - - // packet front which is used to - // send or drop packets - struct packet_front *packet_front; - - // worker which process current packet - struct dp_worker *worker; - uint32_t worker_idx; - - // packet handler - struct packet_handler *handler; - - // state of the balancer - struct balancer_state *balancer_state; - - // current time in seconds - uint32_t now; - - // module counters - struct { - struct balancer_common_stats *common; - struct balancer_icmp_stats *icmp_v4; - struct balancer_icmp_stats *icmp_v6; - struct balancer_l4_stats *l4; - - // counters storage - struct counter_storage *storage; - } stats; - - // selected virtual service - struct { - struct vs_stats *stats; - struct vs *ptr; - } vs; - - // selected real - struct { - struct real_stats *stats; - struct real *ptr; - } real; - - // if packet was decapsulated - bool decap_flag; - bool processed; -}; \ No newline at end of file diff --git a/modules/balancer/dataplane/flow/helpers.h b/modules/balancer/dataplane/flow/helpers.h deleted file mode 100644 index e49e54610..000000000 --- a/modules/balancer/dataplane/flow/helpers.h +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include "api/stats.h" -#include "context.h" - -//////////////////////////////////////////////////////////////////////////////// -// Config stats -//////////////////////////////////////////////////////////////////////////////// - -static inline struct balancer_icmp_stats * -packet_ctx_icmp_v4_config_stats(struct packet_ctx *ctx) { - return ctx->stats.icmp_v4; -} - -static inline struct balancer_icmp_stats * -packet_ctx_icmp_v6_config_stats(struct packet_ctx *ctx) { - return ctx->stats.icmp_v6; -} - -static inline struct balancer_common_stats * -packet_ctx_common_config_stats(struct packet_ctx *ctx) { - return ctx->stats.common; -} - -static inline struct balancer_l4_stats * -packet_ctx_l4_config_stats(struct packet_ctx *ctx) { - return ctx->stats.l4; -} - -//////////////////////////////////////////////////////////////////////////////// -// Module macros -//////////////////////////////////////////////////////////////////////////////// - -#define L4_STATS_INC(name, ctx) \ - do { \ - packet_ctx_l4_config_stats(ctx)->name += 1; \ - } while (0) - -#define COMMON_STATS_INC(name, ctx) \ - do { \ - packet_ctx_common_config_stats(ctx)->name += 1; \ - } while (0) - -#define COMMON_STATS_ADD(name, ctx, value) \ - do { \ - packet_ctx_common_config_stats(ctx)->name += value; \ - } while (0) - -#define ICMP_V4_STATS_INC(name, ctx) \ - do { \ - packet_ctx_icmp_v4_config_stats(ctx)->name += 1; \ - } while (0) - -#define ICMP_V6_STATS_INC(name, ctx) \ - do { \ - packet_ctx_icmp_v6_config_stats(ctx)->name += 1; \ - } while (0) - -#define ICMP_STATS_INC(name, header_type, ctx) \ - do { \ - if ((header_type) == IPPROTO_ICMP) { \ - ICMP_V4_STATS_INC(name, ctx); \ - } else { \ - ICMP_V6_STATS_INC(name, ctx); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////// -// Vs Stats and Info -//////////////////////////////////////////////////////////////////////////////// - -static inline struct vs_stats * -packet_ctx_vs_stats(struct packet_ctx *ctx) { - return ctx->vs.stats; -} - -//////////////////////////////////////////////////////////////////////////////// -// Vs macros -//////////////////////////////////////////////////////////////////////////////// - -#define VS_STATS_INC(name, ctx) \ - do { \ - packet_ctx_vs_stats(ctx)->name += 1; \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////// -// Real Stats and Info -//////////////////////////////////////////////////////////////////////////////// - -static inline struct real_stats * -packet_ctx_real_stats(struct packet_ctx *ctx) { - return ctx->real.stats; -} - -//////////////////////////////////////////////////////////////////////////////// -// Real macros -//////////////////////////////////////////////////////////////////////////////// - -#define REAL_STATS_INC(name, ctx) \ - do { \ - packet_ctx_real_stats(ctx)->name += 1; \ - } while (0) diff --git a/modules/balancer/dataplane/flow/setup.h b/modules/balancer/dataplane/flow/setup.h deleted file mode 100644 index f03fb2d74..000000000 --- a/modules/balancer/dataplane/flow/setup.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include "context.h" -#include "handler.h" - -#include "handler/handler.h" - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_setup( - struct packet_ctx *ctx, - uint32_t now, - struct dp_worker *worker, - struct module_ectx *ectx, - struct packet_handler *handler, - struct packet_front *packet_front -) { - memset(ctx, 0, sizeof(struct packet_ctx)); - ctx->packet = NULL; - ctx->handler = handler; - ctx->now = now; - ctx->stats.storage = ADDR_OF(&ectx->counter_storage); - ctx->worker = worker; - ctx->worker_idx = worker->idx; - ctx->stats.common = common_handler_counter( - handler, worker->idx, ctx->stats.storage - ); - ctx->stats.icmp_v4 = icmp_v4_handler_counter( - handler, worker->idx, ctx->stats.storage - ); - ctx->stats.icmp_v6 = icmp_v4_handler_counter( - handler, worker->idx, ctx->stats.storage - ); - ctx->stats.l4 = - l4_handler_counter(handler, worker->idx, ctx->stats.storage); - ctx->packet_front = packet_front; - ctx->balancer_state = ADDR_OF(&handler->state); - ctx->decap_flag = false; - ctx->processed = false; -} - -static inline void -packet_ctx_set_packet(struct packet_ctx *ctx, struct packet *packet) { - ctx->packet = packet; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/flow/stats.h b/modules/balancer/dataplane/flow/stats.h deleted file mode 100644 index d63a43295..000000000 --- a/modules/balancer/dataplane/flow/stats.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include "api/stats.h" -#include "context.h" - -#include "rte_mbuf_core.h" - -//////////////////////////////////////////////////////////////////////////////// -// Common module stats -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -// FWD - -static inline void -packet_ctx_update_vs_stats_on_outgoing_packet(struct packet_ctx *ctx); - -static inline void -packet_ctx_update_real_stats_on_packet(struct packet_ctx *ctx); - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_update_common_stats_on_outgoing_packet(struct packet_ctx *ctx) { - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - - ctx->stats.common->outgoing_packets += 1; - ctx->stats.common->outgoing_bytes += pkt_len; - - if (ctx->vs.ptr != NULL) { - packet_ctx_update_vs_stats_on_outgoing_packet(ctx); - } - - if (ctx->real.ptr != NULL) { - packet_ctx_update_real_stats_on_packet(ctx); - } -} - -static inline void -packet_ctx_update_common_stats_on_incoming_packet(struct packet_ctx *ctx) { - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - - ctx->stats.common->incoming_packets += 1; - ctx->stats.common->incoming_bytes += pkt_len; -} - -//////////////////////////////////////////////////////////////////////////////// -// Virtual service -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_update_vs_stats_on_outgoing_packet(struct packet_ctx *ctx) { - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - - ctx->vs.stats->outgoing_packets += 1; - ctx->vs.stats->outgoing_bytes += pkt_len; -} - -static inline void -packet_ctx_update_vs_stats_on_incoming_packet(struct packet_ctx *ctx) { - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - - ctx->vs.stats->incoming_packets += 1; - ctx->vs.stats->incoming_bytes += pkt_len; -} - -//////////////////////////////////////////////////////////////////////////////// -// Real -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_ctx_update_real_stats_on_packet(struct packet_ctx *ctx) { - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - - ctx->real.stats->packets += 1; - ctx->real.stats->bytes += pkt_len; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/handler.h b/modules/balancer/dataplane/handler.h deleted file mode 100644 index 4a00136fe..000000000 --- a/modules/balancer/dataplane/handler.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -#include "handler/handler.h" - -static inline struct balancer_common_stats * -common_handler_counter( - struct packet_handler *handler, - size_t worker, - struct counter_storage *storage -) { - uint64_t *counter = - counter_get_address(handler->counter.common, worker, storage); - return (struct balancer_common_stats *)counter; -} - -static inline struct balancer_icmp_stats * -icmp_v4_handler_counter( - struct packet_handler *handler, - size_t worker, - struct counter_storage *storage -) { - uint64_t *counter = - counter_get_address(handler->counter.icmp_v4, worker, storage); - return (struct balancer_icmp_stats *)counter; -} - -static inline struct balancer_icmp_stats * -icmp_v6_handler_counter( - struct packet_handler *handler, - size_t worker, - struct counter_storage *storage -) { - uint64_t *counter = - counter_get_address(handler->counter.icmp_v6, worker, storage); - return (struct balancer_icmp_stats *)counter; -} - -static inline struct balancer_l4_stats * -l4_handler_counter( - struct packet_handler *handler, - size_t worker, - struct counter_storage *storage -) { - uint64_t *counter = - counter_get_address(handler->counter.l4, worker, storage); - return (struct balancer_l4_stats *)counter; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/echo/handle.h b/modules/balancer/dataplane/icmp/echo/handle.h deleted file mode 100644 index 12c402815..000000000 --- a/modules/balancer/dataplane/icmp/echo/handle.h +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include "common/network.h" -#include "flow/helpers.h" -#include "lib/dataplane/packet/packet.h" - -#include "../../checksum.h" - -#include "../../flow/common.h" -#include "../../flow/context.h" -#include "../../flow/stats.h" - -#include -#include -#include - -#include "lookup.h" -#include "rte_icmp.h" -#include "rte_ip.h" -#include "rte_mbuf_core.h" - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -setup_icmp_header_on_echo_request(struct rte_icmp_hdr *icmp, int type) { - icmp->icmp_type = type; - icmp->icmp_code = 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -send_packet(struct packet_ctx *ctx) { - // update counters - ICMP_V4_STATS_INC(echo_responses, ctx); - packet_ctx_update_common_stats_on_outgoing_packet(ctx); - - // send packet to the next module - packet_ctx_send_packet(ctx); -} - -static inline void -handle_icmp_echo_ipv4(struct packet_ctx *ctx) { - // update stats - ICMP_V4_STATS_INC(incoming_packets, ctx); - - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - // validate virtual service - if (!vs_v4_announced(ctx)) { - ICMP_V4_STATS_INC(unrecognized_vs, ctx); - packet_ctx_drop_packet(ctx); - return; - } - - // setup icmp header (type and code) - struct rte_icmp_hdr *icmp = rte_pktmbuf_mtod_offset( - mbuf, struct rte_icmp_hdr *, packet->transport_header.offset - ); - setup_icmp_header_on_echo_request(icmp, ICMP_ECHOREPLY); - - // get ip header - struct rte_ipv4_hdr *ip = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv4_hdr *, packet->network_header.offset - ); - - // swap src and dst address to echo reply - uint32_t tmp = ip->src_addr; - ip->src_addr = ip->dst_addr; - ip->dst_addr = tmp; - - // setup ttl, as it is reply - ip->time_to_live = 64; - - // recalculate ip check sum - ip->hdr_checksum = 0; - ip->hdr_checksum = rte_ipv4_cksum(ip); - - // recalculate icmp checksum - uint16_t icmp_checksum = ~icmp->icmp_cksum; - icmp_checksum = csum_minus(icmp_checksum, ICMP_ECHO); - icmp_checksum = csum_plus(icmp_checksum, ICMP_ECHOREPLY); - icmp->icmp_cksum = ~icmp_checksum; - - // update counters and pass packet - ctx->stats.icmp_v4->echo_responses += 1; - send_packet(ctx); // updates common counters under the hood. -} - -static inline void -handle_icmp_echo_ipv6(struct packet_ctx *ctx) { - // update stats - ICMP_V6_STATS_INC(incoming_packets, ctx); - - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - // validate virtual service - if (!vs_v6_announced(ctx)) { - ICMP_V6_STATS_INC(unrecognized_vs, ctx); - packet_ctx_drop_packet(ctx); - return; - } - - struct rte_icmp_hdr *icmp = rte_pktmbuf_mtod_offset( - mbuf, struct rte_icmp_hdr *, packet->transport_header.offset - ); - setup_icmp_header_on_echo_request(icmp, ICMP6_ECHO_REPLY); - - struct rte_ipv6_hdr *ip = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, packet->network_header.offset - ); - uint8_t tmp[NET6_LEN]; - memcpy(tmp, ip->src_addr, NET6_LEN); - memcpy(ip->src_addr, ip->dst_addr, NET6_LEN); - memcpy(ip->dst_addr, tmp, NET6_LEN); - - ip->hop_limits = 64; - - uint16_t checksum = ~icmp->icmp_cksum; - checksum = csum_minus(checksum, ICMP6_ECHO_REQUEST); - checksum = csum_plus(checksum, ICMP6_ECHO_REPLY); - icmp->icmp_cksum = ~checksum; - - // update counter and pass packet - ctx->stats.icmp_v6->echo_responses += 1; - send_packet(ctx); // updates common counters under the hood. -} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/error/broadcast.h b/modules/balancer/dataplane/icmp/error/broadcast.h deleted file mode 100644 index faf435c17..000000000 --- a/modules/balancer/dataplane/icmp/error/broadcast.h +++ /dev/null @@ -1,180 +0,0 @@ -#pragma once - -#include "flow/common.h" -#include "flow/context.h" -#include "flow/helpers.h" - -#include "common/memory_address.h" -#include "common/network.h" - -#include "lib/dataplane/module/module.h" -#include "lib/dataplane/packet/packet.h" -#include "lib/dataplane/worker/worker.h" - -#include "handler/vs.h" - -#include "tunnel.h" -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -static inline struct packet * -clone_packet(struct dp_worker *worker, struct packet *packet) { - return worker_clone_packet(worker, packet); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -send_cloned_packet(struct packet_ctx *ctx, struct packet *packet) { - // update common module counters - uint64_t pkt_len = ctx->packet->mbuf->pkt_len; - COMMON_STATS_ADD(outgoing_bytes, ctx, pkt_len); - COMMON_STATS_INC(outgoing_packets, ctx); - - // update icmp module counters - ICMP_STATS_INC(packet_clones_sent, packet->transport_header.type, ctx); - - // we send cloned packets to other balancer, - // so we dont update vs or real counters here. - - // send packet to the next module - packet_front_output(ctx->packet_front, packet); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -update_counters_on_packet_clone_failed(struct packet_ctx *ctx) { - struct packet *packet = ctx->packet; - uint16_t type = packet->transport_header.type; - ICMP_STATS_INC(packet_clone_failures, type, ctx); -} - -//////////////////////////////////////////////////////////////////////////////// - -// ICMP error message header structure -// For error messages, the format is: -// [type:1][code:1][checksum:2][unused:4][original packet...] We use the first 2 -// bytes of the unused field to store our broadcast marker -struct icmp_error_hdr { - uint8_t type; - uint8_t code; - rte_be16_t checksum; - rte_be16_t unused_marker; // We use this for ICMP_BROADCAST_IDENT - rte_be16_t unused_rest; -} __rte_packed; - -static inline struct icmp_error_hdr * -icmp_error_hdr(struct packet *packet) { - struct icmp_error_hdr *icmp = rte_pktmbuf_mtod_offset( - packet->mbuf, - struct icmp_error_hdr *, - packet->transport_header.offset - ); - return icmp; -} - -//////////////////////////////////////////////////////////////////////////////// - -#define ICMP_BROADCAST_IDENT 0xBDC - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -set_cloned_mark(struct packet *packet) { - icmp_error_hdr(packet)->unused_marker = - rte_cpu_to_be_16(ICMP_BROADCAST_IDENT); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -broadcast_icmp_packet(struct packet_ctx *ctx) { - // If packet was decapsulated (came from another balancer), drop it - // to prevent broadcast loops. - // - // Otherwise, if virtual service for the packet was found, - // we iterate over virtual service peers and broadcast packet - // to them. - - // Check if packet came from another balancer (was decapsulated) - if (ctx->decap_flag) { - // Packet was decapsulated, meaning it came from another - // balancer Don't broadcast to prevent loops - uint16_t header_type = ctx->packet->transport_header.type; - ICMP_STATS_INC(packet_clones_received, header_type, ctx); - packet_ctx_drop_packet(ctx); - return; - } - - struct vs *vs = ctx->vs.ptr; - assert(vs != NULL); - - // here virtual service can not be null - - // Update counters - - // Update virtual service counters - VS_STATS_INC(broadcasted_icmp_packets, ctx); - - // Update module counters - if (ctx->packet->transport_header.type == IPPROTO_ICMP) { - ICMP_V4_STATS_INC(broadcasted_packets, ctx); - } else if (ctx->packet->transport_header.type == IPPROTO_ICMPV6) { - ICMP_V6_STATS_INC(broadcasted_packets, ctx); - } else { - // impossible - assert(false); - } - - // Broadcast packet to v4 peers. - uint8_t *balancer_src_v4 = ctx->handler->source_ipv4.bytes; - struct net4_addr *peers_v4 = ADDR_OF(&vs->peers_v4); - for (size_t i = 0; i < vs->peers_v4_count; ++i) { - struct packet *clone = clone_packet(ctx->worker, ctx->packet); - if (clone == NULL) { - update_counters_on_packet_clone_failed(ctx); - continue; - } - - // set mark that the packet is cloned - set_cloned_mark(clone); - - // tunnel packet to peer - struct net4_addr *peer = &peers_v4[i]; - tunnel_v4(clone, balancer_src_v4, peer->bytes); - - // send packet - send_cloned_packet(ctx, clone); - } - - // Broadcast packet to v6 peers. - uint8_t *balancer_src_v6 = ctx->handler->source_ipv6.bytes; - struct net6_addr *peers_v6 = ADDR_OF(&vs->peers_v6); - for (size_t i = 0; i < vs->peers_v6_count; ++i) { - struct packet *clone = clone_packet(ctx->worker, ctx->packet); - if (clone == NULL) { - update_counters_on_packet_clone_failed(ctx); - continue; - } - - // set mark that the packet is cloned - set_cloned_mark(clone); - - // tunnel packet to peer - struct net6_addr *peer = &peers_v6[i]; - tunnel_v6(clone, balancer_src_v6, peer->bytes); - - // send packet - send_cloned_packet(ctx, clone); - } - - // Drop the initial packet - packet_ctx_drop_packet(ctx); -} - -//////////////////////////////////////////////////////////////////////////////// - -#undef ICMP_BROADCAST_IDENT \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/error/handle.h b/modules/balancer/dataplane/icmp/error/handle.h deleted file mode 100644 index abf6f93e0..000000000 --- a/modules/balancer/dataplane/icmp/error/handle.h +++ /dev/null @@ -1,92 +0,0 @@ -#pragma once - -#include "broadcast.h" -#include "validate.h" - -#include -#include -#include - -#include "flow/helpers.h" -#include "flow/stats.h" - -#include "../../tunnel.h" - -//////////////////////////////////////////////////////////////////////////////// - -void -handle_icmp_error_packet(struct packet_ctx *ctx) { - // If session with goal real is present on the balancer, - // forward packet to this real. - // - // Else, if packet was not decapsulated (didn't come from another - // balancer), clone it and broadcast to other balancers. - - // update stats - ICMP_STATS_INC( - incoming_packets, ctx->packet->transport_header.type, ctx - ); - - // First, validate and parse packet. - // On errors, update corresponding counters. - enum validate_packet_result validate_result = - validate_and_parse_packet(ctx); - - switch (validate_result) { - - // If packet is invalid, drop it. - case validate_packet_error: - // counters already updated - packet_ctx_drop_packet(ctx); - break; - - case validate_packet_vs_not_found: - // virtual service not found, - // so we can not broadcast packet nor - // forward it. - // counters already updated. - packet_ctx_drop_packet(ctx); - break; - - // If session with real not found on the balancer, - // try to broadcast packet to other balancers. - case validate_packet_session_not_found: - broadcast_icmp_packet(ctx); - break; - - // If session with real found on the balancer, - // tunnel packet to real. - case validate_packet_session_found: - // send packet to real - tunnel_packet( - ctx->vs.ptr, - ctx->real.ptr, - ctx->packet - ); // added tunneling for packet - - // send packet to the next module - packet_ctx_send_packet(ctx); - - // update stats - - // update module stats - - // update icmp stats - if (ctx->packet->transport_header.type == IPPROTO_ICMP) { - ICMP_V4_STATS_INC(forwarded_packets, ctx); - } else { - ICMP_V6_STATS_INC(forwarded_packets, ctx); - } - - // update common module stats - packet_ctx_update_common_stats_on_outgoing_packet(ctx); - - // update vs counter - VS_STATS_INC(error_icmp_packets, ctx); - - // update real counter - REAL_STATS_INC(error_icmp_packets, ctx); - - break; - } -} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/error/info.h b/modules/balancer/dataplane/icmp/error/info.h deleted file mode 100644 index 2191c88cb..000000000 --- a/modules/balancer/dataplane/icmp/error/info.h +++ /dev/null @@ -1,158 +0,0 @@ -#pragma once - -#include "dataplane/packet/packet.h" -#include "rte_byteorder.h" -#include "rte_ether.h" -#include "rte_ip.h" - -//////////////////////////////////////////////////////////////////////////////// - -struct icmp_packet_info { - // ICMP packet layout: - // [NETWORK | ICMP | /* inner */ NETWORK | /* inner */ TRANSPORT] - struct network_header network; - struct transport_header transport; -}; - -//////////////////////////////////////////////////////////////////////////////// - -#define PACKET_INFO_UNKNOWN ((uint16_t)-1) -#define PACKET_INFO_EXTENSIONS_MAX ((uint32_t)32) -#define PACKET_INFO_EXTENSION_SIZE_MAX ((uint32_t)16) - -//////////////////////////////////////////////////////////////////////////////// - -static inline int -fill_icmp_packet_info_ipv4( - struct rte_mbuf *mbuf, struct icmp_packet_info *info -) { - const struct rte_ipv4_hdr *ipv4_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv4_hdr *, info->network.offset - ); - - /// @todo: check version - - // check the entire ip packet encapsulated - // in the icmp packet. - if (rte_pktmbuf_pkt_len(mbuf) < - (uint32_t)info->network.offset + - rte_be_to_cpu_16(ipv4_hdr->total_length)) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - if ((ipv4_hdr->version_ihl & 0x0F) < 0x05) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } else { - - info->transport.type = ipv4_hdr->next_proto_id; - info->transport.offset = info->network.offset + - 4 * (ipv4_hdr->version_ihl & 0x0F); - } - - if (rte_be_to_cpu_16(ipv4_hdr->total_length) < - 4 * (ipv4_hdr->version_ihl & 0x0F)) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - return 0; -} - -static inline int -fill_icmp_packet_info_ipv6( - struct rte_mbuf *mbuf, struct icmp_packet_info *info -) { - const struct rte_ipv6_hdr *ipv6_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, info->network.offset - ); - - /// @todo: check version - - if (rte_pktmbuf_pkt_len(mbuf) < - (uint32_t)info->network.offset + sizeof(struct rte_ipv6_hdr) + - rte_be_to_cpu_16(ipv6_hdr->payload_len)) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - uint8_t transport_hdr_type = ipv6_hdr->proto; - uint16_t transport_hdr_offset = - info->network.offset + sizeof(struct rte_ipv6_hdr); - - unsigned int extension_i = 0; - for (extension_i = 0; extension_i < PACKET_INFO_EXTENSIONS_MAX + 1; - extension_i++) { - if (transport_hdr_type == IPPROTO_HOPOPTS || - transport_hdr_type == IPPROTO_ROUTING || - transport_hdr_type == IPPROTO_DSTOPTS) { - const struct yanet_ipv6_ext_2byte *extension = - rte_pktmbuf_mtod_offset( - mbuf, - struct yanet_ipv6_ext_2byte *, - transport_hdr_offset - ); - - if (extension->extension_length > - PACKET_INFO_EXTENSION_SIZE_MAX) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - transport_hdr_type = extension->next_header; - transport_hdr_offset += - 8 + extension->extension_length * 8; - } else if (transport_hdr_type == IPPROTO_FRAGMENT) { - const struct yanet_ipv6_ext_fragment *extension = - rte_pktmbuf_mtod_offset( - mbuf, - struct yanet_ipv6_ext_fragment *, - transport_hdr_offset - ); - - transport_hdr_type = extension->next_header; - transport_hdr_offset += 8; - - /** @todo: last extension? - info->transport.type = transport_headerType; - info->transport.offset = transport_headerOffset; - - break; - */ - } else { - info->transport.type = transport_hdr_type; - info->transport.offset = transport_hdr_offset; - - break; - } - } - if (extension_i == PACKET_INFO_EXTENSIONS_MAX + 1) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - if (rte_be_to_cpu_16(ipv6_hdr->payload_len) < - info->transport.offset - info->network.offset - - sizeof(struct rte_ipv6_hdr)) { - info->network.type = PACKET_INFO_UNKNOWN; - return -1; - } - - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline int -fill_icmp_packet_info(struct rte_mbuf *mbuf, struct icmp_packet_info *info) { - if (info->network.type == rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - return fill_icmp_packet_info_ipv4(mbuf, info); - } else if (info->network.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6)) { - return fill_icmp_packet_info_ipv6(mbuf, info); - } else { - // unknown - return -1; - } -} diff --git a/modules/balancer/dataplane/icmp/error/tunnel.h b/modules/balancer/dataplane/icmp/error/tunnel.h deleted file mode 100644 index 7d201f654..000000000 --- a/modules/balancer/dataplane/icmp/error/tunnel.h +++ /dev/null @@ -1,127 +0,0 @@ -#pragma once - -#include "rte_ether.h" -#include "rte_ip.h" - -#include "common/network.h" - -#include -#include - -#include "lib/dataplane/packet/packet.h" - -//////////////////////////////////////////////////////////////////////////////// - -// Tunnel packet from this balancer (src address) to another (dst address) - -static inline void -fix_ether_header(struct rte_mbuf *mbuf, uint16_t ether_type) { - struct rte_ether_hdr *ether_header = - rte_pktmbuf_mtod(mbuf, struct rte_ether_hdr *); - - // setup ether type - if (ether_header->ether_type == rte_cpu_to_be_16(RTE_ETHER_TYPE_VLAN)) { - struct rte_vlan_hdr *vlan_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_vlan_hdr *, - sizeof(struct rte_ether_hdr) - ); - vlan_header->eth_proto = rte_cpu_to_be_16(ether_type); - } else { - ether_header->ether_type = rte_cpu_to_be_16(ether_type); - } -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -tunnel_v4(struct packet *packet, uint8_t *src, uint8_t *dst) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - // insert ipv4 header - rte_pktmbuf_prepend(mbuf, sizeof(struct rte_ipv4_hdr)); - memmove(rte_pktmbuf_mtod(mbuf, char *), - rte_pktmbuf_mtod_offset( - mbuf, char *, sizeof(struct rte_ipv4_hdr) - ), - packet->network_header.offset); - - struct rte_ipv4_hdr *outer_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv4_hdr *, packet->network_header.offset - ); - - memcpy(&outer_ip_hdr->src_addr, src, NET4_LEN); - memcpy(&outer_ip_hdr->dst_addr, dst, NET4_LEN); - - outer_ip_hdr->version_ihl = 0x45; - outer_ip_hdr->type_of_service = 0x00; - outer_ip_hdr->packet_id = rte_cpu_to_be_16(0x01); - outer_ip_hdr->fragment_offset = 0; - outer_ip_hdr->time_to_live = 64; - - outer_ip_hdr->total_length = rte_cpu_to_be_16( - (uint16_t)(mbuf->pkt_len - packet->network_header.offset) - ); - - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - outer_ip_hdr->next_proto_id = IPPROTO_IPIP; - } else { - outer_ip_hdr->next_proto_id = IPPROTO_IPV6; - } - - outer_ip_hdr->hdr_checksum = 0; - outer_ip_hdr->hdr_checksum = rte_ipv4_cksum(outer_ip_hdr); ///< @todo - - // might need to change next protocol type in ethernet/vlan header in - // cloned packet - - fix_ether_header(mbuf, RTE_ETHER_TYPE_IPV4); - - // Update mbuf metadata for the new outer IP header - mbuf->l3_len = sizeof(struct rte_ipv4_hdr); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -tunnel_v6(struct packet *packet, uint8_t *src, uint8_t *dst) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - // insert ipv6 header - - rte_pktmbuf_prepend(mbuf, sizeof(struct rte_ipv6_hdr)); - memmove(rte_pktmbuf_mtod(mbuf, char *), - rte_pktmbuf_mtod_offset( - mbuf, char *, sizeof(struct rte_ipv6_hdr) - ), - packet->network_header.offset); - - struct rte_ipv6_hdr *outer_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, packet->network_header.offset - ); - - memcpy(&outer_ip_hdr->src_addr, src, NET6_LEN); - memcpy(&outer_ip_hdr->dst_addr, dst, NET6_LEN); - - // todo: randomize src address - - outer_ip_hdr->vtc_flow = rte_cpu_to_be_32((0x6 << 28)); - outer_ip_hdr->payload_len = - rte_cpu_to_be_16((uint16_t)(mbuf->pkt_len - - packet->network_header.offset - - sizeof(struct rte_ipv6_hdr))); - outer_ip_hdr->hop_limits = 64; - - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - outer_ip_hdr->proto = IPPROTO_IPIP; - } else { - outer_ip_hdr->proto = IPPROTO_IPV6; - } - - fix_ether_header(mbuf, RTE_ETHER_TYPE_IPV6); - - // Update mbuf metadata for the new outer IP header - mbuf->l3_len = sizeof(struct rte_ipv6_hdr); -} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/error/validate.h b/modules/balancer/dataplane/icmp/error/validate.h deleted file mode 100644 index 1cd4146ec..000000000 --- a/modules/balancer/dataplane/icmp/error/validate.h +++ /dev/null @@ -1,331 +0,0 @@ -#pragma once - -#include "common/network.h" -#include "flow/common.h" -#include "handler/real.h" -#include "icmp/error/info.h" -#include "lib/dataplane/packet/packet.h" - -#include "api/stats.h" -#include "lookup.h" -#include "meta.h" -#include "rte_byteorder.h" -#include "rte_icmp.h" -#include "session_table.h" -#include "state/state.h" - -#include - -//////////////////////////////////////////////////////////////////////////////// - -enum validate_packet_result { - // Packet is invalid - validate_packet_error = -1, - - // Not found session with the real on the current balancer - validate_packet_session_not_found = 0, - - // Virtual service not recognized - validate_packet_vs_not_found = 1, - - // Found session with real on the current balancer - validate_packet_session_found = 2 -}; - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -packet_swap_headers( - struct packet *packet, - struct network_header *network, - struct transport_header *transport -) { - // set network heder - { - struct network_header tmp = packet->network_header; - packet->network_header = *network; - *network = tmp; - } - - // set transport header - { - struct transport_header tmp = packet->transport_header; - packet->transport_header = *transport; - *transport = tmp; - } -} - -static inline void -packet_swap_src_dst(struct packet *packet) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - // Swap IP addresses - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - struct rte_ipv4_hdr *inner_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - uint32_t tmp = inner_ip_hdr->src_addr; - inner_ip_hdr->src_addr = inner_ip_hdr->dst_addr; - inner_ip_hdr->dst_addr = tmp; - } else { // ipv6 - struct rte_ipv6_hdr *inner_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - uint8_t tmp[16]; - memcpy(tmp, inner_ip_hdr->src_addr, NET6_LEN); - memcpy(inner_ip_hdr->src_addr, inner_ip_hdr->dst_addr, NET6_LEN - ); - memcpy(inner_ip_hdr->dst_addr, tmp, NET6_LEN); - } - - // Swap transport ports - if (packet->transport_header.type == IPPROTO_TCP) { - struct rte_tcp_hdr *tcp = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_tcp_hdr *, - packet->transport_header.offset - ); - uint16_t tmp_port = tcp->src_port; - tcp->src_port = tcp->dst_port; - tcp->dst_port = tmp_port; - } else if (packet->transport_header.type == IPPROTO_UDP) { - struct rte_udp_hdr *udp = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_udp_hdr *, - packet->transport_header.offset - ); - uint16_t tmp_port = udp->src_port; - udp->src_port = udp->dst_port; - udp->dst_port = tmp_port; - } -} - -static inline int -validate_packet_ipv4( - struct packet_ctx *ctx, - struct packet_metadata_copy *meta_copy, - struct vs **vs -) { - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - struct rte_ipv4_hdr *outer_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv4_hdr *, packet->network_header.offset - ); - - meta_copy->meta.network_proto = IPPROTO_IP; - struct icmp_packet_info info; - info.network.type = rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4); - - // ICMPv4 error messages use 8-byte header (type + code + checksum + - // 4-byte unused) This matches sizeof(struct rte_icmp_hdr) which is 8 - // bytes - info.network.offset = packet->transport_header.offset + 8; - - if (fill_icmp_packet_info_ipv4(mbuf, &info) != 0) { - ICMP_V4_STATS_INC(payload_too_short_ip, ctx); - return -1; - } - - struct rte_ipv4_hdr *inner_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->transport_header.offset + sizeof(struct rte_icmp_hdr) - ); - if (inner_ip_hdr->src_addr != outer_ip_hdr->dst_addr) { - ICMP_V4_STATS_INC(unmatching_src_from_original, ctx); - return -1; - } - - if (mbuf->pkt_len < info.transport.offset + 2 * sizeof(rte_be16_t)) { - ICMP_V4_STATS_INC(payload_too_short_port, ctx); - return -1; - } - - // swap source address and destination address - // on the inner packet. after that, destination address should be equal - // to the virtual service address. also, we need to swap transport - // proto source and destination. - packet_swap_headers(ctx->packet, &info.network, &info.transport); - packet_swap_src_dst(ctx->packet); - - // fill packet metadata - if (fill_packet_metadata_copy(packet, meta_copy)) { - ICMP_V4_STATS_INC(unexpected_transport, ctx); - packet_swap_src_dst(ctx->packet); - packet_swap_headers( - ctx->packet, &info.network, &info.transport - ); - return -1; - } - - // lookup virtual service - *vs = vs_v4_lookup(ctx); - if (*vs == NULL) { - ICMP_V4_STATS_INC(unrecognized_vs, ctx); - } - - // swap headers and src dst back - packet_swap_src_dst(ctx->packet); - packet_swap_headers(ctx->packet, &info.network, &info.transport); - - return 0; -} - -static inline int -validate_packet_ipv6( - struct packet_ctx *ctx, - struct packet_metadata_copy *meta_copy, - struct vs **vs -) { - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - struct rte_ipv6_hdr *outer_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, packet->network_header.offset - ); - - meta_copy->meta.network_proto = IPPROTO_IPV6; - struct icmp_packet_info info; - info.network.type = rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6); - - // ICMPv6 error messages use 8-byte header (type + code + checksum + - // 4-byte unused) This is different from rte_icmp_hdr which is for echo - // messages - info.network.offset = packet->transport_header.offset + 8; - - if (fill_icmp_packet_info_ipv6(mbuf, &info) != 0) { - ICMP_V6_STATS_INC(payload_too_short_ip, ctx); - return -1; - } - - struct rte_ipv6_hdr *inner_ip_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, info.network.offset - ); - - if (memcmp(inner_ip_hdr->src_addr, outer_ip_hdr->dst_addr, 16)) { - ICMP_V6_STATS_INC(unmatching_src_from_original, ctx); - return -1; - } - - if (mbuf->pkt_len < info.transport.offset + 2 * sizeof(rte_be16_t)) { - ICMP_V6_STATS_INC(payload_too_short_port, ctx); - return -1; - } - - // swap source address and destination address - // on the inner packet. after that, destination address should be equal - // to the virtual service address. also, we need to swap transport - // proto source and destination. - packet_swap_headers(ctx->packet, &info.network, &info.transport); - packet_swap_src_dst(ctx->packet); - - // fill packet metadata - if (fill_packet_metadata_copy(packet, meta_copy)) { - ICMP_V6_STATS_INC(unexpected_transport, ctx); - packet_swap_src_dst(ctx->packet); - packet_swap_headers( - ctx->packet, &info.network, &info.transport - ); - return -1; - } - - // lookup virtual service - *vs = vs_v6_lookup(ctx); - if (*vs == NULL) { - ICMP_V6_STATS_INC(unrecognized_vs, ctx); - } - - // swap headers and src dst back - packet_swap_src_dst(ctx->packet); - packet_swap_headers(ctx->packet, &info.network, &info.transport); - - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline int -validate_and_parse_packet(struct packet_ctx *ctx) { - // Fill packet metadata and find virtual service for which - // original packet is intended to. - // - // After that, try to find session with real - // in the current balancer state. - - struct packet_metadata_copy meta_copy; - struct vs *vs; - - // validate packet, set metadata and packet icmp info - // (in the packet context). - // if validation failed, update corresponding counters. - int validate_result; - switch (ctx->packet->transport_header.type) { - case IPPROTO_ICMP: { - validate_result = validate_packet_ipv4(ctx, &meta_copy, &vs); - break; - } - case IPPROTO_ICMPV6: { - validate_result = validate_packet_ipv6(ctx, &meta_copy, &vs); - break; - } - default: { - // impossible, because previously it was - // checked packet is icmp or icmpv6 - assert(false); - } - } - - // if failed to validate packet, return error. - if (validate_result) { - return validate_packet_error; - } - - // if virtual service not found, - // there can not be session with real on the current balancer. - // so, we return corresponding status. - if (vs == NULL) { - return validate_packet_vs_not_found; - } else { - packet_ctx_set_vs(ctx, vs); - } - - // try to find session by id - - // fill session id - struct session_id session_id; - fill_session_id(&session_id, &meta_copy.meta, vs); - - // begin critical section - uint64_t current_gen = session_table_begin_cs( - &ctx->balancer_state->session_table, ctx->worker->idx - ); - - // get real for the session - uint32_t real_id = get_session_real( - &ctx->balancer_state->session_table, - current_gen, - &session_id, - ctx->now - ); - - // end critical section - session_table_end_cs( - &ctx->balancer_state->session_table, ctx->worker->idx - ); - - if (real_id == (uint32_t)-1) { // real not found - // end critical section - return validate_packet_session_not_found; - } else { // real found - struct real *reals = ADDR_OF(&ctx->handler->reals); - struct real *real = &reals[real_id]; - packet_ctx_set_vs(ctx, vs); - packet_ctx_set_real(ctx, real); - return validate_packet_session_found; - } -} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/handle.c b/modules/balancer/dataplane/icmp/handle.c new file mode 100644 index 000000000..aaf57f21d --- /dev/null +++ b/modules/balancer/dataplane/icmp/handle.c @@ -0,0 +1,27 @@ +#include "handle.h" + +#include "lib/dataplane/module/packet_front.h" + +#include "context.h" + +void +balancer_handle_icmp_ipv4( + struct worker_context *context, + struct packet **packets, + size_t packets_count +) { + for (size_t i = 0; i < packets_count; i++) { + packet_front_drop(context->packet_front, packets[i]); + } +} + +void +balancer_handle_icmp_ipv6( + struct worker_context *context, + struct packet **packets, + size_t packets_count +) { + for (size_t i = 0; i < packets_count; i++) { + packet_front_drop(context->packet_front, packets[i]); + } +} \ No newline at end of file diff --git a/modules/balancer/dataplane/icmp/handle.h b/modules/balancer/dataplane/icmp/handle.h index 5777a89b0..833c6e504 100644 --- a/modules/balancer/dataplane/icmp/handle.h +++ b/modules/balancer/dataplane/icmp/handle.h @@ -1,38 +1,20 @@ #pragma once -#include "echo/handle.h" -#include "error/handle.h" +#include -#include "lib/dataplane/packet/packet.h" +struct packet; +struct worker_context; -//////////////////////////////////////////////////////////////////////////////// +void +balancer_handle_icmp_ipv4( + struct worker_context *context, + struct packet **packets, + size_t packets_count +); -static inline void -handle_icmp_packet(struct packet_ctx *ctx) { - // Separately handle echo request and error packets. - // On echo, just answer from the balancer and dont forward packet to - // real. On error, we need to determine of packet is for the real, which - // serves by balancer. If so, forward packet to the real. Else, - // broadcast packet to other balancers, which serves this virtual - // services. - - struct rte_mbuf *mbuf = packet_to_mbuf(ctx->packet); - struct rte_icmp_hdr *icmp = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_icmp_hdr *, - ctx->packet->transport_header.offset - ); - - // this functions send or drop packets under the hood - // (and update corresponding counters) - switch (icmp->icmp_type) { - case ICMP_ECHO: - handle_icmp_echo_ipv4(ctx); - break; - case ICMP6_ECHO_REQUEST: - handle_icmp_echo_ipv6(ctx); - break; - default: - handle_icmp_error_packet(ctx); - } -} \ No newline at end of file +void +balancer_handle_icmp_ipv6( + struct worker_context *context, + struct packet **packets, + size_t packets_count +); \ No newline at end of file diff --git a/modules/balancer/dataplane/interval_counter.h b/modules/balancer/dataplane/interval_counter.h deleted file mode 100644 index 79fb093b3..000000000 --- a/modules/balancer/dataplane/interval_counter.h +++ /dev/null @@ -1,107 +0,0 @@ -#pragma once - -#include "common/likely.h" -#include -#include -#include - -#define RT_INTERVAL_COUNTER_RING_SIZE_EXP 3u -#define RT_INTERVAL_COUNTER_RING_SIZE (1u << RT_INTERVAL_COUNTER_RING_SIZE_EXP) -#define RT_INTERVAL_COUNTER_RING_MASK (RT_INTERVAL_COUNTER_RING_SIZE - 1u) - -/* - * Ring-based interval counter that stores per-timestamp deltas. - * - * The caller keeps a running total and applies the returned change. - * `make` starts an interval at `now` and schedules its end at `until`. - * `prolong` moves a previously scheduled end further in time. - */ -struct rt_interval_counter { - int32_t diff[RT_INTERVAL_COUNTER_RING_SIZE]; - uint32_t last_timestamp; -}; - -/* Reset the whole ring when all slots are older than the current time. */ -static inline int64_t -rt_interval_counter_try_reset( - struct rt_interval_counter *counter, uint32_t now -) { - int64_t sum = 0; - if (unlikely( - now - counter->last_timestamp >= - RT_INTERVAL_COUNTER_RING_SIZE - )) { - /* - * The entire ring is stale. Sum all remaining deltas so - * the caller's running count stays consistent, then clear. - */ - for (size_t i = 0; i < RT_INTERVAL_COUNTER_RING_SIZE; ++i) { - sum += counter->diff[i]; - } - memset(counter->diff, - 0, - RT_INTERVAL_COUNTER_RING_SIZE * sizeof(int32_t)); - counter->last_timestamp = now; - } - return sum; -} - -/* Expire slots up to `now` and return the net change for the running total. */ -static inline int64_t -rt_interval_counter_advance(struct rt_interval_counter *counter, uint32_t now) { - int64_t change = 0; - - /* Sweep past slots: [last_timestamp, now) */ - while (unlikely(counter->last_timestamp < now)) { - uint32_t idx = - counter->last_timestamp & RT_INTERVAL_COUNTER_RING_MASK; - counter->last_timestamp++; - change += counter->diff[idx]; - counter->diff[idx] = 0; - } - - /* - * Consume the current slot (now). Any +1/-1 written by the - * caller for this timestamp is picked up here and the slot is - * cleared so subsequent calls at the same `now` start fresh. - */ - uint32_t idx = counter->last_timestamp & RT_INTERVAL_COUNTER_RING_MASK; - change += counter->diff[idx]; - counter->diff[idx] = 0; - return change; -} - -/* Start a new interval `[now, until)` and return the change visible at `now`. - */ -static inline int64_t -rt_interval_counter_make( - struct rt_interval_counter *counter, uint32_t now, uint32_t until -) { - assert(until - now < RT_INTERVAL_COUNTER_RING_SIZE); - - int64_t change = rt_interval_counter_try_reset(counter, now); - - counter->diff[now & RT_INTERVAL_COUNTER_RING_MASK] += 1; - counter->diff[until & RT_INTERVAL_COUNTER_RING_MASK] -= 1; - - return change + rt_interval_counter_advance(counter, now); -} - -/* Move an existing interval end from `prev_until` to `new_until`. */ -static inline int64_t -rt_interval_counter_prolong( - struct rt_interval_counter *counter, - uint32_t now, - uint32_t prev_until, - uint32_t new_until -) { - assert(prev_until >= now); - assert(new_until - now < RT_INTERVAL_COUNTER_RING_SIZE); - - int64_t change = rt_interval_counter_try_reset(counter, now); - - counter->diff[prev_until & RT_INTERVAL_COUNTER_RING_MASK] += 1; - counter->diff[new_until & RT_INTERVAL_COUNTER_RING_MASK] -= 1; - - return change + rt_interval_counter_advance(counter, now); -} \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/gre.c b/modules/balancer/dataplane/l4/gre.c new file mode 100644 index 000000000..f0bd8824f --- /dev/null +++ b/modules/balancer/dataplane/l4/gre.c @@ -0,0 +1,88 @@ +#include "gre.h" + +#include + +#include +#include +#include +#include + +#include "common/checksum.h" + +#include "lib/dataplane/packet/data.h" +#include "lib/dataplane/packet/packet.h" +#include "rte_branch_prediction.h" + +static void +adjust_outer_ipv6_for_gre(struct rte_mbuf *mbuf, uint16_t network_offset) { + struct rte_ipv6_hdr *hdr = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv6_hdr *, network_offset + ); + hdr->proto = IPPROTO_GRE; + hdr->payload_len = rte_cpu_to_be_16( + rte_be_to_cpu_16(hdr->payload_len) + sizeof(struct rte_gre_hdr) + ); +} + +static void +adjust_outer_ipv4_for_gre(struct rte_mbuf *mbuf, uint16_t network_offset) { + struct rte_ipv4_hdr *hdr = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv4_hdr *, network_offset + ); + + /* Save the 16-bit words that will change. */ + uint16_t old_total_length = hdr->total_length; + uint16_t old_ttl_proto; + memcpy(&old_ttl_proto, &hdr->time_to_live, sizeof(uint16_t)); + + hdr->next_proto_id = IPPROTO_GRE; + hdr->total_length = rte_cpu_to_be_16( + rte_be_to_cpu_16(hdr->total_length) + sizeof(struct rte_gre_hdr) + ); + + /* Incremental checksum: only total_length and ttl|proto changed. */ + uint16_t new_ttl_proto; + memcpy(&new_ttl_proto, &hdr->time_to_live, sizeof(uint16_t)); + + uint16_t cksum = ~hdr->hdr_checksum; + cksum = csum_minus(cksum, old_total_length); + cksum = csum_minus(cksum, old_ttl_proto); + cksum = csum_plus(cksum, hdr->total_length); + cksum = csum_plus(cksum, new_ttl_proto); + hdr->hdr_checksum = (cksum == 0xffff) ? cksum : ~cksum; +} + +void +insert_gre_header( + struct packet *packet, bool is_outer_ipv6, bool is_inner_ipv4 +) { + struct rte_mbuf *mbuf = packet_to_mbuf(packet); + const uint16_t gre_size = sizeof(struct rte_gre_hdr); + + if (unlikely(rte_pktmbuf_prepend(mbuf, gre_size) == NULL)) { + return; + } + + uint16_t outer_ip_size = is_outer_ipv6 ? sizeof(struct rte_ipv6_hdr) + : sizeof(struct rte_ipv4_hdr); + uint16_t prefix_len = packet->network_header.offset + outer_ip_size; + + memmove(rte_pktmbuf_mtod(mbuf, char *), + rte_pktmbuf_mtod_offset(mbuf, char *, gre_size), + prefix_len); + + if (is_outer_ipv6) { + adjust_outer_ipv6_for_gre(mbuf, packet->network_header.offset); + } else { + adjust_outer_ipv4_for_gre(mbuf, packet->network_header.offset); + } + + struct rte_gre_hdr *gre = + rte_pktmbuf_mtod_offset(mbuf, struct rte_gre_hdr *, prefix_len); + memset(gre, 0, sizeof(struct rte_gre_hdr)); + gre->proto = rte_cpu_to_be_16( + is_inner_ipv4 ? RTE_ETHER_TYPE_IPV4 : RTE_ETHER_TYPE_IPV6 + ); + + packet->transport_header.offset += gre_size; +} diff --git a/modules/balancer/dataplane/l4/gre.h b/modules/balancer/dataplane/l4/gre.h new file mode 100644 index 000000000..45863e695 --- /dev/null +++ b/modules/balancer/dataplane/l4/gre.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +struct packet; + +/* + * Insert a GRE header between the outer IP header and the inner payload. + * + * Shifts L2+outer-L3 forward by sizeof(rte_gre_hdr) bytes, places a + * zeroed GRE header in the gap, and adjusts the outer IP header's + * protocol/length fields. IPv4 checksum is recomputed. + * + * is_outer_ipv6 — whether the outer tunnel header is IPv6. + * is_inner_ipv4 — whether the inner (encapsulated) packet is IPv4. + */ +void +insert_gre_header( + struct packet *packet, bool is_outer_ipv6, bool is_inner_ipv4 +); diff --git a/modules/balancer/dataplane/l4/group.c b/modules/balancer/dataplane/l4/group.c new file mode 100644 index 000000000..de436f11f --- /dev/null +++ b/modules/balancer/dataplane/l4/group.c @@ -0,0 +1,60 @@ +#include "common/likely.h" +#include + +#include "group.h" + +/* + * Maximum number of distinct VS IDs we track per batch. + * Practically 1-5 in normal traffic; 16 covers even degenerate cases. + * If exceeded, grouping is skipped (packets still processed correctly, + * just not batched per-VS for ACL queries). + */ +#define MAX_GROUPS 16 + +void +group_by_id(uint32_t *vs_ids, uint8_t *order, size_t count) { + /* Pass 1: collect unique VS IDs and count per group. */ + uint32_t unique_vs[MAX_GROUPS]; + size_t counts[MAX_GROUPS]; + size_t ngroups = 0; + + for (size_t i = 0; i < count; ++i) { + for (size_t g = 0; g < ngroups; ++g) { + if (unique_vs[g] == vs_ids[i]) { + counts[g]++; + goto next; + } + } + if (unlikely(ngroups == MAX_GROUPS)) { + return; + } + unique_vs[ngroups] = vs_ids[i]; + counts[ngroups] = 1; + ngroups++; + next:; + } + + /* Compute scatter offsets from prefix sums. */ + size_t offsets[MAX_GROUPS]; + offsets[0] = 0; + for (size_t g = 1; g < ngroups; ++g) { + offsets[g] = offsets[g - 1] + counts[g - 1]; + } + + /* Pass 2: scatter into temporary arrays. */ + uint32_t tmp_ids[count]; + uint8_t tmp_order[count]; + for (size_t i = 0; i < count; ++i) { + for (size_t g = 0; g < ngroups; ++g) { + if (unique_vs[g] == vs_ids[i]) { + size_t pos = offsets[g]++; + tmp_ids[pos] = vs_ids[i]; + tmp_order[pos] = order[i]; + break; + } + } + } + + memcpy(vs_ids, tmp_ids, count * sizeof(uint32_t)); + memcpy(order, tmp_order, count * sizeof(uint8_t)); +} \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/group.h b/modules/balancer/dataplane/l4/group.h new file mode 100644 index 000000000..033ba0b38 --- /dev/null +++ b/modules/balancer/dataplane/l4/group.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +/* + * Rearrange vs_ids[] and order[] so that entries with the same vs_id + * are contiguous. Uses a two-pass counting sort with a small group + * table, giving O(n * k) time for k distinct VS IDs. + * + * If there is too many distinct VS, the grouping is skipped. + */ +void +group_by_id(uint32_t *ids, uint8_t *order, size_t count); \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/handle.c b/modules/balancer/dataplane/l4/handle.c new file mode 100644 index 000000000..2c2c30dcb --- /dev/null +++ b/modules/balancer/dataplane/l4/handle.c @@ -0,0 +1,490 @@ +#include + +#include "common/likely.h" +#include "common/memory_address.h" +#include "common/network.h" +#include "filter.h" +#include "lib/dataplane/module/packet_front.h" + +#include "filter/query.h" + +#include "context.h" +#include "dataplane.h" +#include "group.h" +#include "packet.h" +#include "real_helpers.h" +#include "resolve.h" +#include "session/table.h" +#include "tunnel.h" +#include "types/session.h" +#include "types/stats.h" +#include "vs_helpers.h" + +FILTER_QUERY_DECLARE( + ipv4_vs_matcher, net4_fast_dst, port_fast_dst, proto_range_fast +); +FILTER_QUERY_DECLARE(ipv4_vs_acl, net4_fast_src, port_fast_src); + +FILTER_QUERY_DECLARE( + ipv6_vs_matcher, net6_fast_dst, port_fast_dst, proto_range_fast +); +FILTER_QUERY_DECLARE(ipv6_vs_acl, net6_fast_src, port_fast_src); + +/* + * Batch ACL query for a group of packets sharing the same VS. + * Updates incoming and ACL stats, drops packets that fail. + * Appends passing packets to pkt_ctxs. + * Returns the number of packets that passed. + */ +static size_t +filter_vs_group( + struct worker_context *context, + struct balancer_vs *vs, + struct packet **group_pkts, + size_t group_size, + struct l4_packet_context *pkt_ctxs, + bool is_ipv6 +) { + struct balancer_vs_stats *vs_stats = + vs_get_stats(vs, context->worker_idx, context->counter_storage); + + uint32_t acl_results[group_size]; + if (is_ipv6) { + filter_query( + ADDR_OF(&vs->acl), + ipv6_vs_acl, + group_pkts, + acl_results, + group_size + ); + } else { + filter_query( + ADDR_OF(&vs->acl), + ipv4_vs_acl, + group_pkts, + acl_results, + group_size + ); + } + + size_t passed = 0; + for (size_t i = 0; i < group_size; ++i) { + uint32_t rule_idx = acl_results[i]; + if (unlikely(rule_idx == FILTER_RULE_INVALID)) { + vs_stats->packet_src_not_allowed += 1; + packet_front_drop(context->packet_front, group_pkts[i]); + continue; + } + + vs_stats->incoming_packets += 1; + vs_stats->incoming_bytes += group_pkts[i]->mbuf->pkt_len; + + uint64_t *rule_counter = vs_get_acl_stats( + vs, + context->worker_idx, + context->counter_storage, + rule_idx + ); + + /* Dont track stats for rules with empty tag. */ + if (rule_counter != NULL) { + *rule_counter += 1; + } + + pkt_ctxs[passed++] = (struct l4_packet_context){ + .packet = group_pkts[i], + .matched_vs = vs, + .matched_vs_stats = vs_stats, + }; + } + + return passed; +} + +/* + * Match packets to virtual services, then filter by ACL. + * + * 1. Batch VS lookup — classify all packets at once. + * 2. Group matched packets by VS. + * 3. Batch ACL per group — drop packets that fail. + * + * Returns the number of packets that passed both checks. + * is_ipv6 is a compile-time constant so the compiler + * eliminates the dead branches after inlining. + */ +static size_t +match_and_filter( + struct worker_context *context, + struct packet **packets, + size_t packets_count, + struct l4_packet_context *pkt_ctxs, + bool is_ipv6 +) { + /* Batch VS lookup. */ + uint32_t vs_results[packets_count]; + if (is_ipv6) { + filter_query( + ADDR_OF(&context->packet_handler->ipv6_vs_matcher), + ipv6_vs_matcher, + packets, + vs_results, + packets_count + ); + } else { + filter_query( + ADDR_OF(&context->packet_handler->ipv4_vs_matcher), + ipv4_vs_matcher, + packets, + vs_results, + packets_count + ); + } + + /* Drop unmatched, record indices and vs_ids of the rest. */ + uint8_t order[packets_count]; + uint32_t vs_ids[packets_count]; + size_t matched_count = 0; + + for (size_t i = 0; i < packets_count; ++i) { + uint32_t vs_id = vs_results[i]; + if (unlikely(vs_id == FILTER_RULE_INVALID)) { + context->l4_stats->select_vs_failed += 1; + packet_front_drop(context->packet_front, packets[i]); + continue; + } + + order[matched_count] = i; + vs_ids[matched_count] = vs_id; + matched_count++; + } + + if (unlikely(matched_count == 0)) { + return 0; + } + + /* + * Group by VS for batched ACL queries. If grouping fails + * (too many distinct VSes), the loop below still works + * correctly with ungrouped packets — just smaller batches. + */ + group_by_id(vs_ids, order, matched_count); + + struct balancer_vs *virtual_services = + ADDR_OF(&context->packet_handler->vs); + + /* Run ACL per VS group. */ + size_t total_passed = 0; + size_t pkt_idx = 0; + struct packet *group_pkts[matched_count]; + while (pkt_idx < matched_count) { + uint32_t cur_vs_id = vs_ids[pkt_idx]; + + /* Collect contiguous packets for this VS. */ + size_t group_size = 0; + while (pkt_idx < matched_count && vs_ids[pkt_idx] == cur_vs_id + ) { + group_pkts[group_size++] = packets[order[pkt_idx]]; + pkt_idx++; + } + + total_passed += filter_vs_group( + context, + virtual_services + cur_vs_id, + group_pkts, + group_size, + pkt_ctxs + total_passed, + is_ipv6 + ); + } + + return total_passed; +} + +static void +extract_ipv4_metadata(struct l4_packet_context *pkt_ctx) { + struct rte_ipv4_hdr *ipv4_hdr = rte_pktmbuf_mtod_offset( + pkt_ctx->packet->mbuf, + struct rte_ipv4_hdr *, + pkt_ctx->packet->network_header.offset + ); + __builtin_memcpy( + pkt_ctx->session_id.client_ip, + (uint8_t *)&ipv4_hdr->src_addr, + NET4_LEN + ); +} + +static void +extract_ipv6_metadata(struct l4_packet_context *pkt_ctx) { + struct rte_ipv6_hdr *ipv6_hdr = rte_pktmbuf_mtod_offset( + pkt_ctx->packet->mbuf, + struct rte_ipv6_hdr *, + pkt_ctx->packet->network_header.offset + ); + memcpy(pkt_ctx->session_id.client_ip, ipv6_hdr->src_addr, NET6_LEN); +} + +static void +extract_network_metadata(struct l4_packet_context *pkt_ctx, bool is_ipv6) { + if (is_ipv6) { + extract_ipv6_metadata(pkt_ctx); + } else { + extract_ipv4_metadata(pkt_ctx); + } +} + +static uint8_t +tcp_session_timeout( + struct balancer_session_timeouts *timeouts, uint16_t tcp_flags +) { + if ((tcp_flags & RTE_TCP_SYN_FLAG) == RTE_TCP_SYN_FLAG) { + if ((tcp_flags & RTE_TCP_ACK_FLAG) == RTE_TCP_ACK_FLAG) { + return timeouts->tcp_syn_ack; + } + return timeouts->tcp_syn; + } + if (tcp_flags & RTE_TCP_FIN_FLAG) { + return timeouts->tcp_fin; + } + return timeouts->tcp; +} + +static void +extract_tcp_metadata( + struct l4_packet_context *pkt_ctx, + struct balancer_session_timeouts *timeouts +) { + struct rte_tcp_hdr *hdr = rte_pktmbuf_mtod_offset( + pkt_ctx->packet->mbuf, + struct rte_tcp_hdr *, + pkt_ctx->packet->transport_header.offset + ); + + uint8_t flags = hdr->tcp_flags; + + pkt_ctx->session_id.client_port = hdr->src_port; + pkt_ctx->session_timeout = tcp_session_timeout(timeouts, flags); + pkt_ctx->can_reschedule = (flags & (RTE_TCP_SYN_FLAG | RTE_TCP_RST_FLAG) + ) == RTE_TCP_SYN_FLAG; +} + +static void +extract_udp_metadata( + struct l4_packet_context *pkt_ctx, + struct balancer_session_timeouts *timeouts +) { + struct rte_udp_hdr *hdr = rte_pktmbuf_mtod_offset( + pkt_ctx->packet->mbuf, + struct rte_udp_hdr *, + pkt_ctx->packet->transport_header.offset + ); + + pkt_ctx->session_id.client_port = hdr->src_port; + pkt_ctx->session_timeout = timeouts->udp; + pkt_ctx->can_reschedule = true; +} + +static void +extract_transport_metadata( + struct l4_packet_context *pkt_ctx, + struct balancer_session_timeouts *timeouts +) { + if (pkt_ctx->packet->transport_header.type == IPPROTO_TCP) { + extract_tcp_metadata(pkt_ctx, timeouts); + } else { + extract_udp_metadata(pkt_ctx, timeouts); + } +} + +static void +extract_session_metadata( + struct l4_packet_context *pkt_ctxs, + size_t pkt_ctx_count, + struct balancer_session_timeouts *timeouts, + bool is_ipv6 +) { + const size_t prefetch_distance = 2; + + for (size_t pkt_idx = 0; pkt_idx < pkt_ctx_count; ++pkt_idx) { + if (pkt_idx + prefetch_distance < pkt_ctx_count) { + rte_prefetch0(rte_pktmbuf_mtod(pkt_ctxs[pkt_idx + prefetch_distance].packet->mbuf, void *)); + } + + struct l4_packet_context *pkt_ctx = &pkt_ctxs[pkt_idx]; + pkt_ctx->session_id.vs_stable_idx = + pkt_ctx->matched_vs->stable_idx; + + /* + * Zero client_ip + padding so the session ID hashes + * deterministically. extract_network_metadata overwrites + * the relevant prefix (4 bytes for IPv4, 16 for IPv6). + */ + __builtin_memset( + pkt_ctx->session_id.client_ip + NET4_LEN, + 0, + NET6_LEN - NET4_LEN + BALANCER_SESSION_ID_PADDING + ); + + extract_network_metadata(pkt_ctx, is_ipv6); + extract_transport_metadata(pkt_ctx, timeouts); + } +} + +static void +select_reals( + struct worker_context *context, + struct l4_packet_context *pkt_ctxs, + size_t pkt_ctx_count, + struct balancer_session_table *session_table, + uint64_t current_table_gen +) { + const size_t prefetch_distance = 4; + + for (size_t pkt_idx = 0; pkt_idx < pkt_ctx_count; ++pkt_idx) { + if (pkt_idx + prefetch_distance < pkt_ctx_count) { + st_prefetch_session( + session_table, + current_table_gen, + &pkt_ctxs[pkt_idx + prefetch_distance] + .session_id + ); + } + + struct l4_packet_context *pkt_ctx = &pkt_ctxs[pkt_idx]; + + struct balancer_real *real = resolve_real( + context, pkt_ctx, session_table, current_table_gen + ); + + if (unlikely(real == NULL)) { + pkt_ctx->is_dropped = true; + packet_front_drop( + context->packet_front, pkt_ctx->packet + ); + continue; + } + + pkt_ctx->resolved_real = real; + pkt_ctx->resolved_real_stats = real_get_stats( + real, context->worker_idx, context->counter_storage + ); + } +} + +static void +tunnel_packets( + struct worker_context *context, + struct l4_packet_context *pkt_ctxs, + size_t pkt_ctx_count, + bool is_ipv6 +) { + struct packet_front *packet_front = context->packet_front; + for (size_t pkt_idx = 0; pkt_idx < pkt_ctx_count; ++pkt_idx) { + struct l4_packet_context *pkt_ctx = &pkt_ctxs[pkt_idx]; + if (unlikely(pkt_ctx->is_dropped)) { + continue; + } + + if (is_ipv6) { + tunnel_ipv6_packet(pkt_ctx); + } else { + tunnel_ipv4_packet(pkt_ctx); + } + + packet_front_output(packet_front, pkt_ctx->packet); + + context->l4_stats->outgoing_packets += 1; + context->common_stats->outgoing_packets += 1; + context->common_stats->outgoing_bytes += + pkt_ctx->packet->mbuf->pkt_len; + } +} + +/* + * L4 packet processing pipeline. + * + * 1. match_and_filter: VS lookup + ACL check (batched). + * Drops packets with no VS or failing ACL. + * + * 2. extract_session_metadata: parse IP/transport headers into + * session_id, timeout, and reschedule flag. + * + * 3. select_reals: look up or create a session, pick a + * real server. Runs inside two critical + * sections: + * - reals_selector_guard: RCU guard for real selector rings, + * held across both select and tunnel + * so the ring isn't freed mid-use. + * - session table CS: pins the current session map generation, + * released after all lookups are done. + * + * 4. tunnel_packets: encapsulate and forward to the chosen real. + */ +static void +balancer_handle_l4_packets( + struct worker_context *context, + struct packet **packets, + size_t packets_count, + bool is_ipv6 +) { + context->l4_stats->incoming_packets += packets_count; + + /* + * Not zero-initialized: filter_vs_group populates each used + * entry via compound literal, which zeros all unset fields + * (including is_dropped = false). + */ + struct l4_packet_context pkt_ctxs[packets_count]; + + size_t pkt_ctx_count = match_and_filter( + context, packets, packets_count, pkt_ctxs, is_ipv6 + ); + + extract_session_metadata( + pkt_ctxs, + pkt_ctx_count, + &context->packet_handler->session_timeouts, + is_ipv6 + ); + + struct balancer_session_table *session_table = + ADDR_OF(&context->packet_handler->session_table); + + rcu_t *reals_selector_guard = &context->packet_handler->rcu; + rcu_read_begin(reals_selector_guard, context->worker_idx); + uint64_t current_table_gen = + st_begin_cs(session_table, context->worker_idx); + + select_reals( + context, + pkt_ctxs, + pkt_ctx_count, + session_table, + current_table_gen + ); + + st_end_cs(session_table, context->worker_idx); + rcu_read_end(reals_selector_guard, context->worker_idx); + + tunnel_packets(context, pkt_ctxs, pkt_ctx_count, is_ipv6); +} + +void +balancer_handle_l4_ipv4( + struct worker_context *context, + struct packet **packets, + size_t packets_count +) { + const bool is_ipv6 = false; + balancer_handle_l4_packets(context, packets, packets_count, is_ipv6); +} + +void +balancer_handle_l4_ipv6( + struct worker_context *context, + struct packet **packets, + size_t packets_count +) { + const bool is_ipv6 = true; + balancer_handle_l4_packets(context, packets, packets_count, is_ipv6); +} diff --git a/modules/balancer/dataplane/l4/handle.h b/modules/balancer/dataplane/l4/handle.h index 9050b913e..713be6417 100644 --- a/modules/balancer/dataplane/l4/handle.h +++ b/modules/balancer/dataplane/l4/handle.h @@ -1,109 +1,21 @@ #pragma once -#include "api/session.h" -#include "common/likely.h" -#include "flow/helpers.h" -#include "flow/stats.h" -#include "meta.h" -#include "select.h" - -#include "../lookup.h" -#include "../tunnel.h" -#include "session_table.h" -#include "state/state.h" -#include - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -handle_l4_packets(struct packet_ctx *ctxs, size_t count) { - assert(count > 0); - - struct packet_handler *handler = ctxs[0].handler; - struct sessions_timeouts *timeouts = &handler->sessions_timeouts; - - // start critical section for the session table - struct session_table *table = &ctxs[0].balancer_state->session_table; - uint64_t current_table_gen = - session_table_begin_cs(table, ctxs[0].worker_idx); - - // select virtual service for each packet - for (size_t i = 0; i < count; ++i) { - struct packet_ctx *ctx = &ctxs[i]; - - // skip processed packets - if (unlikely(ctx->processed)) { - continue; - } - - // update stats - L4_STATS_INC(incoming_packets, ctx); - - // 1. Validate packet and set metadata - struct packet_metadata meta; - int res = fill_packet_metadata(ctx->packet, &meta); - if (unlikely(res != 0)) { // unexpected packet type - L4_STATS_INC(invalid_packets, ctx); - packet_ctx_drop_packet(ctx); - continue; - } - - // 2. Lookup virtual service for which packet is - // directed to - - struct vs *vs = vs_lookup_and_fw(ctx); - if (unlikely(vs == NULL)) { // not found virtual service - L4_STATS_INC(select_vs_failed, ctx); - packet_ctx_drop_packet(ctx); - continue; - } - - // update VS incoming stats - packet_ctx_update_vs_stats_on_incoming_packet(ctx); - - // fill session id and timeout - fill_session_id(&ctx->session, &meta, vs); - ctx->session_timeout = session_timeout(timeouts, &meta); - - ctx->transport_proto = meta.transport_proto; - ctx->tcp_flags = meta.tcp_flags; - - prefetch_session(table, current_table_gen, &ctx->session); - } - - for (size_t i = 0; i < count; ++i) { - struct packet_ctx *ctx = &ctxs[i]; - if (unlikely(ctx->processed)) { - continue; - } - struct real *selected_real = - select_real(ctx, ctx->vs.ptr, table, current_table_gen); - if (unlikely(selected_real == NULL)) { // failed to select real - // update stats - L4_STATS_INC(select_real_failed, ctx); - packet_ctx_drop_packet(ctx); - continue; - } - } - - session_table_end_cs(table, ctxs[0].worker_idx); - - for (size_t i = 0; i < count; ++i) { - struct packet_ctx *ctx = &ctxs[i]; - if (unlikely(ctx->processed)) { - continue; - } - - // 4. Add tunnel to the selected real for the packet - - tunnel_packet(ctx->vs.ptr, ctx->real.ptr, ctx->packet); - - // 5. Pass packet to the next module - - packet_ctx_send_packet(ctx); - - // update stats - L4_STATS_INC(outgoing_packets, ctx); - packet_ctx_update_common_stats_on_outgoing_packet(ctx); - } -} +#include +#include + +struct packet; +struct worker_context; + +void +balancer_handle_l4_ipv4( + struct worker_context *context, + struct packet **packets, + size_t packets_count +); + +void +balancer_handle_l4_ipv6( + struct worker_context *context, + struct packet **packets, + size_t packets_count +); \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/mss.c b/modules/balancer/dataplane/l4/mss.c new file mode 100644 index 000000000..7e25d8e8f --- /dev/null +++ b/modules/balancer/dataplane/l4/mss.c @@ -0,0 +1,217 @@ +#include +#include + +#include +#include +#include + +#include "common/checksum.h" + +#include "lib/dataplane/packet/packet.h" +#include "rte_branch_prediction.h" + +#include "mss.h" + +#define TCP_OPTION_KIND_EOL 0 +#define TCP_OPTION_KIND_NOP 1 +#define TCP_OPTION_KIND_MSS 2 +#define TCP_OPTION_MSS_LEN 4 + +/* MSS value to clamp down to (accounts for tunnel overhead on IPv6). */ +#define FIX_MSS_SIZE 1220 + +/* MSS value to insert when no MSS option is present. */ +#define DEFAULT_MSS_SIZE 536 + +/* Maximum TCP data offset in 32-bit words (4-bit field). */ +#define TCP_DATA_OFF_MAX 0x0F + +struct tcp_option { + uint8_t kind; + uint8_t len; + uint8_t data[]; +} __attribute__((__packed__)); + +/* + * Check whether the packet is a TCP SYN (without RST) that may + * carry MSS options. Returns the TCP header pointer, or NULL + * if the packet should be skipped. + */ +static struct rte_tcp_hdr * +get_syn_tcp_header(struct packet *packet) { + if (unlikely(packet->transport_header.type != IPPROTO_TCP)) { + return NULL; + } + + struct rte_tcp_hdr *tcp = rte_pktmbuf_mtod_offset( + packet->mbuf, + struct rte_tcp_hdr *, + packet->transport_header.offset + ); + + uint8_t flags = tcp->tcp_flags; + if ((flags & (RTE_TCP_SYN_FLAG | RTE_TCP_RST_FLAG)) != + RTE_TCP_SYN_FLAG) { + return NULL; + } + + return tcp; +} + +/* + * Validate the TCP data offset and return the options length in bytes. + * Returns 0 if the data offset is invalid or out of packet bounds. + */ +static uint16_t +tcp_data_offset(struct packet *packet, struct rte_tcp_hdr *tcp) { + uint16_t data_offset = (tcp->data_off >> 4) * 4; + + if (unlikely(data_offset < sizeof(struct rte_tcp_hdr))) { + return 0; + } + + uint16_t pkt_len = rte_pktmbuf_pkt_len(packet->mbuf); + if (unlikely(packet->transport_header.offset + data_offset > pkt_len)) { + return 0; + } + + return data_offset; +} + +/* + * Scan TCP options for an existing MSS option. + * If found and its value exceeds FIX_MSS_SIZE, clamp it and + * update the TCP checksum incrementally. + * + * Returns true if an MSS option was found (regardless of whether + * it was modified), false if no MSS option exists. + */ +static bool +try_clamp_existing_mss( + struct packet *packet, struct rte_tcp_hdr *tcp, uint16_t data_offset +) { + uint16_t offset = sizeof(struct rte_tcp_hdr); + + while (offset + TCP_OPTION_MSS_LEN <= data_offset) { + struct tcp_option *opt = rte_pktmbuf_mtod_offset( + packet->mbuf, + struct tcp_option *, + packet->transport_header.offset + offset + ); + + if (opt->kind == TCP_OPTION_KIND_MSS) { + uint16_t old_mss = + rte_be_to_cpu_16(*(uint16_t *)opt->data); + if (old_mss <= FIX_MSS_SIZE) { + return true; + } + + /* Clamp MSS and update checksum incrementally. */ + uint16_t cksum = ~tcp->cksum; + cksum = csum_minus(cksum, *(uint16_t *)opt->data); + *(uint16_t *)opt->data = rte_cpu_to_be_16(FIX_MSS_SIZE); + cksum = csum_plus(cksum, *(uint16_t *)opt->data); + tcp->cksum = (cksum == 0xffff) ? cksum : ~cksum; + return true; + } + + if (opt->kind == TCP_OPTION_KIND_EOL || + opt->kind == TCP_OPTION_KIND_NOP) { + offset++; + } else { + if (unlikely(opt->len == 0)) { + return false; /* malformed header */ + } + offset += opt->len; + } + } + + return false; +} + +/* + * Insert a new MSS option (DEFAULT_MSS_SIZE) right after the fixed + * TCP header. Shifts L2 + L3 + fixed TCP header backward by 4 bytes, + * then writes the option into the gap. + * + * Updates TCP data offset, TCP checksum, and IPv6 payload length. + */ +static void +insert_mss_option(struct packet *packet, struct rte_tcp_hdr *tcp) { + uint16_t data_offset = (tcp->data_off >> 4) * 4; + + /* Check if there is room for one more 32-bit option word. */ + if (unlikely( + data_offset > (TCP_DATA_OFF_MAX << 2) - TCP_OPTION_MSS_LEN + )) { + return; + } + + struct rte_mbuf *mbuf = packet->mbuf; + + /* Extend the packet at the front by TCP_OPTION_MSS_LEN bytes. */ + if (unlikely(rte_pktmbuf_prepend(mbuf, TCP_OPTION_MSS_LEN) == NULL)) { + return; + } + + /* Move everything before the TCP options backward. */ + uint16_t prefix_len = + packet->transport_header.offset + sizeof(struct rte_tcp_hdr); + memmove(rte_pktmbuf_mtod(mbuf, char *), + rte_pktmbuf_mtod_offset(mbuf, char *, TCP_OPTION_MSS_LEN), + prefix_len); + + /* Write the MSS option into the gap. */ + struct tcp_option *opt = + rte_pktmbuf_mtod_offset(mbuf, struct tcp_option *, prefix_len); + opt->kind = TCP_OPTION_KIND_MSS; + opt->len = TCP_OPTION_MSS_LEN; + *(uint16_t *)opt->data = rte_cpu_to_be_16(DEFAULT_MSS_SIZE); + + /* Re-fetch TCP header (it moved due to prepend). */ + tcp = rte_pktmbuf_mtod_offset( + mbuf, struct rte_tcp_hdr *, packet->transport_header.offset + ); + + /* Increment data offset by one 32-bit word. */ + tcp->data_off += 0x1 << 4; + + /* + * Update TCP checksum incrementally. + * + * data_off is the leading byte of its 16-bit aligned pair + * inside the TCP header, so no byte-swap is needed for the + * data_off delta. + */ + uint16_t cksum = ~tcp->cksum; + cksum = csum_plus(cksum, 0x1 << 4); + cksum = csum_plus(cksum, *(uint16_t *)opt); + cksum = csum_plus(cksum, *(uint16_t *)opt->data); + cksum = csum_plus(cksum, rte_cpu_to_be_16(TCP_OPTION_MSS_LEN)); + tcp->cksum = (cksum == 0xffff) ? cksum : ~cksum; + + /* Update IPv6 payload length. */ + struct rte_ipv6_hdr *ip6 = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv6_hdr *, packet->network_header.offset + ); + ip6->payload_len = rte_cpu_to_be_16( + rte_be_to_cpu_16(ip6->payload_len) + TCP_OPTION_MSS_LEN + ); +} + +void +fix_mss_ipv6(struct packet *packet) { + struct rte_tcp_hdr *tcp = get_syn_tcp_header(packet); + if (unlikely(tcp == NULL)) { + return; + } + + uint16_t data_offset = tcp_data_offset(packet, tcp); + if (unlikely(data_offset == 0)) { + return; + } + + if (!try_clamp_existing_mss(packet, tcp, data_offset)) { + insert_mss_option(packet, tcp); + } +} diff --git a/modules/balancer/dataplane/l4/mss.h b/modules/balancer/dataplane/l4/mss.h new file mode 100644 index 000000000..e2c78cc46 --- /dev/null +++ b/modules/balancer/dataplane/l4/mss.h @@ -0,0 +1,18 @@ +#pragma once + +struct packet; + +/* + * Clamp TCP MSS option for IPv6 packets to avoid fragmentation + * after tunnel encapsulation. + * + * Only processes TCP SYN packets (without RST). For such packets: + * - If an MSS option exists and exceeds FIX_MSS_SIZE (1220), + * it is clamped down and the TCP checksum is updated incrementally. + * - If no MSS option exists and there is room in the TCP header, + * a new MSS option with DEFAULT_MSS_SIZE (536) is inserted. + * + * Does nothing for non-TCP or non-SYN packets. + */ +void +fix_mss_ipv6(struct packet *packet); diff --git a/modules/balancer/dataplane/l4/resolve.c b/modules/balancer/dataplane/l4/resolve.c new file mode 100644 index 000000000..df83aefb6 --- /dev/null +++ b/modules/balancer/dataplane/l4/resolve.c @@ -0,0 +1,297 @@ +#include +#include + +#include +#include + +#include "common/big_array.h" +#include "common/memory_address.h" + +#include "common/ttlmap/detail/lock.h" +#include "lib/dataplane/packet/packet.h" + +#include "context.h" +#include "packet.h" +#include "real_helpers.h" +#include "rte_branch_prediction.h" +#include "session/table.h" +#include "session/tracker.h" +#include "types/selector.h" +#include "types/vs.h" + +#define INVALID_REAL_IDX ((uint32_t)-1) + +static uint32_t +ring_get(struct balancer_ring *ring, uint64_t index) { + if (ring->real_ids.size > 0) { + uint32_t pos = index % (ring->real_ids.size / sizeof(uint32_t)); + uint32_t val; + memcpy(&val, + big_array_get(&ring->real_ids, pos * sizeof(uint32_t)), + sizeof(val)); + return val; + } else { + return INVALID_REAL_IDX; + } +} + +static uint32_t +selector_select( + struct balancer_real_selector *selector, uint32_t worker, uint32_t hash +) { + size_t ring_id = + atomic_load_explicit(&selector->ring_id, memory_order_acquire); + struct balancer_ring *ring = &selector->rings[ring_id]; + + /* + * Here branch predictor works well, + * because use_rr is long-term variable. + * So, we dont use ternary operator here. + */ + uint64_t idx; + if (selector->use_rr) { + idx = selector->workers[worker].value++; + } else { + idx = hash; + } + + return ring_get(ring, idx); +} + +static void +update_session_state( + struct balancer_session_state *session_state, + struct balancer_real *real, + uint32_t worker_idx, + uint32_t now, + uint8_t session_timeout +) { + sessions_tracker_prolong_session( + ADDR_OF(&real->tracker_shards), + worker_idx, + session_state->last_packet_timestamp, + session_state->timeout, + now, + session_timeout + ); + session_state->last_packet_timestamp = now; + session_state->timeout = session_timeout; +} + +static void +create_session_state( + struct balancer_session_state *session_state, + struct balancer_real *real, + uint32_t worker_idx, + uint32_t now, + uint8_t session_timeout +) { + sessions_tracker_new_session( + ADDR_OF(&real->tracker_shards), worker_idx, now, session_timeout + ); + session_state->create_timestamp = now; + session_state->last_packet_timestamp = now; + session_state->real_stable_idx = real->stable_idx; + session_state->timeout = session_timeout; +} + +/* + * Try to reuse the real from an existing session. + * Returns the real if still valid and enabled, NULL otherwise. + */ +static struct balancer_real * +try_reuse_session_real( + struct worker_context *context, + struct l4_packet_context *pkt_ctx, + struct balancer_session_state *session_state +) { + struct balancer_vs *vs = pkt_ctx->matched_vs; + struct balancer_vs_stats *vs_stats = pkt_ctx->matched_vs_stats; + + uint32_t real_idx = + real_idx_from_stable_idx(session_state->real_stable_idx); + struct balancer_real *real = ADDR_OF(&vs->reals) + real_idx; + + /* + * Slot was reused for a different real since the session was created + * or the real is removed without reuse. + */ + if (unlikely( + session_state->real_stable_idx != real->stable_idx || + real_is_removed(real) + )) { + vs_stats->real_is_removed += 1; + return NULL; + } + + if (unlikely(!real_is_enabled(real))) { + struct balancer_real_stats *real_stats = real_get_stats( + real, context->worker_idx, context->counter_storage + ); + real_stats->packets_real_disabled += 1; + vs_stats->real_is_disabled += 1; + return NULL; + } + + return real; +} + +static struct balancer_real * +resolve_real_ops( + struct worker_context *context, + struct l4_packet_context *pkt_ctx, + struct balancer_vs *vs +) { + uint32_t real_idx = selector_select( + ADDR_OF(&vs->selector), + context->worker_idx, + pkt_ctx->packet->hash + ); + if (unlikely(real_idx == INVALID_REAL_IDX)) { + pkt_ctx->matched_vs_stats->no_reals += 1; + return NULL; + } + return ADDR_OF(&vs->reals) + real_idx; +} + +/* + * Select a new real from the ring and write it into the session slot. + * Returns the real, or NULL if the ring is empty. + */ +static struct balancer_real * +schedule_new_real( + struct worker_context *context, + struct l4_packet_context *pkt_ctx, + struct balancer_vs *vs, + struct balancer_session_state *session_state +) { + uint32_t real_id = selector_select( + ADDR_OF(&vs->selector), + context->worker_idx, + pkt_ctx->packet->hash + ); + if (unlikely(real_id == INVALID_REAL_IDX)) { + pkt_ctx->matched_vs_stats->no_reals += 1; + return NULL; + } + + struct balancer_real *real = ADDR_OF(&vs->reals) + real_id; + create_session_state( + session_state, + real, + context->worker_idx, + context->now, + pkt_ctx->session_timeout + ); + + struct balancer_real_stats *real_stats = real_get_stats( + real, context->worker_idx, context->counter_storage + ); + real_stats->created_sessions += 1; + pkt_ctx->matched_vs_stats->created_sessions += 1; + + return real; +} + +struct balancer_real * +resolve_real( + struct worker_context *context, + struct l4_packet_context *pkt_ctx, + struct balancer_session_table *st, + uint64_t current_table_gen +) { + struct balancer_vs *vs = pkt_ctx->matched_vs; + struct balancer_vs_stats *vs_stats = pkt_ctx->matched_vs_stats; + + /* OPS: stateless selection, no session involved. */ + if (vs->flags & balancer_vs_ops) { + return resolve_real_ops(context, pkt_ctx, vs); + } + + /* + * Acquire a session slot from the session table. + * + * st_get_or_create_session either finds an existing entry + * or allocates a new one. In both cases it returns a locked + * pointer to the session_state. The caller must eventually + * call st_unlock_session to release the lock. + * + * Possible results: + * - SESSION_FOUND: existing entry, session_state is populated. + * - SESSION_CREATED: new (or recycled) entry, session_state is blank. + * - SESSION_TABLE_OVERFLOW: table is full, no slot acquired, no lock + * held. + */ + struct balancer_session_state *session_state = NULL; + ttlmap_lock_t *session_lock; + int result = st_get_or_create_session( + st, + current_table_gen, + context->now, + pkt_ctx->session_timeout, + &pkt_ctx->session_id, + &session_state, + &session_lock + ); + + if (unlikely(result == SESSION_TABLE_OVERFLOW)) { + vs_stats->session_table_overflow += 1; + return NULL; + } + + /* + * From here on, session_lock is held and must be released + * before returning. session_state points to the locked slot. + */ + + /* Try to reuse the real from an existing session. */ + if (result == SESSION_FOUND) { + struct balancer_real *real = + try_reuse_session_real(context, pkt_ctx, session_state); + if (real != NULL) { + /* Real is still valid -- prolong the session. */ + update_session_state( + session_state, + real, + context->worker_idx, + context->now, + pkt_ctx->session_timeout + ); + st_unlock_session(session_lock); + return real; + } + /* + * Real is gone (disabled or removed). The session slot + * is still allocated and locked -- fall through to + * reschedule, which will overwrite it with a new real. + */ + } + + /* + * At this point we need to assign a new real to the session slot. + * This happens when: + * - SESSION_CREATED: no prior session existed (new flow). + * - SESSION_FOUND but real is stale (disabled/removed). + * + * Only connection-initiating packets may create or overwrite + * sessions. Non-initiating packets (TCP non-SYN) are dropped + * and the session slot is freed. + */ + if (unlikely(!pkt_ctx->can_reschedule)) { + vs_stats->not_rescheduled_packets += 1; + st_remove_session(session_state); + st_unlock_session(session_lock); + return NULL; + } + + /* Select a new real and write it into the session slot. */ + struct balancer_real *real = + schedule_new_real(context, pkt_ctx, vs, session_state); + if (unlikely(real == NULL)) { + st_remove_session(session_state); + } + + st_unlock_session(session_lock); + + return real; +} \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/resolve.h b/modules/balancer/dataplane/l4/resolve.h new file mode 100644 index 000000000..ca95f3d45 --- /dev/null +++ b/modules/balancer/dataplane/l4/resolve.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +struct worker_context; +struct l4_packet_context; +struct balancer_real; +struct balancer_session_table; + +/* + * Resolve the real server that should handle this packet. + * + * In OPS (one-packet-scheduling) mode, each packet is independently + * assigned to a real via the selector ring. No session is created + * or consulted. + * + * Otherwise, a session slot is acquired from the session table via + * st_get_or_create_session, which either finds an existing entry or + * allocates a new one. In both cases a locked pointer to the + * session_state is returned. The slot is used as follows: + * + * - Session found, real is valid and enabled: + * The session is prolonged (timestamps and timeout updated) + * and the same real is returned. + * + * - Session found, but the real was disabled or removed: + * The session is stale. If the packet is reschedulable + * (TCP SYN or any UDP), a new real is selected from the ring + * and the session slot is overwritten with the new real. + * Non-reschedulable packets (TCP non-SYN) are dropped and + * the session slot is freed. + * + * - No session found (new flow): + * A blank slot is allocated. If the packet is reschedulable, + * a new real is selected and written into the slot. + * Non-reschedulable packets are dropped and the slot is + * freed -- this handles stray TCP ACKs arriving after + * session timeout. + * + * The session lock is always released before returning. + * + * Returns the selected real, or NULL if the packet should be + * dropped. The caller is responsible for dropping NULL-result + * packets. + */ +struct balancer_real * +resolve_real( + struct worker_context *context, + struct l4_packet_context *pkt_ctx, + struct balancer_session_table *session_table, + uint64_t current_table_gen +); diff --git a/modules/balancer/dataplane/l4/select.h b/modules/balancer/dataplane/l4/select.h deleted file mode 100644 index 9dc8b209e..000000000 --- a/modules/balancer/dataplane/l4/select.h +++ /dev/null @@ -1,223 +0,0 @@ -#pragma once - -#include "active_sessions.h" -#include "common/memory_address.h" - -#include "handler/map.h" -#include "handler/vs.h" -#include "rte_tcp.h" -#include "selector.h" -#include "session_table.h" -#include -#include -#include -#include -#include - -#include "../flow/common.h" -#include "../flow/context.h" -#include "../flow/helpers.h" - -#include "state/session.h" -#include "state/session_table.h" - -#include "api/vs.h" - -//////////////////////////////////////////////////////////////////////////////// - -static inline bool -reschedule_real(uint8_t transport_proto, uint16_t tcp_flags) { - // True for UDP and TCP SYN packets - return (transport_proto == IPPROTO_UDP) || - (transport_proto == IPPROTO_TCP && - ((tcp_flags & (RTE_TCP_SYN_FLAG | RTE_TCP_RST_FLAG)) == - RTE_TCP_SYN_FLAG)); -} - -// Selects real and update real and virtual service stats. -static inline struct real * -select_real( - struct packet_ctx *ctx, - struct vs *vs, - struct session_table *table, - uint64_t current_table_gen -) { - struct packet_handler *handler = ctx->handler; - struct real *reals = ADDR_OF(&handler->reals); - - struct map *reals_index = &handler->reals_index; - - const size_t worker_idx = ctx->worker->idx; - const uint32_t now = ctx->now; - - // if `One Packet Scheduling` flag is set, - // we do not account for sessions - if (vs->flags & VS_OPS_FLAG) { - uint32_t local_real_id = selector_select( - &vs->selector, worker_idx, ctx->packet->hash - ); - if (local_real_id == SELECTOR_VALUE_INVALID) { - // discard packet because there are no enabled reals - - // update counter - VS_STATS_INC(no_reals, ctx); - - return NULL; - } - - uint32_t real_id = vs->first_real_idx + local_real_id; - - // select real - struct real *real = &reals[real_id]; - packet_ctx_set_real(ctx, real); - - // update stats - - // real stats - REAL_STATS_INC(ops_packets, ctx); - - // vs stats - VS_STATS_INC(ops_packets, ctx); - - return real; - } - - // get state for the session - struct session_state *session_state = NULL; - session_lock_t *session_lock; - int get_session_result = get_or_create_session( - table, - current_table_gen, - now, - ctx->session_timeout, - &ctx->session, - &session_state, - &session_lock - ); - - if (get_session_result == - SESSION_TABLE_OVERFLOW) { // session with such id is not present and - // there is no enough space in the session - // table to create new state, so error - // update virtual service stats - VS_STATS_INC(session_table_overflow, ctx); - - return NULL; - } - - if (get_session_result == SESSION_FOUND) { // session with such id found - // session_state->real_id contains the global registry index - size_t real_stable_idx = session_state->real_id; - uint64_t real_ph_idx; - int find_res = - map_find(reals_index, real_stable_idx, &real_ph_idx); - - if (find_res == -1) { - // session is for real which is not - // configured for the current packet handler. - - // increase stats, then try reschedule packet to the - // other real - VS_STATS_INC(real_is_removed, ctx); - } else if (!vs_real_enabled( - ctx->vs.ptr, real_ph_idx - )) { // check if real is - // disabled - // real is disabled - - struct real *real = &reals[real_ph_idx]; - - // select real to update its counters - packet_ctx_set_real(ctx, real); - - // increment stats - REAL_STATS_INC(packets_real_disabled, ctx); - VS_STATS_INC(real_is_disabled, ctx); - - // deselect real - packet_ctx_unset_real(ctx); - } else { - // real enabled and present in config, so we select it. - - struct real *real = &reals[real_ph_idx]; - - // set real in packet context - packet_ctx_set_real(ctx, real); - - // prolong session in the active sessions tracker - struct active_sessions_tracker_shard *tracker_shards = - ADDR_OF(&real->tracker_shards); - active_sessions_tracker_prolong_session( - tracker_shards, - worker_idx, - session_state->last_packet_timestamp, - session_state->timeout, - now, - ctx->session_timeout - ); - - // update session and unlock it - session_state->timeout = ctx->session_timeout; - session_state->last_packet_timestamp = now; - session_unlock(session_lock); - - // real is selected, just return it. - return real; - } - } - - // session not found or real is disabled - // but session inserted into table and - // we have pointer to session state with acquired lock. - - // now we need to select real for packet - - assert(session_state != NULL); - if (!reschedule_real( - ctx->transport_proto, ctx->tcp_flags - )) { // packet type not allows to create new session - VS_STATS_INC(not_rescheduled_packets, ctx); - session_remove(session_state); // free created state - session_unlock(session_lock); // unlock state - return NULL; - } - - // select new real for the session and remember it in session state - - uint32_t local_real_id = - selector_select(&vs->selector, worker_idx, ctx->packet->hash); - if (local_real_id == SELECTOR_VALUE_INVALID) { - VS_STATS_INC(no_reals, ctx); - session_remove(session_state); // free created state - session_unlock(session_lock); // unlock state - return NULL; - } - - uint32_t real_id = vs->first_real_idx + local_real_id; - - // real selected, new session is created - - // set real - struct real *real = &reals[real_id]; - packet_ctx_set_real(ctx, real); - - session_state->create_timestamp = now; - session_state->last_packet_timestamp = now; - session_state->real_id = real->stable_idx; - session_state->timeout = ctx->session_timeout; - - // register new session in the active sessions tracker - struct active_sessions_tracker_shard *tracker_shards = - ADDR_OF(&real->tracker_shards); - active_sessions_tracker_new_session( - tracker_shards, worker_idx, now, ctx->session_timeout - ); - - session_unlock(session_lock); - - // update stats - VS_STATS_INC(created_sessions, ctx); - REAL_STATS_INC(created_sessions, ctx); - - return real; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/l4/tunnel.c b/modules/balancer/dataplane/l4/tunnel.c new file mode 100644 index 000000000..1177745ec --- /dev/null +++ b/modules/balancer/dataplane/l4/tunnel.c @@ -0,0 +1,253 @@ +#include + +#include + +#include +#include +#include + +#include "common/checksum.h" +#include "common/network.h" + +#include "lib/dataplane/packet/data.h" +#include "lib/dataplane/packet/encap.h" + +#include "gre.h" +#include "mss.h" +#include "packet.h" +#include "tunnel.h" + +/* + * Embed client source IP into the outer IPv6 tunnel source address + * using SSE2 SIMD when the client address is a full 16-byte IPv6. + * + * The operation per byte is: + * outer_src[i] |= client_src[i] & ~mask[i] + * + * For 16-byte (IPv6) clients this maps to three 128-bit instructions: + * andnot, or, store. + * + * For 4-byte (IPv4) clients a scalar fallback is used. + */ +static inline void +embed_client_ip6( + struct rte_mbuf *mbuf, + uint16_t network_offset, + const uint8_t *client_src, + uint8_t client_len, + const struct net6 *src_net +) { + struct rte_ipv6_hdr *outer = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv6_hdr *, network_offset + ); + uint8_t *dst = outer->src_addr; + + if (client_len == NET6_LEN) { + __m128i v_dst = _mm_loadu_si128((__m128i *)dst); + __m128i v_cli = _mm_loadu_si128((const __m128i *)client_src); + __m128i v_mask = + _mm_loadu_si128((const __m128i *)src_net->mask); + /* v_cli & ~v_mask */ + __m128i v_bits = _mm_andnot_si128(v_mask, v_cli); + v_dst = _mm_or_si128(v_dst, v_bits); + _mm_storeu_si128((__m128i *)dst, v_dst); + } else { + for (uint8_t i = 0; i < client_len; i++) { + dst[i] |= client_src[i] & ~src_net->mask[i]; + } + } +} + +/* + * Embed client source IP into the outer IPv4 tunnel source address + * and recompute the IPv4 header checksum. + * + * Uses a single 32-bit word operation instead of a byte loop: + * src_addr = (client & ~mask) | net_addr + */ +static inline void +embed_client_ip4( + struct rte_mbuf *mbuf, + uint16_t network_offset, + const uint8_t *client_src, + const struct net4 *src_net +) { + struct rte_ipv4_hdr *outer = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv4_hdr *, network_offset + ); + + uint32_t client, mask, addr; + __builtin_memcpy(&client, client_src, NET4_LEN); + __builtin_memcpy(&mask, src_net->mask, NET4_LEN); + __builtin_memcpy(&addr, src_net->addr, NET4_LEN); + + uint32_t old_src = outer->src_addr; + outer->src_addr = (client & ~mask) | addr; + + uint16_t cksum = ~outer->hdr_checksum; + cksum = csum_minus(cksum, (uint16_t)old_src); + cksum = csum_minus(cksum, (uint16_t)(old_src >> 16)); + cksum = csum_plus(cksum, (uint16_t)outer->src_addr); + cksum = csum_plus(cksum, (uint16_t)(outer->src_addr >> 16)); + outer->hdr_checksum = (cksum == 0xffff) ? cksum : ~cksum; +} + +/* + * Encapsulate an inner-IPv4 packet in an IP tunnel and embed the + * client source address into the outer header. + */ +static void +encapsulate_ipv4(struct packet *packet, struct balancer_real *real) { + struct rte_mbuf *mbuf = packet_to_mbuf(packet); + bool is_outer_ipv6 = real->flags & balancer_real_ipv6; + + uint16_t inner_offset = packet->network_header.offset; + + if (is_outer_ipv6) { + packet_ip6_encap( + packet, real->addr.v6.bytes, real->src.v6.addr + ); + + uint16_t outer_size = sizeof(struct rte_ipv6_hdr); + struct rte_ipv4_hdr *inner = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv4_hdr *, inner_offset + outer_size + ); + const uint8_t *client_src = (const uint8_t *)&inner->src_addr; + + embed_client_ip6( + mbuf, + packet->network_header.offset, + client_src, + NET4_LEN, + &real->src.v6 + ); + } else { + packet_ip4_encap( + packet, real->addr.v4.bytes, real->src.v4.addr + ); + + uint16_t outer_size = sizeof(struct rte_ipv4_hdr); + struct rte_ipv4_hdr *inner = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv4_hdr *, inner_offset + outer_size + ); + const uint8_t *client_src = (const uint8_t *)&inner->src_addr; + + embed_client_ip4( + mbuf, + packet->network_header.offset, + client_src, + &real->src.v4 + ); + } +} + +/* + * Encapsulate an inner-IPv6 packet in an IP tunnel and embed the + * client source address into the outer header. + */ +static void +encapsulate_ipv6(struct packet *packet, struct balancer_real *real) { + struct rte_mbuf *mbuf = packet_to_mbuf(packet); + bool is_outer_ipv6 = real->flags & balancer_real_ipv6; + + uint16_t inner_offset = packet->network_header.offset; + + if (is_outer_ipv6) { + packet_ip6_encap( + packet, real->addr.v6.bytes, real->src.v6.addr + ); + + uint16_t outer_size = sizeof(struct rte_ipv6_hdr); + struct rte_ipv6_hdr *inner = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv6_hdr *, inner_offset + outer_size + ); + const uint8_t *client_src = (const uint8_t *)inner->src_addr; + + embed_client_ip6( + mbuf, + packet->network_header.offset, + client_src, + NET6_LEN, + &real->src.v6 + ); + } else { + packet_ip4_encap( + packet, real->addr.v4.bytes, real->src.v4.addr + ); + + uint16_t outer_size = sizeof(struct rte_ipv4_hdr); + struct rte_ipv6_hdr *inner = rte_pktmbuf_mtod_offset( + mbuf, struct rte_ipv6_hdr *, inner_offset + outer_size + ); + const uint8_t *client_src = (const uint8_t *)inner->src_addr; + + embed_client_ip4( + mbuf, + packet->network_header.offset, + client_src, + &real->src.v4 + ); + } +} + +/* + * Update per-VS and per-real forwarding counters. + */ +static inline void +update_tunnel_stats(struct l4_packet_context *pkt_ctx) { + struct balancer_vs_stats *vs_stats = pkt_ctx->matched_vs_stats; + struct balancer_real_stats *real_stats = pkt_ctx->resolved_real_stats; + uint64_t pkt_bytes = pkt_ctx->packet->mbuf->pkt_len; + + vs_stats->outgoing_packets += 1; + vs_stats->outgoing_bytes += pkt_bytes; + real_stats->packets += 1; + real_stats->bytes += pkt_bytes; +} + +/* + * Tunnel a packet whose inner layer is IPv4. + */ +void +tunnel_ipv4_packet(struct l4_packet_context *pkt_ctx) { + struct packet *packet = pkt_ctx->packet; + struct balancer_real *real = pkt_ctx->resolved_real; + uint16_t vs_flags = pkt_ctx->matched_vs->flags; + + /* No MSS clamping for inner IPv4. */ + + encapsulate_ipv4(packet, real); + + if (unlikely(vs_flags & balancer_vs_gre)) { + bool is_outer_ipv6 = real->flags & balancer_real_ipv6; + insert_gre_header(packet, is_outer_ipv6, true); + } + + update_tunnel_stats(pkt_ctx); +} + +/* + * Tunnel a packet whose inner layer is IPv6. + * + * MSS clamping is applied when the VS has balancer_vs_fix_mss set. + */ +void +tunnel_ipv6_packet(struct l4_packet_context *pkt_ctx) { + struct packet *packet = pkt_ctx->packet; + struct balancer_real *real = pkt_ctx->resolved_real; + uint16_t vs_flags = pkt_ctx->matched_vs->flags; + + /* Clamp MSS for IPv6 SYN packets before encapsulation. */ + if (unlikely(vs_flags & balancer_vs_fix_mss)) { + fix_mss_ipv6(packet); + } + + encapsulate_ipv6(packet, real); + + if (unlikely(vs_flags & balancer_vs_gre)) { + bool is_outer_ipv6 = real->flags & balancer_real_ipv6; + insert_gre_header(packet, is_outer_ipv6, false); + } + + update_tunnel_stats(pkt_ctx); +} diff --git a/modules/balancer/dataplane/l4/tunnel.h b/modules/balancer/dataplane/l4/tunnel.h new file mode 100644 index 000000000..2219c3c76 --- /dev/null +++ b/modules/balancer/dataplane/l4/tunnel.h @@ -0,0 +1,27 @@ +#pragma once + +struct l4_packet_context; + +/* + * Tunnel a packet whose inner layer is IPv4. + * + * Encapsulates the packet in an IP-in-IP tunnel towards the resolved + * real server, embeds the client source IP into the outer header, + * optionally wraps in GRE, and updates forwarding statistics. + * + * MSS clamping is not performed — it only applies to inner IPv6. + */ +void +tunnel_ipv4_packet(struct l4_packet_context *pkt_ctx); + +/* + * Tunnel a packet whose inner layer is IPv6. + * + * Encapsulates the packet in an IP-in-IP tunnel towards the resolved + * real server, embeds the client source IP into the outer header, + * optionally wraps in GRE, and updates forwarding statistics. + * + * MSS clamping is applied when the VS has balancer_vs_fix_mss set. + */ +void +tunnel_ipv6_packet(struct l4_packet_context *pkt_ctx); diff --git a/modules/balancer/dataplane/lookup.h b/modules/balancer/dataplane/lookup.h deleted file mode 100644 index 064e41124..000000000 --- a/modules/balancer/dataplane/lookup.h +++ /dev/null @@ -1,186 +0,0 @@ -#pragma once - -#include "common/lpm.h" -#include "common/memory_address.h" -#include "common/network.h" - -#include "counters/counters.h" -#include "flow/helpers.h" -#include "lib/dataplane/packet/packet.h" - -#include - -#include - -#include -#include - -#include "handler/handler.h" - -#include "flow/common.h" -#include "flow/context.h" - -//////////////////////////////////////////////////////////////////////////////// - -FILTER_QUERY_DECLARE( - vs_lookup_ipv4, net4_fast_dst, port_fast_dst, proto_range_fast -); -FILTER_QUERY_DECLARE(vs_acl_ipv4, net4_fast_src, port_fast_src); - -static inline uint32_t -vs_v4_table_lookup(struct packet_handler *handler, struct packet *packet) { - uint32_t result; - filter_query( - ADDR_OF(&handler->vs_ipv4.filter), - vs_lookup_ipv4, - &packet, - &result, - 1 - ); - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -FILTER_QUERY_DECLARE( - vs_lookup_ipv6, net6_fast_dst, port_fast_dst, proto_range_fast -); -FILTER_QUERY_DECLARE(vs_acl_ipv6, net6_fast_src, port_fast_src); - -static inline uint32_t -vs_v6_table_lookup(struct packet_handler *handler, struct packet *packet) { - uint32_t result; - filter_query( - ADDR_OF(&handler->vs_ipv6.filter), - vs_lookup_ipv6, - &packet, - &result, - 1 - ); - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline struct vs * -vs_v4_lookup(struct packet_ctx *ctx) { - struct packet_handler *handler = ctx->handler; - // get id of the virtual service - uint32_t service_id = vs_v4_table_lookup(handler, ctx->packet); - if (service_id == (uint32_t)-1) { - return NULL; - } - struct vs *vs = ADDR_OF(&handler->vs_ipv4.vs) + service_id; - - // set virtual service - packet_ctx_set_vs(ctx, vs); - - return vs; -} - -static inline bool -check_fw_and_inc_stats(struct packet_ctx *ctx, struct vs *vs, uint32_t result) { - if (result != FILTER_RULE_INVALID) { - uint32_t rule_idx = result; - uint64_t counter_id = ADDR_OF(&vs->rule_counters)[rule_idx]; - if (counter_id != (uint64_t)-1) { - counter_get_address( - counter_id, ctx->worker_idx, ctx->stats.storage - )[0] += 1; - } - return true; - } - return false; -} - -static inline bool -vs_v4_fw(struct packet_ctx *ctx, struct vs *vs, struct packet *packet) { - (void)ctx; - uint32_t result; - filter_query(ADDR_OF(&vs->acl), vs_acl_ipv4, &packet, &result, 1); - return check_fw_and_inc_stats(ctx, vs, result); -} - -static inline bool -vs_v4_announced(struct packet_ctx *ctx) { - struct packet_handler *handler = ctx->handler; - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - struct rte_ipv4_hdr *ipv4_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv4_hdr *, packet->network_header.offset - ); - return lpm_lookup( - &handler->vs_ipv4.announce, - NET4_LEN, - (uint8_t *)&ipv4_hdr->dst_addr - ) != LPM_VALUE_INVALID; -} - -static inline bool -vs_v6_announced(struct packet_ctx *ctx) { - struct packet_handler *handler = ctx->handler; - struct packet *packet = ctx->packet; - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - struct rte_ipv6_hdr *ipv6_hdr = rte_pktmbuf_mtod_offset( - mbuf, struct rte_ipv6_hdr *, packet->network_header.offset - ); - return lpm_lookup( - &handler->vs_ipv6.announce, - NET6_LEN, - (uint8_t *)&ipv6_hdr->dst_addr - ) != LPM_VALUE_INVALID; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline struct vs * -vs_v6_lookup(struct packet_ctx *ctx) { - struct packet_handler *handler = ctx->handler; - uint32_t service_id = vs_v6_table_lookup(handler, ctx->packet); - if (service_id == (uint32_t)-1) { - return NULL; - } - struct vs *vs = ADDR_OF(&handler->vs_ipv6.vs) + service_id; - - // set virtual service - packet_ctx_set_vs(ctx, vs); - - return vs; -} - -static inline bool -vs_v6_fw(struct packet_ctx *ctx, struct vs *vs, struct packet *packet) { - (void)ctx; - uint32_t result; - filter_query(ADDR_OF(&vs->acl), vs_acl_ipv6, &packet, &result, 1); - return check_fw_and_inc_stats(ctx, vs, result); -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline struct vs * -vs_lookup_and_fw(struct packet_ctx *ctx) { - struct packet *packet = ctx->packet; - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - struct vs *vs = vs_v4_lookup(ctx); - if (vs == NULL) { - return NULL; - } - if (!vs_v4_fw(ctx, vs, packet)) { - packet_ctx_vs_stats(ctx)->packet_src_not_allowed += 1; - return NULL; - } - return vs; - } else { // ipv6 - struct vs *vs = vs_v6_lookup(ctx); - if (vs == NULL) { - return NULL; - } - if (!vs_v6_fw(ctx, vs, packet)) { - packet_ctx_vs_stats(ctx)->packet_src_not_allowed += 1; - return NULL; - } - return vs; - } -} diff --git a/modules/balancer/dataplane/meson.build b/modules/balancer/dataplane/meson.build index c9ab5aacb..ea2a57d99 100644 --- a/modules/balancer/dataplane/meson.build +++ b/modules/balancer/dataplane/meson.build @@ -6,10 +6,17 @@ dp_dependencies = [ lib_filter_query_dep, ] -includes = include_directories('.', '../controlplane') +includes = include_directories('.', '../../../') dp_sources = files( 'dataplane.c', + 'l4/handle.c', + 'l4/resolve.c', + 'l4/tunnel.c', + 'l4/mss.c', + 'l4/gre.c', + 'l4/group.c', + 'icmp/handle.c', ) lib_balancer_dp = static_library( diff --git a/modules/balancer/dataplane/meta.h b/modules/balancer/dataplane/meta.h deleted file mode 100644 index db93cf220..000000000 --- a/modules/balancer/dataplane/meta.h +++ /dev/null @@ -1,237 +0,0 @@ -#pragma once - -#include "common/network.h" -#include "dataplane/packet/data.h" -#include "dataplane/packet/packet.h" -#include "rte_byteorder.h" -#include "rte_ether.h" -#include -#include - -#include -#include -#include -#include -#include - -#include "handler/handler.h" -#include "handler/vs.h" -#include "state/session.h" - -//////////////////////////////////////////////////////////////////////////////// - -struct packet_metadata { - uint8_t network_proto; - uint8_t transport_proto; - - uint8_t *src_addr; - uint8_t *dst_addr; - uint16_t src_port; - uint16_t dst_port; - - uint8_t tcp_flags; -}; - -struct packet_metadata_copy { - struct packet_metadata meta; - uint8_t src_addr[NET6_LEN]; - uint8_t dst_addr[NET6_LEN]; -}; - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -fill_packet_metadata_ipv4( - struct rte_ipv4_hdr *ip_hdr, struct packet_metadata *metadata -) { - metadata->network_proto = IPPROTO_IP; - metadata->dst_addr = (uint8_t *)&ip_hdr->dst_addr; - metadata->src_addr = (uint8_t *)&ip_hdr->src_addr; -} - -static inline void -fill_packet_metadata_copy_ipv4( - struct rte_ipv4_hdr *ip_hdr, struct packet_metadata_copy *copy -) { - copy->meta.network_proto = IPPROTO_IP; - memcpy(copy->dst_addr, &ip_hdr->dst_addr, NET4_LEN); - memcpy(copy->src_addr, &ip_hdr->src_addr, NET4_LEN); - copy->meta.dst_addr = copy->dst_addr; - copy->meta.src_addr = copy->src_addr; -} - -static inline void -fill_packet_metadata_ipv6( - struct rte_ipv6_hdr *ip_hdr, struct packet_metadata *metadata -) { - metadata->network_proto = IPPROTO_IPV6; - metadata->dst_addr = ip_hdr->dst_addr; - metadata->src_addr = ip_hdr->src_addr; -} - -static inline void -fill_packet_metadata_copy_ipv6( - struct rte_ipv6_hdr *ip_hdr, struct packet_metadata_copy *copy -) { - copy->meta.network_proto = IPPROTO_IPV6; - memcpy(copy->dst_addr, ip_hdr->dst_addr, NET6_LEN); - memcpy(copy->src_addr, ip_hdr->src_addr, NET6_LEN); - copy->meta.dst_addr = copy->dst_addr; - copy->meta.src_addr = copy->src_addr; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -fill_packet_metadata_tcp( - struct rte_tcp_hdr *tcp_header, struct packet_metadata *metadata -) { - metadata->transport_proto = IPPROTO_TCP; - metadata->dst_port = tcp_header->dst_port; - metadata->src_port = tcp_header->src_port; - metadata->tcp_flags = tcp_header->tcp_flags; -} - -static inline void -fill_packet_metadata_udp( - struct rte_udp_hdr *udp_header, struct packet_metadata *metadata -) { - metadata->transport_proto = IPPROTO_UDP; - metadata->dst_port = udp_header->dst_port; - metadata->src_port = udp_header->src_port; - metadata->tcp_flags = 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline int -fill_packet_metadata(struct packet *packet, struct packet_metadata *metadata) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - struct rte_ipv4_hdr *ipv4_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - fill_packet_metadata_ipv4(ipv4_header, metadata); - } else if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6)) { - struct rte_ipv6_hdr *ipv6_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - fill_packet_metadata_ipv6(ipv6_header, metadata); - } else { // unsupported - return -1; - } - - if (packet->transport_header.type == IPPROTO_TCP) { - struct rte_tcp_hdr *tcp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_tcp_hdr *, - packet->transport_header.offset - ); - fill_packet_metadata_tcp(tcp_header, metadata); - } else if (packet->transport_header.type == IPPROTO_UDP) { - struct rte_udp_hdr *udp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_udp_hdr *, - packet->transport_header.offset - ); - fill_packet_metadata_udp(udp_header, metadata); - } else { // unsupported - return -1; - } - - return 0; -} - -static inline int -fill_packet_metadata_copy( - struct packet *packet, struct packet_metadata_copy *copy -) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4)) { - struct rte_ipv4_hdr *ipv4_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - fill_packet_metadata_copy_ipv4(ipv4_header, copy); - } else if (packet->network_header.type == - rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV6)) { - struct rte_ipv6_hdr *ipv6_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - fill_packet_metadata_copy_ipv6(ipv6_header, copy); - } else { // unsupported - return -1; - } - - if (packet->transport_header.type == IPPROTO_TCP) { - struct rte_tcp_hdr *tcp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_tcp_hdr *, - packet->transport_header.offset - ); - fill_packet_metadata_tcp(tcp_header, ©->meta); - } else if (packet->transport_header.type == IPPROTO_UDP) { - struct rte_udp_hdr *udp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_udp_hdr *, - packet->transport_header.offset - ); - fill_packet_metadata_udp(udp_header, ©->meta); - } else { // unsupported - return -1; - } - - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline uint32_t -session_timeout( - struct sessions_timeouts *timeouts, struct packet_metadata *metadata -) { - if (metadata->transport_proto == IPPROTO_UDP) { - return timeouts->udp; - } - if (metadata->transport_proto != IPPROTO_TCP) { - return timeouts->def; - } - - if ((metadata->tcp_flags & RTE_TCP_SYN_FLAG) == RTE_TCP_SYN_FLAG) { - if ((metadata->tcp_flags & RTE_TCP_ACK_FLAG) == - RTE_TCP_ACK_FLAG) { - return timeouts->tcp_syn_ack; - } - return timeouts->tcp_syn; - } - if (metadata->tcp_flags & RTE_TCP_FIN_FLAG) { - return timeouts->tcp_fin; - } - return timeouts->tcp; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -fill_session_id( - struct session_id *id, struct packet_metadata *data, struct vs *vs -) { - memset(id, 0, sizeof(*id)); - memcpy(&id->client_ip, - data->src_addr, - data->network_proto == IPPROTO_IPV6 ? NET6_LEN : NET4_LEN); - id->client_port = data->src_port; - id->vs_id = vs->stable_idx; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/mss.h b/modules/balancer/dataplane/mss.h deleted file mode 100644 index 45efc1949..000000000 --- a/modules/balancer/dataplane/mss.h +++ /dev/null @@ -1,163 +0,0 @@ -#pragma once - -#include "lib/dataplane/packet/packet.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "checksum.h" - -//////////////////////////////////////////////////////////////////////////////// - -#define TCP_OPTION_MSS_LEN (4) -#define TCP_OPTION_KIND_MSS (2) -#define TCP_OPTION_KIND_EOL (0) -#define TCP_OPTION_KIND_NOP (1) - -#define DEFAULT_MSS_SIZE 536 -#define FIX_MSS_SIZE 1220 - -//////////////////////////////////////////////////////////////////////////////// - -struct tcp_option { - uint8_t kind; - uint8_t len; - char data[0]; -} __attribute__((__packed__)); - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -fix_mss_ipv6(struct packet *packet) { - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - if (packet->transport_header.type == IPPROTO_TCP) { - struct rte_tcp_hdr *tcp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_tcp_hdr *, - packet->transport_header.offset - ); - - if ((tcp_header->tcp_flags & - (RTE_TCP_SYN_FLAG | RTE_TCP_RST_FLAG)) != - RTE_TCP_SYN_FLAG) { - return; - } - - uint16_t tcp_data_offset = (tcp_header->data_off >> 4) * 4; - if (tcp_data_offset < sizeof(struct rte_tcp_hdr) || - packet->transport_header.offset + tcp_data_offset > - rte_pktmbuf_pkt_len(mbuf)) { - // Data offset is out of bounds of the packet, nothing - // to do here - return; - } - - // Option lookup - uint16_t tcp_option_offset = sizeof(struct rte_tcp_hdr); - while (tcp_option_offset + TCP_OPTION_MSS_LEN <= tcp_data_offset - ) { - const struct tcp_option *option = - rte_pktmbuf_mtod_offset( - mbuf, - struct tcp_option *, - packet->transport_header.offset + - tcp_option_offset - ); - - if (option->kind == TCP_OPTION_KIND_MSS) { - /// mss could not be increased so check the - /// value first - uint16_t old_mss = rte_be_to_cpu_16( - *(uint16_t *)option->data - ); - if (old_mss <= FIX_MSS_SIZE) { - return; - } - uint16_t cksum = ~tcp_header->cksum; - cksum = csum_minus( - cksum, *(uint16_t *)option->data - ); - *(uint16_t *)option->data = - rte_cpu_to_be_16(FIX_MSS_SIZE); - cksum = csum_plus( - cksum, *(uint16_t *)option->data - ); - tcp_header->cksum = - (cksum == 0xffff) ? cksum : ~cksum; - return; - } else if (option->kind == TCP_OPTION_KIND_EOL || - option->kind == TCP_OPTION_KIND_NOP) { - tcp_option_offset++; - } else { - if (option->len == 0) { - /// packet header is broken - return; - } - tcp_option_offset += option->len; - } - } - - /// try to insert option - if (tcp_data_offset > (0x0f << 2) - TCP_OPTION_MSS_LEN) { - /// no space to insert the option - return; - } - - /// insert option just after regular tcp header - rte_pktmbuf_prepend(mbuf, TCP_OPTION_MSS_LEN); - memmove(rte_pktmbuf_mtod(mbuf, char *), - rte_pktmbuf_mtod_offset( - mbuf, char *, TCP_OPTION_MSS_LEN - ), - packet->transport_header.offset + - sizeof(struct rte_tcp_hdr)); - struct tcp_option *option = rte_pktmbuf_mtod_offset( - mbuf, - struct tcp_option *, - packet->transport_header.offset + - sizeof(struct rte_tcp_hdr) - ); - option->kind = TCP_OPTION_KIND_MSS; - option->len = TCP_OPTION_MSS_LEN; - *(uint16_t *)option->data = rte_cpu_to_be_16(DEFAULT_MSS_SIZE); - - /// adjust tcp and ip lengths and update checksums - tcp_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_tcp_hdr *, - packet->transport_header.offset - ); - tcp_header->data_off += 0x1 << 4; - uint16_t cksum = ~tcp_header->cksum; - /// data_off is the leading byte of corresponding 2-byte - /// sequence inside a tcp header so there is no rte_cpu_to_be_16 - cksum = csum_plus(cksum, 0x1 << 4); - cksum = csum_plus(cksum, *(uint16_t *)option); - cksum = csum_plus(cksum, *(uint16_t *)option->data); - cksum = csum_plus(cksum, rte_cpu_to_be_16(TCP_OPTION_MSS_LEN)); - tcp_header->cksum = (cksum == 0xffff) ? cksum : ~cksum; - - struct rte_ipv6_hdr *ipv6_header = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - ipv6_header->payload_len = rte_cpu_to_be_16( - rte_be_to_cpu_16(ipv6_header->payload_len) + - TCP_OPTION_MSS_LEN - ); - } -} - -#undef TCP_OPTION_MSS_LEN -#undef TCP_OPTION_KIND_MSS -#undef TCP_OPTION_KIND_EOL -#undef TCP_OPTION_KIND_NOP - -#undef DEFAULT_MSS_SIZE -#undef FIX_MSS_SIZE \ No newline at end of file diff --git a/modules/balancer/dataplane/packet.h b/modules/balancer/dataplane/packet.h new file mode 100644 index 000000000..ff626bf6a --- /dev/null +++ b/modules/balancer/dataplane/packet.h @@ -0,0 +1,17 @@ +#pragma once + +#include "types/real.h" +#include "types/session.h" +#include "types/vs.h" + +struct l4_packet_context { + struct packet *packet; + struct balancer_vs *matched_vs; + struct balancer_vs_stats *matched_vs_stats; + struct balancer_real *resolved_real; + struct balancer_real_stats *resolved_real_stats; + struct balancer_session_id session_id; + uint8_t session_timeout; + bool can_reschedule; + bool is_dropped; +}; diff --git a/modules/balancer/dataplane/real.h b/modules/balancer/dataplane/real.h deleted file mode 100644 index 204bb50c7..000000000 --- a/modules/balancer/dataplane/real.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "counters/counters.h" -#include - -#include "handler/real.h" - -static inline struct real_stats * -real_counter( - struct real *real, size_t worker, struct counter_storage *storage -) { - uint64_t *counter = - counter_get_address(real->counter_id, worker, storage); - return (struct real_stats *)counter; -} diff --git a/modules/balancer/dataplane/real_helpers.h b/modules/balancer/dataplane/real_helpers.h new file mode 100644 index 000000000..0c63387b0 --- /dev/null +++ b/modules/balancer/dataplane/real_helpers.h @@ -0,0 +1,32 @@ +#pragma once + +#include "lib/counters/counters.h" + +#include "types/real.h" + +static inline struct balancer_real_stats * +real_get_stats( + struct balancer_real *real, + uint32_t worker, + struct counter_storage *counter_storage +) { + return (struct balancer_real_stats *)counter_get_address( + real->counter_id, worker, counter_storage + ); +} + +/* Extract config index (lower 32 bits) from a real's stable_idx. */ +static inline uint32_t +real_idx_from_stable_idx(uint64_t stable_idx) { + return (uint32_t)stable_idx; +} + +static inline bool +real_is_enabled(struct balancer_real *real) { + return real->flags & balancer_real_enabled; +} + +static inline bool +real_is_removed(struct balancer_real *real) { + return real->flags & balancer_real_removed; +} \ No newline at end of file diff --git a/modules/balancer/dataplane/selector.h b/modules/balancer/dataplane/selector.h deleted file mode 100644 index d26756239..000000000 --- a/modules/balancer/dataplane/selector.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include "common/memory_address.h" - -#include "handler/selector.h" - -#include -#include - -#define SELECTOR_VALUE_INVALID ((uint32_t)-1) - -// Selects a real server based on passed index. -static inline uint32_t -ring_get(struct ring *ring, uint64_t index) { - if (ring->len > 0) { - uint32_t idx = index % ring->len; - return *(ADDR_OF(&ring->ids) + idx); - } else { - return SELECTOR_VALUE_INVALID; - } -} - -static inline uint32_t -selector_select(struct real_selector *selector, size_t worker, uint32_t hash) { - size_t ring_id = - RCU_READ_BEGIN(&selector->rcu, worker, &selector->ring_id); - struct ring *ring = &selector->rings[ring_id]; - size_t idx = selector->use_rr ? selector->workers[worker].rr_counter++ - : hash; - uint32_t res = ring_get(ring, idx); - RCU_READ_END(&selector->rcu, worker); - return res; -} \ No newline at end of file diff --git a/modules/balancer/dataplane/session/interval_counter.h b/modules/balancer/dataplane/session/interval_counter.h new file mode 100644 index 000000000..182777871 --- /dev/null +++ b/modules/balancer/dataplane/session/interval_counter.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include + +#include "common/likely.h" + +#include "../types/interval_counter.h" + +/* Reset the whole ring when all slots are older than the current time. */ +static inline int64_t +balancer_ic_try_reset(struct balancer_interval_counter *counter, uint32_t now) { + int64_t sum = 0; + if (unlikely(now - counter->last_timestamp >= BALANCER_IC_RING_SIZE)) { + /* + * The entire ring is stale. Sum all remaining deltas so + * the caller's running count stays consistent, then clear. + */ + for (size_t i = 0; i < BALANCER_IC_RING_SIZE; ++i) { + sum += counter->diff[i]; + } + memset(counter->diff, 0, BALANCER_IC_RING_SIZE * sizeof(int32_t) + ); + counter->last_timestamp = now; + } + return sum; +} + +/* Expire slots up to `now` and return the net change for the running total. */ +static inline int64_t +balancer_ic_advance(struct balancer_interval_counter *counter, uint32_t now) { + int64_t change = 0; + + /* Sweep past slots: [last_timestamp, now) */ + while (unlikely(counter->last_timestamp < now)) { + uint32_t idx = counter->last_timestamp & BALANCER_IC_RING_MASK; + counter->last_timestamp++; + change += counter->diff[idx]; + counter->diff[idx] = 0; + } + + /* + * Consume the current slot (now). Any +1/-1 written by the + * caller for this timestamp is picked up here and the slot is + * cleared so subsequent calls at the same `now` start fresh. + */ + uint32_t idx = counter->last_timestamp & BALANCER_IC_RING_MASK; + change += counter->diff[idx]; + counter->diff[idx] = 0; + return change; +} + +/* Start a new interval `[now, until)` and return the change visible at `now`. + */ +static inline int64_t +balancer_ic_make( + struct balancer_interval_counter *counter, uint32_t now, uint32_t until +) { + assert(until - now < BALANCER_IC_RING_SIZE); + + int64_t change = balancer_ic_try_reset(counter, now); + + counter->diff[now & BALANCER_IC_RING_MASK] += 1; + counter->diff[until & BALANCER_IC_RING_MASK] -= 1; + + return change + balancer_ic_advance(counter, now); +} + +/* Move an existing interval end from `prev_until` to `new_until`. */ +static inline int64_t +balancer_ic_prolong( + struct balancer_interval_counter *counter, + uint32_t now, + uint32_t prev_until, + uint32_t new_until +) { + assert(prev_until >= now); + assert(new_until - now < BALANCER_IC_RING_SIZE); + + int64_t change = balancer_ic_try_reset(counter, now); + + counter->diff[prev_until & BALANCER_IC_RING_MASK] += 1; + counter->diff[new_until & BALANCER_IC_RING_MASK] -= 1; + + return change + balancer_ic_advance(counter, now); +} \ No newline at end of file diff --git a/modules/balancer/dataplane/session/table.h b/modules/balancer/dataplane/session/table.h new file mode 100644 index 000000000..cb5ca7621 --- /dev/null +++ b/modules/balancer/dataplane/session/table.h @@ -0,0 +1,139 @@ +#pragma once + +#include "common/ttlmap/detail/lock.h" +#include "common/ttlmap/ttlmap.h" + +#include "types/session.h" + +/* + * Session table operations for the dataplane. + * + * All session access must happen within a critical section + * (st_begin_cs / st_end_cs). The critical section pins the + * current generation, ensuring the controlplane does not + * free a map while workers are still reading it. + */ + +#define SESSION_FOUND TTLMAP_FOUND +#define SESSION_CREATED 2 +#define SESSION_TABLE_OVERFLOW TTLMAP_FAILED + +/* + * Enter a session table critical section. + * Returns the current generation, which must be passed to + * all subsequent st_* calls within this critical section. + */ +static inline uint64_t +st_begin_cs(struct balancer_session_table *st, uint32_t worker) { + return RCU_READ_BEGIN(&st->rcu, worker, &st->current_gen); +} + +/* + * Leave a session table critical section. + * After this call, the controlplane may free the map that + * was pinned by st_begin_cs. + */ +static inline void +st_end_cs(struct balancer_session_table *st, uint32_t worker) { + RCU_READ_END(&st->rcu, worker); +} + +static inline void +st_prefetch_session( + struct balancer_session_table *st, + uint64_t current_table_gen, + struct balancer_session_id *session_id +) { + struct ttlmap *map = balancer_st_cur_map(st, current_table_gen); + TTLMAP_PREFETCH(map, session_id, struct balancer_session_state, 1, 0); +} + +/* + * Whether the previous map should be consulted. + * True during transition generations (odd), when sessions may + * still reside in the old map pending migration. + */ +static inline int +st_prev_map_used(uint32_t table_gen) { + return table_gen & 1; +} + +/* + * Look up or create a session entry. + * + * First tries the current map. If the session is found, returns + * SESSION_FOUND with session_state pointing to the existing entry. + * + * If the session is not found, a new slot is allocated in the + * current map. During transition generations (odd), the previous + * map is also checked -- if the session exists there, its state + * is copied into the new slot and SESSION_FOUND is returned. + * This handles sessions that have not yet been migrated. + * + * If neither map has the session, returns SESSION_CREATED with + * a blank slot allocated in the current map. + * + * Returns SESSION_TABLE_OVERFLOW if the current map is full and + * no slot could be allocated. In this case no lock is held. + * + * On SESSION_FOUND or SESSION_CREATED, the returned session_state + * is locked via *lock. The caller must call st_unlock_session + * after modifying the state. + */ +static inline int +st_get_or_create_session( + struct balancer_session_table *st, + uint64_t current_table_gen, + uint32_t now, + uint32_t timeout, + struct balancer_session_id *session_id, + struct balancer_session_state **session_state, + ttlmap_lock_t **lock +) { + struct ttlmap *map = balancer_st_cur_map(st, current_table_gen); + + int res = + TTLMAP_GET(map, session_id, session_state, lock, now, timeout); + int status = TTLMAP_STATUS(res); + + int result_status; + if (status == TTLMAP_FOUND) { + result_status = SESSION_FOUND; + } else if (status == TTLMAP_INSERTED || status == TTLMAP_REPLACED) { + if (!st_prev_map_used(current_table_gen)) { + result_status = SESSION_CREATED; + } else { + /* + * New slot allocated in current map. During transition + * (odd gen), check the previous map for a session that + * has not been migrated yet. + */ + struct ttlmap *prev_map = + balancer_st_prev_map(st, current_table_gen); + int lookup_res = TTLMAP_LOOKUP( + prev_map, session_id, *session_state, now + ); + if (TTLMAP_STATUS(lookup_res) == TTLMAP_FAILED) { + result_status = SESSION_CREATED; + } else { + result_status = SESSION_FOUND; + } + } + } else { // status == TTLMAP_FAILED + result_status = SESSION_TABLE_OVERFLOW; + } + + return result_status; +} + +/* Mark a session slot as empty. Must be called while the lock is held. */ +static inline void +st_remove_session(struct balancer_session_state *session_state) { + TTLMAP_REMOVE(struct balancer_session_id, session_state); +} + +/* Release the lock acquired by st_get_or_create_session. */ +static inline void +st_unlock_session(ttlmap_lock_t *lock) { + ttlmap_release_lock(lock); +} \ No newline at end of file diff --git a/modules/balancer/dataplane/session/tracker.h b/modules/balancer/dataplane/session/tracker.h new file mode 100644 index 000000000..948f07c95 --- /dev/null +++ b/modules/balancer/dataplane/session/tracker.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include "../types/sessions_tracker.h" +#include "interval_counter.h" + +/* Convert a packet timestamp to the current tracker tick. */ +static inline uint32_t +sessions_tracker_now(uint32_t timestamp) { + return timestamp / BALANCER_SESSIONS_TRACKER_PRECISION; +} + +/* Round a packet timestamp up to the tick where the session expires. */ +static inline uint32_t +sessions_tracker_until(uint32_t timestamp) { + return (timestamp + BALANCER_SESSIONS_TRACKER_PRECISION - 1) / + BALANCER_SESSIONS_TRACKER_PRECISION; +} + +/* Account for a newly created session on the selected worker shard. */ +static inline void +sessions_tracker_new_session( + struct balancer_sessions_tracker_shard *tracker_shards, + uint32_t worker_idx, + uint32_t now, + uint32_t timeout +) { + struct balancer_sessions_tracker_shard *shard = + &tracker_shards[worker_idx]; + shard->count += balancer_ic_make( + &shard->counter, + sessions_tracker_now(now), + sessions_tracker_until(now + timeout) + ); + shard->last_timestamp = now; +} + +/* Extend an existing session and move its scheduled expiration. */ +static inline void +sessions_tracker_prolong_session( + struct balancer_sessions_tracker_shard *tracker_shards, + uint32_t worker_idx, + uint32_t last_packet_timestamp, + uint32_t prev_timeout, + uint32_t now, + uint32_t new_timeout +) { + struct balancer_sessions_tracker_shard *shard = + &tracker_shards[worker_idx]; + shard->count += balancer_ic_prolong( + &shard->counter, + sessions_tracker_now(now), + sessions_tracker_until(last_packet_timestamp + prev_timeout), + sessions_tracker_until(now + new_timeout) + ); + shard->last_timestamp = now; +} diff --git a/modules/balancer/dataplane/session_table.h b/modules/balancer/dataplane/session_table.h deleted file mode 100644 index 5a65514d1..000000000 --- a/modules/balancer/dataplane/session_table.h +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include "state/session.h" -#include "state/session_table.h" - -#include "common/ttlmap/ttlmap.h" - -static inline uint64_t -session_table_begin_cs(struct session_table *session_table, uint32_t worker) { - return RCU_READ_BEGIN( - &session_table->rcu, worker, &session_table->current_gen - ); -} - -static inline void -session_table_end_cs(struct session_table *table, uint32_t worker) { - RCU_READ_END(&table->rcu, worker); -} - -static inline int -worker_use_prev_map(uint32_t table_gen) { - return table_gen & 1; -} - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -prefetch_session( - struct session_table *table, - uint64_t current_table_gen, - struct session_id *session_id -) { - struct ttlmap *map = session_table_map(table, current_table_gen); - TTLMAP_PREFETCH(map, session_id, struct session_state, 1, 0); -} - -static inline int -get_or_create_session( - struct session_table *session_table, - uint64_t current_table_gen, - uint32_t now, - uint32_t timeout, - struct session_id *session_id, - struct session_state **session_state, - session_lock_t **lock -) { - // Get ttlmap - struct ttlmap *map = - session_table_map(session_table, current_table_gen); - - int res = - TTLMAP_GET(map, session_id, session_state, lock, now, timeout); - int status = TTLMAP_STATUS(res); - - int result_status; - if (status == TTLMAP_FOUND) { - result_status = SESSION_FOUND; - } else if (status == TTLMAP_INSERTED || status == TTLMAP_REPLACED) { - if (worker_use_prev_map(current_table_gen - )) { // if worker in this gen should use prev map - struct ttlmap *prev_map = session_table_prev_map( - session_table, current_table_gen - ); - int lookup_res = TTLMAP_LOOKUP( - prev_map, session_id, *session_state, now - ); - if (TTLMAP_STATUS(lookup_res) == TTLMAP_FOUND) { - result_status = SESSION_FOUND; - } else { - result_status = SESSION_CREATED; - } - } else { - result_status = SESSION_CREATED; - } - } else { // status == TTLMAP_FAILED - result_status = SESSION_TABLE_OVERFLOW; - } - - return result_status; -} - -static inline uint32_t -get_session_real( - struct session_table *session_table, - uint32_t current_table_gen, - struct session_id *session_id, - uint32_t now -) { - // Get ttlmap - struct ttlmap *map = - session_table_map(session_table, current_table_gen); - - struct session_state session_state; - int res = TTLMAP_LOOKUP(map, session_id, &session_state, now); - int status = TTLMAP_STATUS(res); - - uint32_t real_id = -1; - if (status == TTLMAP_FOUND) { - real_id = session_state.real_id; - } else { - assert(status == TTLMAP_FAILED); - if (worker_use_prev_map(current_table_gen - )) { // if worker in this gen should use prev map - struct ttlmap *prev = session_table_prev_map( - session_table, current_table_gen - ); - int res = TTLMAP_LOOKUP( - prev, session_id, &session_state, now - ); - status = TTLMAP_STATUS(res); - if (status == TTLMAP_FOUND) { - real_id = session_state.real_id; - } - }; - } - - return real_id; -} - -static inline void -session_remove(struct session_state *session_state) { - TTLMAP_REMOVE(struct session_id, session_state); -} - -static inline void -session_unlock(session_lock_t *lock) { - ttlmap_release_lock(lock); -} \ No newline at end of file diff --git a/modules/balancer/dataplane/tunnel.h b/modules/balancer/dataplane/tunnel.h deleted file mode 100644 index 2a9d0f3a9..000000000 --- a/modules/balancer/dataplane/tunnel.h +++ /dev/null @@ -1,166 +0,0 @@ -#pragma once - -#include "handler/vs.h" - -#include "dataplane/packet/packet.h" -#include "lib/dataplane/packet/encap.h" -#include "mss.h" -#include "rte_gre.h" -#include "rte_ip.h" - -#include - -//////////////////////////////////////////////////////////////////////////////// - -static inline void -tunnel_packet(struct vs *vs, struct real *real, struct packet *packet) { - int vs_ip_proto = vs->identifier.ip_proto; - uint8_t vs_flags = vs->flags; - - // fix packet MSS if flag is specified and vs is IPv6 - if ((vs_flags & VS_FIX_MSS_FLAG) && (vs_ip_proto == IPPROTO_IPV6)) { - fix_mss_ipv6(packet); - } - - // encapsulate packet - - struct rte_mbuf *mbuf = packet_to_mbuf(packet); - - struct rte_ipv6_hdr *ipv6_header_inner = NULL; - struct rte_ipv4_hdr *ipv4_header_inner = NULL; - if (vs_ip_proto == IPPROTO_IPV6) { - ipv6_header_inner = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - } else { - ipv4_header_inner = rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - } - - const int real_ipv6 = - real->identifier.relative.ip_proto == IPPROTO_IPV6 ? 1 : 0; - - if (real_ipv6) { // IPv6 - // rs->src_addr is already masked. - - const struct net6 *n6 = &real->src.v6; - - packet_ip6_encap( - packet, - real->identifier.relative.addr.v6.bytes, - n6->addr - ); - - struct rte_ipv6_hdr *ipv6_header_outer = - rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - - uint8_t *src = ipv6_header_outer->src_addr; - uint8_t len = (ipv4_header_inner != NULL ? NET4_LEN : NET6_LEN); - uint8_t *src_user = - (ipv4_header_inner != NULL - ? (uint8_t *)&ipv4_header_inner->src_addr - : ipv6_header_inner->src_addr); - for (uint8_t i = 0; i < len; i++) { - src[i] |= src_user[i] & (~n6->mask[i]); - } - } else { // IPv4 - // rs->src_addr is already masked. - const struct net4 *n4 = &real->src.v4; - - packet_ip4_encap( - packet, - real->identifier.relative.addr.v4.bytes, - n4->addr - ); - - struct rte_ipv4_hdr *ipv4_header_outer = - rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - - uint8_t *src = (uint8_t *)&ipv4_header_outer->src_addr; - uint8_t *src_user = - (ipv4_header_inner != NULL) - ? (uint8_t *)&ipv4_header_inner->src_addr - : ipv6_header_inner->src_addr; - for (size_t i = 0; i < 4; ++i) { - src[i] = (src_user[i] & ~n4->mask[i]) | n4->addr[i]; - } - } - - // use GRE for encap - if (vs_flags & VS_GRE_FLAG) { - const uint16_t gre_hdr_size = sizeof(struct rte_gre_hdr); - - if (rte_pktmbuf_prepend(mbuf, gre_hdr_size) == NULL) { - // not enough headroom to insert GRE - assert(false); - } - - const uint16_t len_before_gre = - packet->network_header.offset + - (real_ipv6 ? sizeof(struct rte_ipv6_hdr) - : sizeof(struct rte_ipv4_hdr)); - - // move L2 + outer L3 back to head to open a gap right after - // outer L3 - memmove(rte_pktmbuf_mtod(mbuf, char *), - rte_pktmbuf_mtod_offset(mbuf, char *, gre_hdr_size), - len_before_gre); - - if (real_ipv6) { - struct rte_ipv6_hdr *ipv6_header = - rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv6_hdr *, - packet->network_header.offset - ); - ipv6_header->proto = IPPROTO_GRE; - ipv6_header->payload_len = rte_cpu_to_be_16( - rte_be_to_cpu_16(ipv6_header->payload_len) + - gre_hdr_size - ); - } else { - struct rte_ipv4_hdr *ipv4_header = - rte_pktmbuf_mtod_offset( - mbuf, - struct rte_ipv4_hdr *, - packet->network_header.offset - ); - ipv4_header->next_proto_id = IPPROTO_GRE; - ipv4_header->total_length = rte_cpu_to_be_16( - rte_be_to_cpu_16(ipv4_header->total_length) + - gre_hdr_size - ); - - ipv4_header->hdr_checksum = 0; - ipv4_header->hdr_checksum = rte_ipv4_cksum(ipv4_header); - } - - // place GRE header in the created gap (right after outer L3) - struct rte_gre_hdr *gre_header = rte_pktmbuf_mtod_offset( - mbuf, struct rte_gre_hdr *, len_before_gre - ); - memset(gre_header, 0, sizeof(struct rte_gre_hdr)); - gre_header->ver = 0; // default version - gre_header->proto = rte_cpu_to_be_16( - ipv4_header_inner != NULL ? RTE_ETHER_TYPE_IPV4 - : RTE_ETHER_TYPE_IPV6 - ); - - // advance transport offset past GRE header (inner transport - // shifts forward) - packet->transport_header.offset += gre_hdr_size; - } -} diff --git a/modules/balancer/dataplane/types/interval_counter.h b/modules/balancer/dataplane/types/interval_counter.h new file mode 100644 index 000000000..c56e07841 --- /dev/null +++ b/modules/balancer/dataplane/types/interval_counter.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +#define BALANCER_IC_RING_SIZE_EXP 3u +#define BALANCER_IC_RING_SIZE (1u << BALANCER_IC_RING_SIZE_EXP) +#define BALANCER_IC_RING_MASK (BALANCER_IC_RING_SIZE - 1u) + +/* + * Ring-based interval counter that stores per-timestamp deltas. + */ +struct balancer_interval_counter { + int32_t diff[BALANCER_IC_RING_SIZE]; + uint32_t last_timestamp; +}; diff --git a/modules/balancer/dataplane/types/real.h b/modules/balancer/dataplane/types/real.h new file mode 100644 index 000000000..dccc26a2f --- /dev/null +++ b/modules/balancer/dataplane/types/real.h @@ -0,0 +1,93 @@ +#pragma once + +#include + +#include "common/network.h" + +struct balancer_sessions_tracker_shard; + +enum balancer_real_flags { + balancer_real_enabled = 1u << 0, + balancer_real_removed = 1u << 1, + balancer_real_ipv6 = 1u << 2, +}; + +/* + * A real (backend) server within a virtual service. + * + * Stored in a per-VS contiguous array indexed by config index. + * See struct balancer_vs for the array layout and indexing scheme. + */ +struct balancer_real { + uint64_t counter_id; + + /* + * Encodes (epoch << 32) | config_index. + * Sessions store this value; on reuse the dataplane compares + * the stored value against the current one to detect + * replacements without any map lookup. + */ + uint64_t stable_idx; + + struct balancer_sessions_tracker_shard *tracker_shards; + + /* Destination IP address of the real server (IPv4 or IPv6). */ + struct net_addr addr; + + /* + * Source network for the outer tunnel header. + * + * Contains both the base address (src.v4.addr / src.v6.addr) + * and the mask (src.v4.mask / src.v6.mask). + * + * INVARIANT: the address bytes must be pre-masked by the + * controlplane, i.e. (addr[i] & mask[i]) == addr[i] for every + * byte i. The tunnel code relies on this to embed client source + * IP bits into the unmasked positions without an extra AND: + * + * outer_src[i] = addr[i] | (client_src[i] & ~mask[i]) + * + * Use v4 when balancer_real_ipv6 is clear, v6 when set. + */ + struct net src; + + uint8_t flags; + + /* ---Controlplane data --- */ + + uint32_t weight; + uint32_t effective_weight; +}; + +/* + * Per-real-server statistics. + * + * Tracks packet processing and session creation for a specific real + * server within a virtual service. + */ +struct balancer_real_stats { + /* + * Packets for sessions assigned to this real when it was disabled. + * + * Incremented when: + * - A session exists for this real + * - The real is currently disabled + * - A packet arrives for that session + * + * This indicates packets that were dropped or rescheduled because + * the real was disabled after the session was created. + */ + uint64_t packets_real_disabled; + + /* ICMP error packets forwarded to this real server. */ + uint64_t error_icmp_packets; + + /* Total number of new sessions created with this real as backend. */ + uint64_t created_sessions; + + /* Total packets forwarded to this real server. */ + uint64_t packets; + + /* Total bytes forwarded to this real server. */ + uint64_t bytes; +}; diff --git a/modules/balancer/dataplane/types/selector.h b/modules/balancer/dataplane/types/selector.h new file mode 100644 index 000000000..64d0f4016 --- /dev/null +++ b/modules/balancer/dataplane/types/selector.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +#include "common/big_array.h" +#include "common/rcu.h" + +/** + * Ring containing backend ("real") indices. + * + * Each backend appears multiple times according to its weight. + * The ring is shuffled to distribute selections evenly. + */ +struct balancer_ring { + /* + * Indices of backend servers. + * + * Stored in a big array because the weighted list can exceed + * the allocator's maximum block size. + * + * TODO: give more accurate comment why we use big array here and why it + * is valid. It is OK to use big array here because every real index + * size is 4 bytes. + */ + struct big_array real_ids; +}; + +/** + * Round-robin counter. + * + * Used to track the position of the current real in the ring. + */ +struct balancer_rr_counter { + uint64_t value; +} __attribute__((aligned(64))); + +/** + * Real backend selector. + * + * Maintains two rings for RCU-swapped updates and per-worker RR counters. + * Uses either round-robin or hash-based selection, + * depending on the virtual server scheduler. + */ +struct balancer_real_selector { + /* Double-buffered rings. */ + struct balancer_ring rings[2]; + + /* Active ring index. */ + _Atomic size_t ring_id; + + /* Non-zero for RR scheduler, zero for hash scheduler. */ + int use_rr; + + /* Array of per-worker round-robin counters. */ + struct balancer_rr_counter workers[RCU_WORKERS]; +}; diff --git a/modules/balancer/dataplane/types/session.h b/modules/balancer/dataplane/types/session.h new file mode 100644 index 000000000..6a4f77fb3 --- /dev/null +++ b/modules/balancer/dataplane/types/session.h @@ -0,0 +1,83 @@ +#pragma once + +#include "common/network.h" +#include "common/rcu.h" +#include "common/ttlmap/detail/ttlmap.h" + +#define BALANCER_SESSION_ID_PADDING \ + (32 - (sizeof(uint64_t) + sizeof(uint16_t) + NET6_LEN)) + +struct balancer_session_id { + uint64_t vs_stable_idx; + uint16_t client_port; + uint8_t client_ip[NET6_LEN]; + uint8_t padding[BALANCER_SESSION_ID_PADDING]; +}; + +struct balancer_session_state { + /* + * stable_idx of the assigned real; see struct balancer_vs. + */ + uint64_t real_stable_idx; + uint32_t last_packet_timestamp; + uint32_t create_timestamp; + uint8_t timeout; +}; + +struct balancer_session_timeouts { + uint8_t tcp_syn_ack; + uint8_t tcp_syn; + uint8_t tcp_fin; + uint8_t tcp; + uint8_t udp; +}; + +/* + * Double-buffered session table with RCU-based resizing. + * + * The table contains two ttlmaps. At any given time, one is the + * "current" map (where workers insert new sessions) and the other + * is the "previous" map (which may still hold sessions from before + * a resize). + * + * The controlplane resizes the table by: + * 1. Preparing the new map in the inactive slot. + * 2. Incrementing current_gen (atomically). + * 3. Waiting for all workers to observe the new generation (RCU). + * 4. Incrementing current_gen again once migration is complete. + * + * Generation parity controls dual-map behavior: + * - Even generation (stable): workers use only the current map. + * The previous map is not consulted and may be rebuilt. + * - Odd generation (transition): workers write to the current map + * but also check the previous map for existing sessions that + * have not yet been migrated. + * + * This allows resizing without dropping active sessions: during + * the transition (odd gen), a session not found in the new map + * may still exist in the old one. Once all sessions have been + * migrated or expired, the controlplane moves to the next even + * generation, and the old map is freed and can be reused. + */ +struct balancer_session_table { + struct ttlmap maps[2]; + rcu_t rcu; + _Atomic uint64_t current_gen; + struct memory_context mctx; + uint32_t workers; +}; + +static inline int +balancer_st_map_idx(uint32_t gen) { + return ((gen + 1) & 0b11) >> 1; +} + +static inline struct ttlmap * +balancer_st_cur_map(struct balancer_session_table *table, uint32_t gen) { + return &table->maps[balancer_st_map_idx(gen)]; +} + +static inline struct ttlmap * +balancer_st_prev_map(struct balancer_session_table *table, uint32_t gen) { + return &table->maps[balancer_st_map_idx(gen) ^ 1]; +} diff --git a/modules/balancer/dataplane/types/sessions_tracker.h b/modules/balancer/dataplane/types/sessions_tracker.h new file mode 100644 index 000000000..d19538b1d --- /dev/null +++ b/modules/balancer/dataplane/types/sessions_tracker.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +#include "interval_counter.h" + +#define BALANCER_SESSIONS_TRACKER_PRECISION 16 + +/* + * The underlying rt_interval_counter ring has size R, so the tick + * distance (until_tick - now_tick) must be < R. With precision P + * this means the session timeout must satisfy: + * (ts + timeout + P-1)/P - ts/P < R + * This means the session timeout must be at most (R-2) * P + 1. + * For example, with R=8 (production ring size) and P=16 (production precision), + * the session timeout must be at most 97s. + */ +#define BALANCER_MAX_SESSION_TIMEOUT \ + ((BALANCER_IC_RING_SIZE - 2) * BALANCER_SESSIONS_TRACKER_PRECISION + 1) + +static const uint8_t balancer_max_session_timeout = + BALANCER_MAX_SESSION_TIMEOUT; + +/* + * Per-worker active-session tracker. + */ +struct balancer_sessions_tracker_shard { + struct balancer_interval_counter counter; + uint32_t count; + uint32_t last_timestamp; +} __attribute__((aligned(64))); \ No newline at end of file diff --git a/modules/balancer/dataplane/types/stats.h b/modules/balancer/dataplane/types/stats.h new file mode 100644 index 000000000..37dff70a3 --- /dev/null +++ b/modules/balancer/dataplane/types/stats.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +struct balancer_common_stats { + uint64_t incoming_packets; + uint64_t incoming_bytes; + uint64_t unexpected_network_proto; + uint64_t decap_successful; + uint64_t decap_failed; + uint64_t outgoing_packets; + uint64_t outgoing_bytes; +}; + +struct balancer_l4_stats { + uint64_t incoming_packets; + uint64_t select_vs_failed; + uint64_t invalid_packets; + uint64_t select_real_failed; + uint64_t outgoing_packets; +}; + +struct balancer_icmp_stats { + uint64_t incoming_packets; + uint64_t src_not_allowed; + uint64_t echo_responses; + uint64_t payload_too_short_ip; + uint64_t unmatching_src_from_original; + uint64_t payload_too_short_port; + uint64_t unexpected_transport; + uint64_t unrecognized_vs; + uint64_t forwarded_packets; + uint64_t broadcasted_packets; + uint64_t packet_clones_sent; + uint64_t packet_clones_received; + uint64_t packet_clone_failures; +}; diff --git a/modules/balancer/dataplane/types/vs.h b/modules/balancer/dataplane/types/vs.h new file mode 100644 index 000000000..3e0c259b0 --- /dev/null +++ b/modules/balancer/dataplane/types/vs.h @@ -0,0 +1,199 @@ +#pragma once + +#include +#include +#include + +#include "common/network.h" + +#include "selector.h" + +enum balancer_vs_flags { + balancer_vs_pure_l3 = 1u << 0, + balancer_vs_fix_mss = 1u << 1, + balancer_vs_gre = 1u << 2, + balancer_vs_ops = 1u << 3, + balancer_vs_wlc = 1u << 4, + balancer_vs_removed = 1u << 5, + balancer_vs_round_robin = 1u << 6, +}; + +struct filter; +struct filter_port_range; +struct balancer_real; +struct balancer_real_selector; + +enum { + balancer_vs_acl_max_tag_len = 20, +}; + +struct balancer_vs_allowed_source { + struct net *nets; + uint32_t nets_count; + + struct filter_port_range *port_ranges; + uint32_t port_ranges_count; + + char tag[balancer_vs_acl_max_tag_len + 1]; +}; + +/* + * Reals within a VS are stored in a contiguous array with stable + * positions. Each position has a fixed config index that never + * changes across config updates: + * + * - When a real is replaced at position N, the new real occupies + * the same slot with an incremented epoch in its stable_idx. + * - When a real is removed, the slot is marked removed=true + * but not compacted. + * - When a new real is added, it may reuse a removed slot or + * be appended at the end. + * + * Sessions in the session table store the real's stable_idx, + * which encodes (epoch << 32) | config_index. On session reuse, + * the dataplane extracts the config index for O(1) array access + * and compares the full stable_idx to detect replacements. + */ +struct balancer_vs { + struct balancer_real *reals; + uint32_t reals_count; + + uint64_t stable_idx; + uint64_t counter_id; + + struct balancer_real_selector *selector; + struct filter *acl; + + /* + * Controlplane limits the number of rules per VS to billion, + * so we can allocate a single array for all rules + * in the shared memory without fragmentation issues. + */ + uint64_t *rule_counter_ids; + + uint16_t flags; + + /* ---Controlplane data --- */ + + /* To separate dataplane and controlplane cachelines. */ + uint8_t __padding[64]; + + struct net_addr addr; + uint8_t ip_proto; + uint16_t port; + uint8_t transport_proto; + + struct net4_addr *peers_v4; + uint32_t peers_v4_count; + + struct net6_addr *peers_v6; + uint32_t peers_v6_count; + + struct balancer_vs_allowed_source *allowed_sources; + uint32_t allowed_sources_count; +}; + +/** + * Per-virtual-service runtime counters. + * + * Tracks packet processing statistics for a specific virtual service, + * including successful forwards, various failure conditions, and + * session management metrics. + */ +struct balancer_vs_stats { + /* Total packets received matching this virtual service. */ + uint64_t incoming_packets; + + /* Total bytes received matching this virtual service. */ + uint64_t incoming_bytes; + + /* Packets dropped due to source address not in allowlist. */ + uint64_t packet_src_not_allowed; + + /* + * Packets that failed real server selection. + * + * Incremented when: + * - No real servers are configured + * - All real servers are disabled + * - All real servers have zero weight. + */ + uint64_t no_reals; + + /* Session creation failures due to table capacity. */ + uint64_t session_table_overflow; + + /* ICMP echo packets processed. */ + uint64_t echo_icmp_packets; + + /* + * ICMP error packets forwarded to real servers. + * + * Tracks ICMP errors (destination unreachable, time exceeded, + * etc.) that were matched to sessions and forwarded to the + * appropriate real server. + */ + uint64_t error_icmp_packets; + + /* + * Packets for sessions where the real server is disabled. + * + * Incremented when: + * - Session exists for a specific real + * - That real is currently disabled + * - Packet arrives for the session + * + * These packets are dropped. + */ + uint64_t real_is_disabled; + + /* + * Packets for sessions where the real server was removed. + * + * Incremented when: + * - Session exists for a specific real + * - That real is no longer in the configuration + * - Packet arrives for the session + * + * These packets are dropped. + */ + uint64_t real_is_removed; + + /* + * Packets that couldn't be rescheduled. + * + * Incremented when: + * - No existing session found + * - Packet doesn't start a new session (e.g., TCP non-SYN) + * + * Common for: + * - TCP packets without SYN flag when no session exists + * - Packets arriving after session timeout + */ + uint64_t not_rescheduled_packets; + + /* + * ICMP packets broadcasted to peer balancers. + * + * Incremented when: + * - ICMP error has this VS as source + * - Packet is cloned and sent to configured peers + * - Used for distributed ICMP error handling + */ + uint64_t broadcasted_icmp_packets; + + /* + * Total sessions created for this virtual service. + * + * Tracks the cumulative number of sessions created since + * the balancer started or statistics were reset. Does not + * include OPS packets (which don't create sessions). + */ + uint64_t created_sessions; + + /* Packets successfully forwarded to real servers. */ + uint64_t outgoing_packets; + + /* Bytes successfully forwarded to real servers (IP layer). */ + uint64_t outgoing_bytes; +}; diff --git a/modules/balancer/dataplane/vs.h b/modules/balancer/dataplane/vs.h deleted file mode 100644 index d3e623e46..000000000 --- a/modules/balancer/dataplane/vs.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include "handler/vs.h" - -// Counter for the virtual service, -// which depends on the placement of the -// module config in the controlplane topology. -static inline struct vs_stats * -vs_counter(struct vs *vs, size_t worker, struct counter_storage *storage) { - uint64_t *counter = - counter_get_address(vs->counter_id, worker, storage); - return (struct vs_stats *)counter; -} diff --git a/modules/balancer/dataplane/vs_helpers.h b/modules/balancer/dataplane/vs_helpers.h new file mode 100644 index 000000000..e6f4c8b54 --- /dev/null +++ b/modules/balancer/dataplane/vs_helpers.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/memory_address.h" +#include "lib/counters/counters.h" + +#include "dataplane.h" +#include "types/vs.h" + +static inline struct balancer_vs_stats * +vs_get_stats( + struct balancer_vs *vs, + uint32_t worker, + struct counter_storage *counter_storage +) { + return (struct balancer_vs_stats *)counter_get_address( + vs->counter_id, worker, counter_storage + ); +} + +static inline uint64_t * +vs_get_acl_stats( + struct balancer_vs *vs, + uint32_t worker, + struct counter_storage *counter_storage, + uint32_t rule_idx +) { + // Rule counter is undefined if tag is empty + uint64_t id = ADDR_OF(&vs->rule_counter_ids)[rule_idx]; + return id != (uint64_t)-1 + ? counter_get_address(id, worker, counter_storage) + : NULL; +} diff --git a/modules/balancer/dataplane/worker.h b/modules/balancer/dataplane/worker.h deleted file mode 100644 index 969f39f81..000000000 --- a/modules/balancer/dataplane/worker.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "flow/context.h" -#include - -#define MAX_WORKERS_NUM 8 - -enum { batch_size = 32 }; - -static thread_local struct packet_ctx packet_ctxs[batch_size]; \ No newline at end of file diff --git a/modules/balancer/meson.build b/modules/balancer/meson.build index 5dcc00428..48e1c19a0 100644 --- a/modules/balancer/meson.build +++ b/modules/balancer/meson.build @@ -2,7 +2,4 @@ subdir('dataplane') if not dataplane_only subdir('controlplane') - subdir('agent') - subdir('bench') - subdir('tests') -endif \ No newline at end of file +endif diff --git a/modules/balancer/tests/go/active_sessions_test.go b/modules/balancer/tests/go/active_sessions_test.go deleted file mode 100644 index fbbf59ff2..000000000 --- a/modules/balancer/tests/go/active_sessions_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package balancer_test - -import ( - "fmt" - "net/netip" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" -) - -type activeSessionsExpected struct { - total uint64 - lastPacketTS time.Time - vs map[string]activeSessionsVsExpected -} - -type activeSessionsVsExpected struct { - activeSessions uint64 - lastPacketTS time.Time - reals map[string]activeSessionsRealExpected -} - -type activeSessionsRealExpected struct { - activeSessions uint64 - lastPacketTS time.Time -} - -type activeSessionsPacketPlan struct { - vsIP netip.Addr - vsPort uint16 - count int - start int -} - -func TestActiveSessions(t *testing.T) { - vs1IP := netip.MustParseAddr("1.1.1.1") - vs2IP := netip.MustParseAddr("1.1.1.2") - vs1Port := uint16(80) - vs2Port := uint16(8080) - - vs1Real1 := netip.MustParseAddr("2.2.2.1") - vs1Real2 := netip.MustParseAddr("2.2.2.2") - vs1Real3 := netip.MustParseAddr("2.2.2.3") - vs2Real1 := netip.MustParseAddr("3.3.3.1") - vs2Real2 := netip.MustParseAddr("3.3.3.2") - vs2Real3 := netip.MustParseAddr("3.3.3.3") - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - makeActiveSessionsVS(vs1IP, vs1Port, []netip.Addr{vs1Real1, vs1Real2, vs1Real3}), - makeActiveSessionsVS(vs2IP, vs2Port, []netip.Addr{vs2Real1, vs2Real2, vs2Real3}), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 30, - TcpSyn: 30, - TcpFin: 30, - Tcp: 30, - Udp: 30, - Default: 30, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1024); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - utils.EnableAllReals(t, ts) - - initialPlans := []activeSessionsPacketPlan{ - {vsIP: vs1IP, vsPort: vs1Port, count: 18, start: 0}, - {vsIP: vs2IP, vsPort: vs2Port, count: 18, start: 1000}, - } - - t.Run("Initial_Send_And_Check", func(t *testing.T) { - sendActiveSessionsPackets(t, ts, initialPlans) - currentTime := ts.Mock.CurrentTime() - expected := buildActiveSessionsExpected(t, ts, currentTime) - checkActiveSessionsState(t, ts, currentTime, expected) - }) - - t.Run("Advance_29s_And_Check", func(t *testing.T) { - currentTime := ts.Mock.AdvanceTime(29 * time.Second) - expected := buildActiveSessionsExpected(t, ts, currentTime) - checkActiveSessionsState(t, ts, currentTime, expected) - }) - - t.Run("Advance_100s_Send_Again_And_Check", func(t *testing.T) { - ts.Mock.AdvanceTime(100 * time.Second) - sendActiveSessionsPackets(t, ts, initialPlans) - currentTime := ts.Mock.CurrentTime() - expected := buildActiveSessionsExpected(t, ts, currentTime) - checkActiveSessionsState(t, ts, currentTime, expected) - }) -} - -func makeActiveSessionsVS( - vsIP netip.Addr, - vsPort uint16, - reals []netip.Addr, -) *balancerpb.VirtualService { - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: vsIP.AsSlice()}, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0").AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Peers: []*balancerpb.Addr{}, - } - - for _, realIP := range reals { - vs.Reals = append(vs.Reals, &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: realIP.AsSlice()}, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIP.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255").AsSlice(), - }, - }) - } - - return vs -} - -func sendActiveSessionsPackets( - t *testing.T, - ts *utils.TestSetup, - plans []activeSessionsPacketPlan, -) { - t.Helper() - - packets := make([]gopacket.Packet, 0) - for _, plan := range plans { - for i := range plan.count { - clientIP := activeSessionsClientIP(plan.start + i) - clientPort := uint16(10000 + plan.start + i) - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - plan.vsIP, - plan.vsPort, - &layers.TCP{SYN: true}, - ) - packets = append(packets, xpacket.LayersToPacket(t, packetLayers...)) - } - } - - result, err := ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - require.Equal(t, len(packets), len(result.Output), "all packets should be forwarded") - require.Empty(t, result.Drop, "no packets should be dropped") - - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } -} - -func activeSessionsClientIP(index int) netip.Addr { - return netip.MustParseAddr( - fmt.Sprintf("10.%d.%d.%d", index/(256*256)%256, (index/256)%256, index%256), - ) -} - -func buildActiveSessionsExpected( - t *testing.T, - ts *utils.TestSetup, - currentTime time.Time, -) activeSessionsExpected { - t.Helper() - - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - - expected := activeSessionsExpected{ - total: uint64(len(sessions)), - lastPacketTS: time.Time{}, - vs: map[string]activeSessionsVsExpected{}, - } - - for _, session := range sessions { - vsAddr, ok := netip.AddrFromSlice(session.RealId.Vs.Addr.Bytes) - require.True(t, ok, "failed to decode VS addr from session.RealId.Vs") - realAddr, ok := netip.AddrFromSlice(session.RealId.Real.Ip.Bytes) - require.True(t, ok, "failed to decode real addr from session.RealId.Real.Ip") - - vsKey := activeSessionsVSKey(vsAddr, uint16(session.RealId.Vs.Port)) - realKey := realAddr.String() - lastPacketTS := session.LastPacketTimestamp.AsTime() - - if lastPacketTS.After(expected.lastPacketTS) { - expected.lastPacketTS = lastPacketTS - } - - vsExpected := expected.vs[vsKey] - if vsExpected.reals == nil { - vsExpected.reals = map[string]activeSessionsRealExpected{} - } - vsExpected.activeSessions++ - if lastPacketTS.After(vsExpected.lastPacketTS) { - vsExpected.lastPacketTS = lastPacketTS - } - - realExpected := vsExpected.reals[realKey] - realExpected.activeSessions++ - if lastPacketTS.After(realExpected.lastPacketTS) { - realExpected.lastPacketTS = lastPacketTS - } - vsExpected.reals[realKey] = realExpected - expected.vs[vsKey] = vsExpected - } - - return expected -} - -func checkActiveSessionsState( - t *testing.T, - ts *utils.TestSetup, - currentTime time.Time, - expected activeSessionsExpected, -) { - t.Helper() - - activeInfo := ts.Balancer.ActiveSessions() - require.NotNil(t, activeInfo) - - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - require.Equal(t, int(expected.total), len(sessions), "sessions count should match expected") - - assert.Equal(t, expected.total, activeInfo.ActiveSessions, "total active sessions mismatch") - assert.Equal( - t, - timestamppb.New(expected.lastPacketTS), - activeInfo.LastPacketTimestamp, - "balancer last packet timestamp mismatch", - ) - - actualTotalFromVS := uint64(0) - actualTotalFromReals := uint64(0) - require.Len(t, activeInfo.Vs, 2, "should have exactly two VS entries") - - for _, vsInfo := range activeInfo.Vs { - vsAddr, ok := netip.AddrFromSlice(vsInfo.Id.Addr.Bytes) - require.True(t, ok, "failed to decode VS addr from ActiveSessions info") - vsKey := activeSessionsVSKey(vsAddr, uint16(vsInfo.Id.Port)) - vsExpected, ok := expected.vs[vsKey] - require.True(t, ok, "unexpected VS in ActiveSessions: %s", vsKey) - - assert.Equal( - t, - vsExpected.activeSessions, - vsInfo.ActiveSessions, - "VS active sessions mismatch for %s", - vsKey, - ) - assert.Equal( - t, - timestamppb.New(vsExpected.lastPacketTS), - vsInfo.LastPacketTimestamp, - "VS last packet timestamp mismatch for %s", - vsKey, - ) - actualTotalFromVS += vsInfo.ActiveSessions - - actualVSRealTotal := uint64(0) - require.Len(t, vsInfo.Reals, 3, "VS %s should have exactly three reals", vsKey) - for _, realInfo := range vsInfo.Reals { - realAddr, ok := netip.AddrFromSlice(realInfo.Id.Real.Ip.Bytes) - require.True(t, ok, "failed to decode real addr from ActiveSessions info") - realKey := realAddr.String() - realExpected, ok := vsExpected.reals[realKey] - require.True(t, ok, "unexpected real %s in VS %s", realKey, vsKey) - - assert.NotZero( - t, - realInfo.ActiveSessions, - "real %s in VS %s should have active sessions", - realKey, - vsKey, - ) - assert.Equal( - t, - realExpected.activeSessions, - realInfo.ActiveSessions, - "real active sessions mismatch for %s in %s", - realKey, - vsKey, - ) - assert.Equal( - t, - timestamppb.New(realExpected.lastPacketTS), - realInfo.LastPacketTimestamp, - "real last packet timestamp mismatch for %s in %s", - realKey, - vsKey, - ) - actualVSRealTotal += realInfo.ActiveSessions - actualTotalFromReals += realInfo.ActiveSessions - } - - assert.Equal( - t, - vsInfo.ActiveSessions, - actualVSRealTotal, - "sum of real sessions should match VS total for %s", - vsKey, - ) - } - - assert.Equal( - t, - activeInfo.ActiveSessions, - actualTotalFromVS, - "sum of VS sessions should match balancer total", - ) - assert.Equal( - t, - activeInfo.ActiveSessions, - actualTotalFromReals, - "sum of real sessions should match balancer total", - ) -} - -func activeSessionsVSKey(vsIP netip.Addr, vsPort uint16) string { - return fmt.Sprintf("%s:%d", vsIP, vsPort) -} diff --git a/modules/balancer/tests/go/allowed_src_test.go b/modules/balancer/tests/go/allowed_src_test.go deleted file mode 100644 index b3333d59b..000000000 --- a/modules/balancer/tests/go/allowed_src_test.go +++ /dev/null @@ -1,2629 +0,0 @@ -package balancer_test - -// TestAllowedSrc is a comprehensive test suite for the allowed_src source filtering feature. -// -// This test verifies that the balancer correctly filters packets based on the allowed_srcs -// configuration for each virtual service. It covers: -// -// # Source Filtering Behavior -// - Packets from allowed source ranges are accepted and forwarded -// - Packets from non-allowed source ranges are dropped -// - Empty allowed_srcs list denies all sources -// - 0.0.0.0/0 (IPv4) or ::/0 (IPv6) allows all sources -// -// # Protocol Coverage -// - TCP protocol with source filtering -// - UDP protocol with source filtering -// -// # IP Version Coverage -// - IPv4 virtual services with IPv4 allowed_src ranges -// - IPv6 virtual services with IPv6 allowed_src ranges -// -// # Counter Validation -// - packet_src_not_allowed counter increases when packets are blocked -// - incoming_packets counter increases for all packets (allowed and blocked) -// - outgoing_packets counter increases only for allowed packets -// - created_sessions counter increases only for allowed packets -// -// The test uses 7 virtual services with different configurations: -// - VS1: IPv4 TCP port 80 with allowed_src 10.0.1.0/24 -// - VS2: IPv4 UDP port 5353 with allowed_src 10.0.2.0/24 -// - VS3: IPv6 TCP port 8080 with allowed_src 2001:db8:1::/48 -// - VS4: IPv6 UDP port 5353 with allowed_src 2001:db8:2::/48 -// - VS5: IPv4 TCP port 443 with allowed_src 0.0.0.0/1 + 128.0.0.0/1 (allow all IPv4) -// - VS6: IPv4 TCP port 8443 with allowed_src 0.0.0.0/0 (allow all) -// - VS7: IPv4 TCP port 9443 with empty allowed_src (deny all) - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Virtual service addresses and ports for allowed_src tests -var ( - // VS1: IPv4 TCP with restricted source - allowedSrcVs1IP = netip.MustParseAddr("10.10.1.1") - allowedSrcVs1Port = uint16(80) - - // VS2: IPv4 UDP with restricted source - allowedSrcVs2IP = netip.MustParseAddr("10.10.2.1") - allowedSrcVs2Port = uint16(5353) - - // VS3: IPv6 TCP with restricted source - allowedSrcVs3IP = netip.MustParseAddr("2001:db8:100::1") - allowedSrcVs3Port = uint16(8080) - - // VS4: IPv6 UDP with restricted source - allowedSrcVs4IP = netip.MustParseAddr("2001:db8:200::1") - allowedSrcVs4Port = uint16(5353) - - // VS5: IPv4 TCP with large CIDR ranges (allow all) - allowedSrcVs5IP = netip.MustParseAddr("10.10.5.1") - allowedSrcVs5Port = uint16(443) - - // VS6: IPv4 TCP with 0.0.0.0/0 allowed_src (allow all) - allowedSrcVs6IP = netip.MustParseAddr("10.10.6.1") - allowedSrcVs6Port = uint16(8443) - - // VS7: IPv4 TCP with empty allowed_src (deny all) - allowedSrcVs7IP = netip.MustParseAddr("10.10.7.1") - allowedSrcVs7Port = uint16(9443) - - // Real servers for allowed_src tests - allowedSrcRealIPv4 = netip.MustParseAddr("192.168.100.1") - allowedSrcRealIPv6 = netip.MustParseAddr("fe80::100") - - // Balancer source addresses for allowed_src tests - allowedSrcBalancerSrcIPv4 = netip.MustParseAddr("5.5.5.5") - allowedSrcBalancerSrcIPv6 = netip.MustParseAddr("fe80::5") -) - -// createAllowedSrcTestConfig creates a balancer configuration with 7 virtual services -// covering different allowed_src scenarios -func createAllowedSrcTestConfig() *balancerpb.BalancerConfig { - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // VS1: IPv4 TCP with allowed_src 10.0.1.0/24 - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs1IP.AsSlice(), - }, - Port: uint32(allowedSrcVs1Port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS2: IPv4 UDP with allowed_src 10.0.2.0/24 - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs2IP.AsSlice(), - }, - Port: uint32(allowedSrcVs2Port), - Proto: balancerpb.TransportProto_UDP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.2.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS3: IPv6 TCP with allowed_src 2001:db8:1::/48 - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs3IP.AsSlice(), - }, - Port: uint32(allowedSrcVs3Port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:1::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS4: IPv6 UDP with allowed_src 2001:db8:2::/48 - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs4IP.AsSlice(), - }, - Port: uint32(allowedSrcVs4Port), - Proto: balancerpb.TransportProto_UDP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS5: IPv4 TCP with single large CIDR (effectively allow all IPv4) - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs5IP.AsSlice(), - }, - Port: uint32(allowedSrcVs5Port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("128.0.0.0"). - AsSlice(), - }, - }}, - }, - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("128.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("128.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS6: IPv4 TCP with 0.0.0.0/0 allowed_src (allow all) - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs6IP.AsSlice(), - }, - Port: uint32(allowedSrcVs6Port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS7: IPv4 TCP with empty allowed_src (deny all) - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: allowedSrcVs7IP.AsSlice(), - }, - Port: uint32(allowedSrcVs7Port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{}, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// findVsStats finds statistics for a specific virtual service by its identifier -func findVsStats( - stats *balancerpb.BalancerStats, - vsIP netip.Addr, - vsPort uint16, - proto balancerpb.TransportProto, -) *balancerpb.VsStats { - for _, namedVsStats := range stats.Vs { - if namedVsStats.Vs == nil { - continue - } - addr, _ := netip.AddrFromSlice(namedVsStats.Vs.Addr.Bytes) - if addr == vsIP && - namedVsStats.Vs.Port == uint32(vsPort) && - namedVsStats.Vs.Proto == proto { - return namedVsStats.Stats - } - } - return nil -} - -// TestAllowedSrc is the main test function -func TestAllowedSrc(t *testing.T) { - config := createAllowedSrcTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Get packet handler reference for stats - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - // Test IPv4 TCP with allowed source - t.Run("IPv4_TCP_Allowed", func(t *testing.T) { - testIPv4TCPAllowed(t, ts, statsRef) - }) - - // Test IPv4 TCP with blocked source - t.Run("IPv4_TCP_Blocked", func(t *testing.T) { - testIPv4TCPBlocked(t, ts, statsRef) - }) - - // Test IPv4 UDP with allowed source - t.Run("IPv4_UDP_Allowed", func(t *testing.T) { - testIPv4UDPAllowed(t, ts, statsRef) - }) - - // Test IPv4 UDP with blocked source - t.Run("IPv4_UDP_Blocked", func(t *testing.T) { - testIPv4UDPBlocked(t, ts, statsRef) - }) - - // Test IPv6 TCP with allowed source - t.Run("IPv6_TCP_Allowed", func(t *testing.T) { - testIPv6TCPAllowed(t, ts, statsRef) - }) - - // Test IPv6 TCP with blocked source - t.Run("IPv6_TCP_Blocked", func(t *testing.T) { - testIPv6TCPBlocked(t, ts, statsRef) - }) - - // Test IPv6 UDP with allowed source - t.Run("IPv6_UDP_Allowed", func(t *testing.T) { - testIPv6UDPAllowed(t, ts, statsRef) - }) - - // Test IPv6 UDP with blocked source - t.Run("IPv6_UDP_Blocked", func(t *testing.T) { - testIPv6UDPBlocked(t, ts, statsRef) - }) - - // Test empty allowed_src (allow all) - t.Run("Empty_AllowedSrc_AllowsAll", func(t *testing.T) { - testEmptyAllowedSrcAllowsAll(t, ts, statsRef) - }) - - // Test 0.0.0.0/0 allowed_src (allow all) - t.Run("Zero_CIDR_AllowsAll", func(t *testing.T) { - testZeroCIDRAllowsAll(t, ts, statsRef) - }) - - // Test empty allowed_src (deny all) - t.Run("Empty_AllowedSrc_DeniesAll", func(t *testing.T) { - testEmptyAllowedSrcDeniesAll(t, ts, statsRef) - }) -} - -// testEmptyAllowedSrcDeniesAll tests that empty allowed_src denies all sources -func testEmptyAllowedSrcDeniesAll( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs7IP, - allowedSrcVs7Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS7 stats should exist") - - // Send packets from various sources - all should be denied - testSources := []string{ - "10.0.1.1", - "10.0.99.99", - "192.168.1.1", - "1.2.3.4", - "172.16.0.1", - } - - for _, srcIP := range testSources { - clientIP := netip.MustParseAddr(srcIP) - clientPort := uint16(60000) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs7IP, - allowedSrcVs7Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty( - t, - result.Output, - "expected no output packets for source %s when allowed_src is empty", - srcIP, - ) - require.Equal( - t, - 1, - len(result.Drop), - "expected 1 dropped packet for source %s when allowed_src is empty", - srcIP, - ) - } - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs7IP, - allowedSrcVs7Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS7 stats should exist") - - // Verify counters - all packets should be blocked - assert.Equal( - t, - initialVsStats.PacketSrcNotAllowed+uint64(len(testSources)), - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should increase by number of test sources when allowed_src is empty", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets, - finalVsStats.OutgoingPackets, - "outgoing_packets should not increase when allowed_src is empty", - ) - assert.Equal(t, - initialVsStats.CreatedSessions, - finalVsStats.CreatedSessions, - "created_sessions should not increase when allowed_src is empty", - ) -} - -// testIPv4TCPAllowed tests that packets from allowed IPv4 source are accepted -func testIPv4TCPAllowed( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs1IP, - allowedSrcVs1Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS1 stats should exist") - - // Send packet from allowed source (10.0.1.50) - clientIP := netip.MustParseAddr("10.0.1.50") - clientPort := uint16(12345) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs1IP, - allowedSrcVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - - // Validate the output packet - utils.ValidatePacket(t, ts.Balancer.Config(), packet, result.Output[0]) - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs1IP, - allowedSrcVs1Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS1 stats should exist") - - // Verify counters - assert.Equal(t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase for allowed source", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+1, - finalVsStats.IncomingPackets, - "incoming_packets should increase by 1", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+1, - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by 1", - ) - assert.Equal(t, - initialVsStats.CreatedSessions+1, - finalVsStats.CreatedSessions, - "created_sessions should increase by 1", - ) -} - -// testIPv4TCPBlocked tests that packets from non-allowed IPv4 source are blocked -func testIPv4TCPBlocked( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs1IP, - allowedSrcVs1Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS1 stats should exist") - - // Send packet from non-allowed source (10.0.99.50) - clientIP := netip.MustParseAddr("10.0.99.50") - clientPort := uint16(12346) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs1IP, - allowedSrcVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty(t, result.Output, "expected no output packets") - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet") - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs1IP, - allowedSrcVs1Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS1 stats should exist") - - // Verify counters - assert.Equal(t, - initialVsStats.PacketSrcNotAllowed+1, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should increase by 1", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets, - finalVsStats.OutgoingPackets, - "outgoing_packets should not increase for blocked packet", - ) - assert.Equal(t, - initialVsStats.CreatedSessions, - finalVsStats.CreatedSessions, - "created_sessions should not increase for blocked packet", - ) -} - -// testIPv4UDPAllowed tests that UDP packets from allowed IPv4 source are accepted -func testIPv4UDPAllowed( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs2IP, - allowedSrcVs2Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, initialVsStats, "VS2 stats should exist") - - // Send packet from allowed source (10.0.2.50) - clientIP := netip.MustParseAddr("10.0.2.50") - clientPort := uint16(54321) - - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - allowedSrcVs2IP, - allowedSrcVs2Port, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - - // Validate the output packet - utils.ValidatePacket(t, ts.Balancer.Config(), packet, result.Output[0]) - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs2IP, - allowedSrcVs2Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, finalVsStats, "VS2 stats should exist") - - // Verify counters - assert.Equal(t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase for allowed source", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+1, - finalVsStats.IncomingPackets, - "incoming_packets should increase by 1", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+1, - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by 1", - ) -} - -// testIPv4UDPBlocked tests that UDP packets from non-allowed IPv4 source are blocked -func testIPv4UDPBlocked( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs2IP, - allowedSrcVs2Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, initialVsStats, "VS2 stats should exist") - - // Send packet from non-allowed source (10.0.99.50) - clientIP := netip.MustParseAddr("10.0.99.50") - clientPort := uint16(54322) - - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - allowedSrcVs2IP, - allowedSrcVs2Port, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty(t, result.Output, "expected no output packets") - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet") - - // Get stats after the blocked packet - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - vsStats := findVsStats( - stats, - allowedSrcVs2IP, - allowedSrcVs2Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, vsStats, "VS2 stats should exist") - - // Verify that packet_src_not_allowed counter increased - assert.Greater(t, - vsStats.PacketSrcNotAllowed, - uint64(0), - "packet_src_not_allowed should be greater than 0 for blocked packets", - ) -} - -// testIPv6TCPAllowed tests that TCP packets from allowed IPv6 source are accepted -func testIPv6TCPAllowed( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs3IP, - allowedSrcVs3Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS3 stats should exist") - - // Send packet from allowed source (2001:db8:1::50) - clientIP := netip.MustParseAddr("2001:db8:1::50") - clientPort := uint16(23456) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs3IP, - allowedSrcVs3Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - - // Validate the output packet - utils.ValidatePacket(t, ts.Balancer.Config(), packet, result.Output[0]) - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs3IP, - allowedSrcVs3Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS3 stats should exist") - - // Verify counters - assert.Equal(t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase for allowed source", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+1, - finalVsStats.IncomingPackets, - "incoming_packets should increase by 1", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+1, - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by 1", - ) -} - -// testIPv6TCPBlocked tests that TCP packets from non-allowed IPv6 source are blocked -func testIPv6TCPBlocked( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs3IP, - allowedSrcVs3Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS3 stats should exist") - - // Send packet from non-allowed source (2001:db8:99::50) - clientIP := netip.MustParseAddr("2001:db8:99::50") - clientPort := uint16(23457) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs3IP, - allowedSrcVs3Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty(t, result.Output, "expected no output packets") - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet") - - // Get stats after the blocked packet - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - vsStats := findVsStats( - stats, - allowedSrcVs3IP, - allowedSrcVs3Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, vsStats, "VS3 stats should exist") - - // Verify that packet_src_not_allowed counter increased - assert.Greater(t, - vsStats.PacketSrcNotAllowed, - uint64(0), - "packet_src_not_allowed should be greater than 0 for blocked packets", - ) -} - -// testIPv6UDPAllowed tests that UDP packets from allowed IPv6 source are accepted -func testIPv6UDPAllowed( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs4IP, - allowedSrcVs4Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, initialVsStats, "VS4 stats should exist") - - // Send packet from allowed source (2001:db8:2::50) - clientIP := netip.MustParseAddr("2001:db8:2::50") - clientPort := uint16(34567) - - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - allowedSrcVs4IP, - allowedSrcVs4Port, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - - // Validate the output packet - utils.ValidatePacket(t, ts.Balancer.Config(), packet, result.Output[0]) - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs4IP, - allowedSrcVs4Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, finalVsStats, "VS4 stats should exist") - - // Verify counters - assert.Equal(t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase for allowed source", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+1, - finalVsStats.IncomingPackets, - "incoming_packets should increase by 1", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+1, - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by 1", - ) -} - -// testIPv6UDPBlocked tests that UDP packets from non-allowed IPv6 source are blocked -func testIPv6UDPBlocked( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs4IP, - allowedSrcVs4Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, initialVsStats, "VS4 stats should exist") - - // Send packet from non-allowed source (2001:db8:99::50) - clientIP := netip.MustParseAddr("2001:db8:99::50") - clientPort := uint16(34568) - - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - allowedSrcVs4IP, - allowedSrcVs4Port, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty(t, result.Output, "expected no output packets") - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet") - - // Get stats after the blocked packet - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - vsStats := findVsStats( - stats, - allowedSrcVs4IP, - allowedSrcVs4Port, - balancerpb.TransportProto_UDP, - ) - require.NotNil(t, vsStats, "VS4 stats should exist") - - // Verify that packet_src_not_allowed counter increased - assert.Greater(t, - vsStats.PacketSrcNotAllowed, - uint64(0), - "packet_src_not_allowed should be greater than 0 for blocked packets", - ) -} - -// testEmptyAllowedSrcAllowsAll tests that large CIDR ranges allow all sources -func testEmptyAllowedSrcAllowsAll( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs5IP, - allowedSrcVs5Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS5 stats should exist") - - // Send packets from various sources - all should be allowed - testSources := []string{ - "10.0.1.1", - "10.0.99.99", - "192.168.1.1", - "1.2.3.4", - } - - for _, srcIP := range testSources { - clientIP := netip.MustParseAddr(srcIP) - clientPort := uint16(45678) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs5IP, - allowedSrcVs5Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for source %s", - srcIP, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for source %s", - srcIP, - ) - } - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs5IP, - allowedSrcVs5Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS5 stats should exist") - - // Verify counters - no packets should be blocked - assert.Equal( - t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase when allowed_src covers all IPs", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+uint64(len(testSources)), - finalVsStats.IncomingPackets, - "incoming_packets should increase by number of test sources", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+uint64(len(testSources)), - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by number of test sources", - ) -} - -// testZeroCIDRAllowsAll tests that 0.0.0.0/0 allowed_src allows all sources -func testZeroCIDRAllowsAll( - t *testing.T, - ts *utils.TestSetup, - statsRef *balancerpb.PacketHandlerRef, -) { - t.Helper() - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - initialVsStats := findVsStats( - initialStats, - allowedSrcVs6IP, - allowedSrcVs6Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, initialVsStats, "VS6 stats should exist") - - // Send packets from various sources - all should be allowed - testSources := []string{ - "10.0.1.1", - "10.0.99.99", - "192.168.1.1", - "1.2.3.4", - "172.16.0.1", - } - - for _, srcIP := range testSources { - clientIP := netip.MustParseAddr(srcIP) - clientPort := uint16(56789) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - allowedSrcVs6IP, - allowedSrcVs6Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for source %s", - srcIP, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for source %s", - srcIP, - ) - } - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - finalVsStats := findVsStats( - finalStats, - allowedSrcVs6IP, - allowedSrcVs6Port, - balancerpb.TransportProto_TCP, - ) - require.NotNil(t, finalVsStats, "VS6 stats should exist") - - // Verify counters - no packets should be blocked - assert.Equal( - t, - initialVsStats.PacketSrcNotAllowed, - finalVsStats.PacketSrcNotAllowed, - "packet_src_not_allowed should not increase when allowed_src is 0.0.0.0/0", - ) - assert.Equal(t, - initialVsStats.IncomingPackets+uint64(len(testSources)), - finalVsStats.IncomingPackets, - "incoming_packets should increase by number of test sources", - ) - assert.Equal(t, - initialVsStats.OutgoingPackets+uint64(len(testSources)), - finalVsStats.OutgoingPackets, - "outgoing_packets should increase by number of test sources", - ) -} - -// TestAllowedSrcWithPorts tests source filtering with port range restrictions -func TestAllowedSrcWithPorts(t *testing.T) { - // Create configuration with port range restrictions - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // VS with single port range restriction - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.1.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - { - From: 1024, - To: 65535, - }, // Only high ports allowed - }, - }, - }, - Flags: &balancerpb.VsFlags{}, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS with multiple specific port ranges - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.2.1").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, // HTTP - {From: 443, To: 443}, // HTTPS - {From: 8000, To: 9000}, // Custom range - }, - }, - }, - Flags: &balancerpb.VsFlags{}, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - t.Run("HighPortAllowed", func(t *testing.T) { - // Test packet from high port (within 1024-65535 range) - vsIP := netip.MustParseAddr("10.20.1.1") - clientIP := netip.MustParseAddr("192.168.1.100") - clientPort := uint16(50000) // High port - should be allowed - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vsIP, - 80, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for high port", - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for high port", - ) - }) - - t.Run("LowPortBlocked", func(t *testing.T) { - // Test packet from low port (below 1024) - vsIP := netip.MustParseAddr("10.20.1.1") - clientIP := netip.MustParseAddr("192.168.1.100") - clientPort := uint16(80) // Low port - should be blocked - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vsIP, - 80, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty( - t, - result.Output, - "expected no output packets for low port", - ) - require.Equal( - t, - 1, - len(result.Drop), - "expected 1 dropped packet for low port", - ) - - // Verify counter increased - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - vsStats := findVsStats(stats, vsIP, 80, balancerpb.TransportProto_TCP) - require.NotNil(t, vsStats) - assert.Greater(t, vsStats.PacketSrcNotAllowed, uint64(0), - "packet_src_not_allowed should increase for blocked port") - }) - - t.Run("SpecificPortsAllowed", func(t *testing.T) { - // Test packets from specific allowed ports - vsIP := netip.MustParseAddr("10.20.2.1") - clientIP := netip.MustParseAddr("10.1.1.100") - - allowedPorts := []uint16{80, 443, 8500} // All within allowed ranges - for _, port := range allowedPorts { - packetLayers := utils.MakeTCPPacket( - clientIP, - port, - vsIP, - 443, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), - "expected 1 output packet for allowed port %d", port) - require.Empty(t, result.Drop, - "expected no dropped packets for allowed port %d", port) - } - }) - - t.Run("SpecificPortsBlocked", func(t *testing.T) { - // Test packets from ports outside allowed ranges - vsIP := netip.MustParseAddr("10.20.2.1") - clientIP := netip.MustParseAddr("10.1.1.100") - - blockedPorts := []uint16{22, 3306, 10000} // Outside allowed ranges - for _, port := range blockedPorts { - packetLayers := utils.MakeTCPPacket( - clientIP, - port, - vsIP, - 443, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Empty(t, result.Output, - "expected no output packets for blocked port %d", port) - require.Equal(t, 1, len(result.Drop), - "expected 1 dropped packet for blocked port %d", port) - } - }) -} - -// TestAllowedSrcWithTags tests that allowed_sources stats are correctly tracked per tag -func TestAllowedSrcWithTags(t *testing.T) { - // Create configuration with multiple allowed sources with different tags - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // VS with multiple allowed sources with tags - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.30.1.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - // Tag 100: Internal network - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "100"; return &s }(), - }, - { - // Tag 200: Partner network - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }}, - Tag: func() *string { s := "200"; return &s }(), - }, - { - // Tag 300: Public network range (for testing untracked sources) - // Using a specific range that doesn't overlap with tags 100 and 200 - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("8.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - Tag: nil, // nil means no tracking - }, - }, - Flags: &balancerpb.VsFlags{}, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // VS with multiple networks per allowed source and multiple port ranges - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.30.2.1").AsSlice(), - }, - Port: 443, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - // Tag 300: Multiple networks with port restrictions - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.240.0.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.32.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.240.0.0"). - AsSlice(), - }, - }, - }, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, // High ports only - }, - Tag: func() *string { s := "300"; return &s }(), - }, - { - // Tag 400: Different network with different port ranges - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("203.0.113.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - {From: 8080, To: 8080}, - }, - Tag: func() *string { s := "400"; return &s }(), - }, - }, - Flags: &balancerpb.VsFlags{}, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(256*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 128 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - t.Run("TaggedSourcesTracking", func(t *testing.T) { - vsIP := netip.MustParseAddr("10.30.1.1") - vsPort := uint16(80) - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var initialVsStats *balancerpb.NamedVsStats - for _, vs := range initialStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - initialVsStats = vs - break - } - } - require.NotNil(t, initialVsStats, "VS stats should exist") - - // Send packets from different sources - testCases := []struct { - name string - srcIP string - srcPort uint16 - expectedTag uint32 - }{ - {"InternalNetwork", "10.5.5.5", 50000, 100}, - {"PartnerNetwork", "192.168.10.10", 50001, 200}, - {"PublicNetwork", "8.8.8.8", 50002, 0}, // Tag 0 - no tracking - } - - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError( - t, - err, - "packet from %s should be processed", - tc.name, - ) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for %s", - tc.name, - ) - } - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var finalVsStats *balancerpb.NamedVsStats - for _, vs := range finalStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - finalVsStats = vs - break - } - } - require.NotNil(t, finalVsStats, "VS stats should exist") - - // Verify allowed_sources stats - require.NotNil( - t, - finalVsStats.AllowedSources, - "allowed_sources stats should exist", - ) - - // Build a map of tag -> passes for easier verification - tagStats := make(map[string]uint64) - for _, allowedSrc := range finalVsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify tag 100 (Internal network) has 1 pass - assert.Equal(t, uint64(1), tagStats["100"], "tag 100 should have 1 pass") - - // Verify tag 200 (Partner network) has 1 pass - assert.Equal(t, uint64(1), tagStats["200"], "tag 200 should have 1 pass") - - // Verify nil tag (Public network) is NOT tracked - _, exists := tagStats[""] - assert.False(t, exists, "nil tag should not be tracked in stats") - }) - - t.Run("MultipleNetworksAndPorts", func(t *testing.T) { - vsIP := netip.MustParseAddr("10.30.2.1") - vsPort := uint16(443) - - // Get initial stats - initialStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var initialVsStats *balancerpb.NamedVsStats - for _, vs := range initialStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - initialVsStats = vs - break - } - } - require.NotNil(t, initialVsStats, "VS stats should exist") - - // Test packets from different networks and ports - testCases := []struct { - name string - srcIP string - srcPort uint16 - shouldPass bool - expectedTag uint32 - }{ - // Tag 300 tests (172.16.0.0/12 and 172.32.0.0/12 with high ports) - {"Network1HighPort", "172.16.1.1", 50000, true, 300}, - {"Network2HighPort", "172.32.1.1", 50001, true, 300}, - {"Network1LowPort", "172.16.1.1", 80, false, 0}, // Low port blocked - - // Tag 400 tests (203.0.113.0/24 with specific ports) - {"Network3Port80", "203.0.113.10", 80, true, 400}, - {"Network3Port443", "203.0.113.10", 443, true, 400}, - {"Network3Port8080", "203.0.113.10", 8080, true, 400}, - { - "Network3Port22", - "203.0.113.10", - 22, - false, - 0, - }, // Port not in allowed list - } - - passedPackets := make(map[uint32]int) - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err, "packet processing failed for %s", tc.name) - - if tc.shouldPass { - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for %s", - tc.name, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for %s", - tc.name, - ) - passedPackets[tc.expectedTag]++ - } else { - require.Empty(t, result.Output, "expected no output packets for %s", tc.name) - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet for %s", tc.name) - } - } - - // Get final stats - finalStats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var finalVsStats *balancerpb.NamedVsStats - for _, vs := range finalStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - finalVsStats = vs - break - } - } - require.NotNil(t, finalVsStats, "VS stats should exist") - - // Verify allowed_sources stats - require.NotNil( - t, - finalVsStats.AllowedSources, - "allowed_sources stats should exist", - ) - - // Build a map of tag -> passes - tagStats := make(map[string]uint64) - for _, allowedSrc := range finalVsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify tag 300 has correct number of passes (2 packets from different networks) - assert.Equal(t, uint64(passedPackets[300]), tagStats["300"], - "tag 300 should have %d passes", passedPackets[300]) - - // Verify tag 400 has correct number of passes (3 packets from different ports) - assert.Equal(t, uint64(passedPackets[400]), tagStats["400"], - "tag 400 should have %d passes", passedPackets[400]) - }) -} - -// TestAllowedSrcMultipleNetworksWithTags tests ACL behavior with multiple networks per allowed source -// and verifies stats tracking using tags -func TestAllowedSrcMultipleNetworksWithTags(t *testing.T) { - // Create configuration with allowed sources containing multiple networks each - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: allowedSrcBalancerSrcIPv6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // VS with multiple allowed sources, each containing multiple networks - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.40.1.1").AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - // Tag 500: Corporate networks (4 different subnets) - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.10.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.20.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.30.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.40.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0"). - AsSlice(), - }, - }, - }, - Tag: func() *string { s := "500"; return &s }(), - }, - { - // Tag 600: Partner networks (3 different subnets) - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.1.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.2.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.3.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - }, - Tag: func() *string { s := "600"; return &s }(), - }, - { - // Tag 700: External networks (4 different subnets with port restrictions) - Nets: []*balancerpb.Net{ - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("203.0.113.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("198.51.100.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("198.18.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.254.0.0"). - AsSlice(), - }, - }, - { - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("100.64.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.192.0.0"). - AsSlice(), - }, - }, - }, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, // Only high ports - }, - Tag: func() *string { s := "700"; return &s }(), - }, - }, - Flags: &balancerpb.VsFlags{}, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: allowedSrcRealIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(256*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 128 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - vsIP := netip.MustParseAddr("10.40.1.1") - vsPort := uint16(80) - - t.Run("CorporateNetworks_Tag500", func(t *testing.T) { - // Test packets from all 4 corporate networks - testCases := []struct { - name string - srcIP string - srcPort uint16 - }{ - {"Network1", "10.10.5.5", 50000}, - {"Network2", "10.20.10.10", 50001}, - {"Network3", "10.30.15.15", 50002}, - {"Network4", "10.40.20.20", 50003}, - } - - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError( - t, - err, - "packet from %s should be processed", - tc.name, - ) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for %s", - tc.name, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for %s", - tc.name, - ) - } - - // Verify stats for tag 500 - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var vsStats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - vsStats = vs - break - } - } - require.NotNil(t, vsStats, "VS stats should exist") - - // Build tag stats map - tagStats := make(map[string]uint64) - for _, allowedSrc := range vsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify tag 500 has 4 passes (one from each network) - assert.Equal( - t, - uint64(4), - tagStats["500"], - "tag 500 should have 4 passes", - ) - }) - - t.Run("PartnerNetworks_Tag600", func(t *testing.T) { - // Test packets from all 3 partner networks - testCases := []struct { - name string - srcIP string - srcPort uint16 - }{ - {"Partner1", "192.168.1.100", 51000}, - {"Partner2", "192.168.2.200", 51001}, - {"Partner3", "192.168.3.50", 51002}, - } - - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError( - t, - err, - "packet from %s should be processed", - tc.name, - ) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for %s", - tc.name, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for %s", - tc.name, - ) - } - - // Verify stats for tag 600 - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var vsStats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - vsStats = vs - break - } - } - require.NotNil(t, vsStats, "VS stats should exist") - - // Build tag stats map - tagStats := make(map[string]uint64) - for _, allowedSrc := range vsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify tag 600 has 3 passes (one from each partner network) - assert.Equal( - t, - uint64(3), - tagStats["600"], - "tag 600 should have 3 passes", - ) - }) - - t.Run("ExternalNetworks_Tag700_WithPortFiltering", func(t *testing.T) { - // Test packets from all 4 external networks with different port scenarios - testCases := []struct { - name string - srcIP string - srcPort uint16 - shouldPass bool - }{ - // High ports - should pass - {"External1_HighPort", "203.0.113.10", 50000, true}, - {"External2_HighPort", "198.51.100.20", 50001, true}, - {"External3_HighPort", "198.18.5.5", 50002, true}, - {"External4_HighPort", "100.64.10.10", 50003, true}, - - // Low ports - should be blocked - {"External1_LowPort", "203.0.113.11", 80, false}, - {"External2_LowPort", "198.51.100.21", 443, false}, - {"External3_LowPort", "198.18.5.6", 22, false}, - {"External4_LowPort", "100.64.10.11", 1023, false}, - } - - passedCount := 0 - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err, "packet processing failed for %s", tc.name) - - if tc.shouldPass { - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for %s", - tc.name, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for %s", - tc.name, - ) - passedCount++ - } else { - require.Empty(t, result.Output, "expected no output packets for %s", tc.name) - require.Equal(t, 1, len(result.Drop), "expected 1 dropped packet for %s", tc.name) - } - } - - // Verify stats for tag 700 - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var vsStats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - vsStats = vs - break - } - } - require.NotNil(t, vsStats, "VS stats should exist") - - // Build tag stats map - tagStats := make(map[string]uint64) - for _, allowedSrc := range vsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify tag 700 has correct number of passes (only high port packets) - assert.Equal( - t, - uint64(passedCount), - tagStats["700"], - "tag 700 should have %d passes (only high port packets)", - passedCount, - ) - }) - - t.Run("BlockedSources", func(t *testing.T) { - // Test packets from networks not in any allowed source - testCases := []struct { - name string - srcIP string - srcPort uint16 - }{ - {"Blocked1", "172.16.1.1", 50000}, - {"Blocked2", "8.8.8.8", 50001}, - {"Blocked3", "1.1.1.1", 50002}, - } - - for _, tc := range testCases { - clientIP := netip.MustParseAddr(tc.srcIP) - packetLayers := utils.MakeTCPPacket( - clientIP, - tc.srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err, "packet processing failed for %s", tc.name) - require.Empty( - t, - result.Output, - "expected no output packets for %s", - tc.name, - ) - require.Equal( - t, - 1, - len(result.Drop), - "expected 1 dropped packet for %s", - tc.name, - ) - } - - // Verify these blocked packets don't appear in any tag stats - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var vsStats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - vsStats = vs - break - } - } - require.NotNil(t, vsStats, "VS stats should exist") - - // Verify packet_src_not_allowed counter increased - assert.Greater(t, vsStats.Stats.PacketSrcNotAllowed, uint64(0), - "packet_src_not_allowed should be greater than 0") - }) - - t.Run("VerifyAllTagsPresent", func(t *testing.T) { - // Final verification that all tags are present with correct counts - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - - var vsStats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == vsIP && vs.Vs.Port == uint32(vsPort) { - vsStats = vs - break - } - } - require.NotNil(t, vsStats, "VS stats should exist") - - // Build tag stats map - tagStats := make(map[string]uint64) - for _, allowedSrc := range vsStats.AllowedSources { - tagStats[allowedSrc.Tag] = allowedSrc.Passes - } - - // Verify all three tags are present - assert.Contains(t, tagStats, "500", "tag 500 should be present") - assert.Contains(t, tagStats, "600", "tag 600 should be present") - assert.Contains(t, tagStats, "700", "tag 700 should be present") - - // Verify tag 500 (4 corporate networks) - assert.Equal( - t, - uint64(4), - tagStats["500"], - "tag 500 should have 4 passes", - ) - - // Verify tag 600 (3 partner networks) - assert.Equal( - t, - uint64(3), - tagStats["600"], - "tag 600 should have 3 passes", - ) - - // Verify tag 700 (4 external networks with port filtering - only high ports) - assert.Equal( - t, - uint64(4), - tagStats["700"], - "tag 700 should have 4 passes", - ) - - // Verify total allowed sources count - assert.Equal(t, 3, len(vsStats.AllowedSources), - "should have exactly 3 allowed source stats entries") - }) -} diff --git a/modules/balancer/tests/go/basic_test.go b/modules/balancer/tests/go/basic_test.go index 43df1b79e..768bff818 100644 --- a/modules/balancer/tests/go/basic_test.go +++ b/modules/balancer/tests/go/basic_test.go @@ -3,246 +3,383 @@ package balancer_test import ( "net/netip" "testing" + "time" "github.com/c2h5oh/datasize" "github.com/gopacket/gopacket/layers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" ) -func TestBasicOperations(t *testing.T) { - // Define test addresses - vsIP := netip.MustParseAddr("1.1.1.1") - vsPort := uint16(80) - realAddr := netip.MustParseAddr("2.2.2.2") - clientIP := netip.MustParseAddr("3.3.3.3") - - // Create balancer configuration - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIP.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("3.3.3.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } +// Test addresses. +var ( + // Virtual services. + vs1Addr = netip.MustParseAddr("10.0.0.1") // TCP IPv4 + vs2Addr = netip.MustParseAddr("10.0.0.2") // UDP IPv4 + vs3Addr = netip.MustParseAddr("2001:db8::1") // TCP IPv6 GRE + vs4Addr = netip.MustParseAddr("10.0.0.4") // TCP IPv4 OPS + vs5Addr = netip.MustParseAddr("10.0.0.5") + + // Reals for VS1. + real1a = netip.MustParseAddr("192.168.1.1") + real1b = netip.MustParseAddr("192.168.1.2") + real1c = netip.MustParseAddr("192.168.1.3") + + // Reals for VS2. + real2a = netip.MustParseAddr("192.168.2.1") + real2b = netip.MustParseAddr("192.168.2.2") + + // Reals for VS3 (IPv6). + real3a = netip.MustParseAddr("fd00::1") + real3b = netip.MustParseAddr("fd00::2") + + // Reals for VS4. + real4a = netip.MustParseAddr("192.168.4.1") + real4b = netip.MustParseAddr("192.168.4.2") + + // Reals for VS5. + real5a = netip.MustParseAddr("192.168.5.1") + real5b = netip.MustParseAddr("192.168.5.2") + + // Client addresses. + clientV4 = netip.MustParseAddr("3.3.3.1") + clientV6 = netip.MustParseAddr("2001:db8::3") +) + +func buildInitialConfig() *balancerpb.BalancerConfig { + return utils.NewConfigBuilder(). + AddVS( + // VS1: TCP IPv4, source hash, 3 reals + utils.NewTCPVS(vs1Addr.String(), 80). + AllowAll(). + AddReal( + utils.R(real1a.String()), + utils.R(real1b.String()), + utils.R(real1c.String()), + ).Build(), + + // VS2: UDP IPv4, round robin, 2 reals with different weights + utils.NewUDPVS(vs2Addr.String(), 12345). + WithScheduler(balancerpb.VsScheduler_WRR). + AllowAll(). + AddReal( + utils.RW(real2a.String(), 2), + utils.RW(real2b.String(), 1), + ).Build(), + + // VS3: TCP IPv6, GRE encapsulation, 2 reals + utils.NewTCPVS(vs3Addr.String(), 443). + GRE(). + AllowAll(). + AddReal( + utils.R(real3a.String()), + utils.R(real3b.String()), + ).Build(), + + // VS4: TCP IPv4, OPS mode (no sessions) + utils.NewTCPVS(vs4Addr.String(), 8080). + OPS(). + AllowAll(). + AddReal( + utils.R(real4a.String()), + utils.R(real4b.String()), + ).Build(), + ). + Build() +} + +func TestBasic(t *testing.T) { + config := buildInitialConfig() - // Setup test ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), + Mock: utils.SingleWorkerMockConfig(128*datasize.MB, 4*datasize.MB), + Balancer: config, + AgentMemory: 64 * datasize.MB, }) require.NoError(t, err) defer ts.Free() - mock := ts.Mock - balancer := ts.Balancer - - // Enable all reals before sending packets - enableTrue := true - realUpdates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: vsIP.AsSlice()}, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: realAddr.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - } - _, err = balancer.UpdateReals(realUpdates, false) - require.NoError(t, err, "failed to enable reals") - - // Create and send TCP SYN packet - packetLayers := utils.MakeTCPPacket( - clientIP, - 1000, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") + utils.EnableAllReals(t, ts) + ts.Mock.SetCurrentTime(time.Unix(1000, 0)) - // Validate response packet - response := result.Output[0] - utils.ValidatePacket(t, balancer.Config(), packet, response) + vs1 := utils.VsIDFromPb(ts.Balancer.Config().PacketHandler.Vs[0].Id) + vs2 := utils.VsIDFromPb(ts.Balancer.Config().PacketHandler.Vs[1].Id) + vs3 := utils.VsIDFromPb(ts.Balancer.Config().PacketHandler.Vs[2].Id) - // Check balancer info and stats - t.Run("Read_Balancer_Info", func(t *testing.T) { - info, err := balancer.Info(mock.CurrentTime()) - require.NoError(t, err) + t.Run("InitialTraffic", func(t *testing.T) { + // TCP IPv4 => VS1 + _, err := utils.SendAndValidateTCP(ts, clientV4, 10000, vs1Addr, 80, &layers.TCP{SYN: true}) + require.NoError(t, err, "failed to send packet to vs1: %w", err) - // Basic validation that info is populated - assert.NotNil(t, info, "balancer info should not be nil") + // UDP IPv4 => VS2 + _, err = utils.SendAndValidateUDP(ts, clientV4, 10001, vs2Addr, 12345) + require.NoError(t, err, "failed to send packet to vs2: %w", err) - // Check that we have session information - assert.Equal( - t, - uint64(1), - info.ActiveSessions, - "should have exactly one active session", + // TCP IPv6 => VS3 (GRE) + _, err = utils.SendAndValidateTCP(ts, clientV6, 10002, vs3Addr, 443, &layers.TCP{SYN: true}) + require.NoError(t, err, "failed to send packet to vs3: %w", err) + + // TCP IPv4 => VS4 + // OPS mode => no session created + _, err = utils.SendAndValidateTCP( + ts, + clientV4, + 10003, + vs4Addr, + 8080, + &layers.TCP{SYN: true}, ) + require.NoError(t, err, "failed to send packet to vs4") }) - t.Run("Read_Balancer_Stats", func(t *testing.T) { - // Get stats for the specific packet handler - ref := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, + t.Run("SessionAffinity", func(t *testing.T) { + packetsCount := 10 + var rl *utils.RealID + for idx := range packetsCount { + pkt, err := utils.SendAndValidateTCP(ts, clientV4, 10000, vs1Addr, 80, &layers.TCP{}) + require.NoError(t, err, "failed to send packet %d", idx) + if rl == nil { + rl = &pkt.RealID + } else if pkt.RealID.Compare(rl) != 0 { + t.Fatalf("expected all packets to go to the same real, got %s and %s", rl, &pkt.RealID) + } } + }) + + t.Run("ListSessions", func(t *testing.T) { + now := ts.Mock.CurrentTime() + + count := 0 + meetVs1 := false + meetVs2 := false + meetVs3 := false + err := ts.Balancer.ListSessions(nil, now, func(s *balancerpb.Session) error { + pkt, err := utils.PacketInfoFromSessionPb(s) + require.NoError(t, err) + count++ + + switch { + case pkt.VsID.Compare(&vs1) == 0: + meetVs1 = true + case pkt.VsID.Compare(&vs2) == 0: + meetVs2 = true + case pkt.VsID.Compare(&vs3) == 0: + meetVs3 = true + } - stats, err := balancer.Stats(ref) + return nil + }) require.NoError(t, err) - require.NotNil(t, stats, "stats should not be nil") - - // Validate that we have VS stats - require.NotEmpty(t, stats.Vs, "should have VS stats") - - // Check VS stats - vsStats := stats.Vs[0] - assert.Equal( - t, - uint64(1), - vsStats.Stats.IncomingPackets, - "should have 1 incoming packet", - ) - assert.Equal( - t, - uint64(1), - vsStats.Stats.OutgoingPackets, - "should have 1 outgoing packet", - ) - assert.Equal( - t, - uint64(1), - vsStats.Stats.CreatedSessions, - "should have 1 created session", - ) - assert.Equal( - t, - uint64(len(packet.Data())), - vsStats.Stats.IncomingBytes, - "incoming bytes should match packet size", + + assert.Equal(t, count, 3, "expected to meet 3 sessions (one for each no-ops-VS)") + assert.True(t, meetVs1, "expected to meet VS1") + assert.True(t, meetVs2, "expected to meet VS2") + assert.True(t, meetVs3, "expected to meet VS3") + }) + + t.Run("Update", func(t *testing.T) { + // Update: change VS2 real weights. + updatedConfig := buildInitialConfig() + updatedConfig.PacketHandler.Vs[1].Reals[0].Weight = 5 // real2a: 5 + updatedConfig.PacketHandler.Vs[1].Reals[1].Weight = 5 // real2b: 5 + + now := ts.Mock.CurrentTime() + _, err := ts.Balancer.Update(updatedConfig, &now) + require.NoError(t, err) + + utils.EnableAllReals(t, ts) + + for range 10 { + _, err := utils.SendAndValidateUDP(ts, clientV4, 20000, vs2Addr, 12345) + require.NoError(t, err, "failed to send packet: %w", err) + } + }) + + t.Run("UpdateVS", func(t *testing.T) { + // Add a new VS5. + newVS := utils.NewTCPVS(vs5Addr.String(), 9090). + AllowAll(). + AddReal( + utils.R(real5a.String()), + utils.R(real5b.String()), + ).Build() + + _, err := ts.Balancer.UpdateVS( + []*balancerpb.VirtualService{newVS}, ) - assert.Equal( - t, - uint64(len(response.RawData)), - vsStats.Stats.OutgoingBytes, - "outgoing bytes should match packet size", + require.NoError(t, err) + + // Enable reals for the new VS. + utils.EnableAllReals(t, ts) + + // Send traffic to VS5 and validate. + for idx := range 10 { + _, err := utils.SendAndValidateTCP( + ts, + clientV4, + 20000, + vs5Addr, + 9090, + &layers.TCP{SYN: true}, + ) + require.NoError(t, err, "failed to send packet %d: %w", idx, err) + } + + // Existing VS1 still works. + for idx := range 10 { + _, err := utils.SendAndValidateTCP( + ts, + clientV4, + 20001, + vs1Addr, + 80, + &layers.TCP{SYN: true}, + ) + require.NoError(t, err, "failed to send packet %d: %w", idx, err) + } + }) + + t.Run("DeleteVS", func(t *testing.T) { + // Delete VS4 (OPS). + vs4ToDelete := &balancerpb.VirtualService{ + Id: &balancerpb.VsIdentifier{ + Addr: vs4Addr.AsSlice(), + Port: 8080, + Proto: balancerpb.TransportProto_TCP, + }, + } + + _, err := ts.Balancer.DeleteVS( + []*balancerpb.VirtualService{vs4ToDelete}, ) + require.NoError(t, err) - // Check Real stats - require.NotEmpty(t, vsStats.Reals, "should have Real stats") - realStats := vsStats.Reals[0] - assert.Equal( - t, - uint64(1), - realStats.Stats.CreatedSessions, - "real should have 1 created session", + // Traffic to deleted VS4 should be dropped. + pkt := xpacket.LayersToPacket(t, + utils.MakeTCPPacketLayers(clientV4, 30000, vs4Addr, 8080, &layers.TCP{SYN: true})..., ) - assert.Equal( - t, - uint64(1), - realStats.Stats.Packets, - "real should have 1 packet", + result, err := ts.Mock.HandlePackets(pkt) + require.NoError(t, err) + assert.Empty(t, result.Output, "expected no output for deleted VS") + assert.NotEmpty(t, result.Drop, "expected drop for deleted VS") + + // VS1 still works. + for idx := range 10 { + _, err := utils.SendAndValidateTCP( + ts, + clientV4, + 30001, + vs1Addr, + 80, + &layers.TCP{SYN: true}, + ) + require.NoError(t, err, "failed to send packet %d: %w", idx, err) + } + }) + + t.Run("UpdateReals", func(t *testing.T) { + // Disable real1b in VS1. + config := ts.Balancer.Config() + _, err := ts.Balancer.UpdateReals([]*balancerpb.RealUpdate{ + utils.DisableReal( + config.PacketHandler.Vs[0].Id, + config.PacketHandler.Vs[0].Reals[1].Id, + ), + }, false) + require.NoError(t, err) + + // Send many packets with unique sources -> should only go to real1a and real1c. + results, err := utils.SendAndValidateRandomSrcPorts( + ts, + clientV4, + vs1Addr, + 80, + &layers.TCP{SYN: true}, + 1000, ) - assert.Equal( - t, - uint64(len(response.RawData)), - realStats.Stats.Bytes, - "real bytes should match packet size", + require.NoError(t, err, "failed to send packets: %w", err) + + counts, err := utils.CountPacketsPerReal(results) + require.NoError(t, err, "failed to count packets per real: %w") + + _, hasDisabled := counts[real1b] + assert.False(t, hasDisabled, "disabled real1b should receive no traffic, got %v", counts) + assert.Equal(t, 2, len(counts)) + + config = ts.Balancer.Config() + + // Re-enable real1b. + _, err = ts.Balancer.UpdateReals([]*balancerpb.RealUpdate{ + utils.EnableReal( + config.PacketHandler.Vs[0].Id, + config.PacketHandler.Vs[0].Reals[1].Id, + ), + }, false) + require.NoError(t, err) + + results, err = utils.SendAndValidateRandomSrcPorts( + ts, + clientV4, + vs1Addr, + 80, + &layers.TCP{SYN: true}, + 1000, ) + require.NoError(t, err, "failed to send packets: %w", err) + counts, err = utils.CountPacketsPerReal(results) + require.NoError(t, err) + + _, hasEnabled := counts[real1b] + assert.True(t, hasEnabled, "enabled real1b should receive traffic, got %v", counts) + assert.Equal(t, 3, len(counts)) + }) + + t.Run("GetState", func(t *testing.T) { + ref := utils.PacketHandlerRef() + states, err := ts.Balancer.GetState(ref, nil, true, ts.Mock.CurrentTime()) + require.NoError(t, err) + require.NotEmpty(t, states) + + state := states[0] + + // L4 stats should show some processed packets. + require.NotNil(t, state.L4Stats) + assert.Greater(t, state.L4Stats.IncomingPackets, uint64(0), + "expected non-zero incoming packets") + assert.Greater(t, state.L4Stats.OutgoingPackets, uint64(0), + "expected non-zero outgoing packets") + + // Should have VS states. + assert.Equal(t, 4, len(state.VirtualServices), "expected 4 virtual services") + assert.Greater(t, state.ActiveSessions, uint64(0), + "expected non-zero active sessions") + assert.NotNil(t, state.LastPacketTimestamp, "expected last packet timestamp") + }) + + ts.Mock.AdvanceTime(time.Second * 200) + + t.Run("GetStateAfterTimeAdvance", func(t *testing.T) { + states, err := ts.Balancer.GetState(nil, nil, false, ts.Mock.CurrentTime()) + require.NoError(t, err) + require.NotEmpty(t, states) + + state := states[0] + + assert.Equal(t, state.ActiveSessions, uint64(0), + "expected zero active sessions after time advance") + }) + + t.Run("ListSessionsAfterTimeAdvance", func(t *testing.T) { + now := ts.Mock.CurrentTime() + found := false + err := ts.Balancer.ListSessions(nil, now, func(_ *balancerpb.Session) error { + found = true + return nil + }) + require.NoError(t, err) + require.False(t, found, "expected no sessions after time advance") }) } diff --git a/modules/balancer/tests/go/big_config_test.go b/modules/balancer/tests/go/big_config_test.go deleted file mode 100644 index eaf2fc828..000000000 --- a/modules/balancer/tests/go/big_config_test.go +++ /dev/null @@ -1,785 +0,0 @@ -package balancer_test - -// TestBigConfig tests the balancer with large configurations to verify: -// -// # Configuration Scalability -// - Phase 1: 10 virtual services with 50 reals each (500 total reals) -// - Phase 2: 20 virtual services with 20 reals each (400 total reals) -// - Random IPv4/IPv6 addresses for virtual services and reals -// - Random TCP/UDP protocols -// - Random schedulers (WRR, PRR, WLC) -// - Random flags (GRE, FixMSS, OPS) -// -// # Packet Processing -// - Sends 10 batches of 10,000 packets per phase -// - 90% packets to existing virtual services -// - 10% packets to non-existent virtual services (should be dropped) -// - Validates packet distribution to enabled reals only -// -// # Real Server Management -// - Randomly disables half of reals before each batch -// - Re-enables all reals after each batch -// - Validates that disabled reals receive no new packets -// -// # Session Management -// - Tracks active sessions across batches -// - Validates session table capacity and load factor -// - Syncs active sessions between batches -// -// # Performance Metrics -// - Measures configuration creation time -// - Measures configuration update time -// - Measures real enable/disable time -// - Measures packet handling time and RPS (requests per second) - -import ( - "math/rand" - "net/netip" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "github.com/yanet-platform/yanet2/tests/functional/framework" - "google.golang.org/protobuf/types/known/durationpb" -) - -//////////////////////////////////////////////////////////////////////////////// -// Random IP generation helpers - -// generateRandomIPv4 generates a random IPv4 address -func generateRandomIPv4(rng *rand.Rand) netip.Addr { - return netip.AddrFrom4([4]byte{ - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - }) -} - -// generateRandomIPv6 generates a random IPv6 address -func generateRandomIPv6(rng *rand.Rand) netip.Addr { - var bytes [16]byte - for i := range bytes { - bytes[i] = byte(rng.Intn(256)) - } - return netip.AddrFrom16(bytes) -} - -// generateRandomIP generates a random IP address (IPv4 or IPv6) -func generateRandomIP(rng *rand.Rand, forceIPv6 bool) netip.Addr { - if forceIPv6 || rng.Intn(2) == 0 { - return generateRandomIPv6(rng) - } - return generateRandomIPv4(rng) -} - -//////////////////////////////////////////////////////////////////////////////// -// Configuration creation - -// createBigConfig generates a balancer configuration with the specified number -// of virtual services and reals per VS. Virtual services can have different -// flags (GRE, FixMSS, OPS, PureL3), IPv4/IPv6 addresses, and TCP/UDP protocols. -// FixMSS flag is only set for IPv6 virtual services. -func createBigConfig( - vsCount int, - realsPerVs int, - rng *rand.Rand, -) *balancerpb.BalancerConfig { - virtualServices := make([]*balancerpb.VirtualService, 0, vsCount) - - for range vsCount { - // Randomly choose IPv4 or IPv6 for VS - isIPv6 := rng.Intn(2) == 0 - var vsIP netip.Addr - if isIPv6 { - vsIP = generateRandomIPv6(rng) - } else { - vsIP = generateRandomIPv4(rng) - } - - // Randomly choose TCP or UDP - var proto balancerpb.TransportProto - if rng.Intn(2) == 0 { - proto = balancerpb.TransportProto_TCP - } else { - proto = balancerpb.TransportProto_UDP - } - - // Random flags - useGRE := rng.Intn(2) == 0 - useOPS := rng.Intn(2) == 0 - usePureL3 := false - - // FixMSS only for IPv6 VS - useFixMSS := isIPv6 && rng.Intn(2) == 0 - - vsPort := uint32(1 + rng.Intn(65535)) - - // Random scheduler - schedulers := []balancerpb.VsScheduler{ - balancerpb.VsScheduler_ROUND_ROBIN, - balancerpb.VsScheduler_SOURCE_HASH, - } - scheduler := schedulers[rng.Intn(len(schedulers))] - - // Generate reals for this VS - reals := make([]*balancerpb.Real, 0, realsPerVs) - for range realsPerVs { - // Reals can be IPv4 or IPv6 independently of VS - realIP := generateRandomIP(rng, false) - - // Generate source address and mask - var srcAddr, srcMask netip.Addr - if realIP.Is4() { - srcAddr = generateRandomIPv4(rng) - srcMask = netip.AddrFrom4([4]byte{255, 255, 255, 255}) - } else { - srcAddr = generateRandomIPv6(rng) - srcMask = netip.AddrFrom16([16]byte{ - 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, - }) - } - - reals = append(reals, &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: realIP.AsSlice()}, - Port: 0, - }, - Weight: uint32(1 + rng.Intn(100)), - SrcAddr: &balancerpb.Addr{ - Bytes: srcAddr.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: srcMask.AsSlice(), - }, - }) - } - - // Create allowed sources (allow all traffic for simplicity) - // Only add allowed sources that match the VS IP version - var allowedSrcs []*balancerpb.AllowedSources - if isIPv6 { - // IPv6 VS - only allow IPv6 sources - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{}).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{}).AsSlice(), - }, - }}, - }, - } - } else { - // IPv4 VS - only allow IPv4 sources - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{0, 0, 0, 0}).AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{0, 0, 0, 0}).AsSlice(), - }, - }}, - }, - } - } - - virtualServices = append(virtualServices, &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: vsIP.AsSlice()}, - Port: vsPort, - Proto: proto, - }, - Scheduler: scheduler, - AllowedSrcs: allowedSrcs, - Reals: reals, - Flags: &balancerpb.VsFlags{ - Gre: useGRE, - FixMss: useFixMSS, - Ops: useOPS, - PureL3: usePureL3, - Wlc: false, - }, - Peers: []*balancerpb.Addr{}, - }) - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: virtualServices, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { - v := uint64(200_000) - return &v - }(), - SessionTableMaxLoadFactor: func() *float32 { - v := float32(0.75) - return &v - }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { - v := uint64(10) - return &v - }(), - MaxWeight: func() *uint32 { - v := uint32(1000) - return &v - }(), - }, - }, - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Virtual service key for tracking - -// vsKey uniquely identifies a virtual service -type vsKey struct { - ip netip.Addr - port uint16 - proto balancerpb.TransportProto -} - -func (key *vsKey) String() string { - protoStr := "TCP" - if key.proto == balancerpb.TransportProto_UDP { - protoStr = "UDP" - } - return netip.AddrPortFrom(key.ip, key.port).String() + "/" + protoStr -} - -// vsKeyFromPacket extracts the VS key from a tunneled packet -func vsKeyFromPacket( - t *testing.T, - packet *framework.PacketInfo, -) (*vsKey, *netip.Addr) { - t.Helper() - - if !packet.IsTunneled { - t.Errorf("Output packet is not tunneled") - return nil, nil - } - - if packet.InnerPacket == nil { - t.Errorf("Output packet has no inner packet") - return nil, nil - } - - innerPkt := packet.InnerPacket - - // Get destination IP of the tunneled packet (should be a real) - realIP, ok := netip.AddrFromSlice(packet.DstIP) - if !ok { - t.Errorf("Invalid real IP in output packet") - return nil, nil - } - - // Get inner packet destination (VS IP) - innerDstIP, ok := netip.AddrFromSlice(innerPkt.DstIP) - if !ok { - t.Errorf("Invalid inner dst IP") - return nil, nil - } - - dstPort := packet.DstPort - - // Determine protocol from inner packet - transportProto, ok := innerPkt.GetTransportProtocol() - if !ok { - t.Errorf("Unable to determine transport protocol from inner packet") - return nil, nil - } - - var proto balancerpb.TransportProto - switch transportProto { - case layers.IPProtocolTCP: - proto = balancerpb.TransportProto_TCP - case layers.IPProtocolUDP: - proto = balancerpb.TransportProto_UDP - default: - t.Errorf("Unknown transport protocol: %v", transportProto) - return nil, nil - } - - // Find the VS this packet was sent to - key := vsKey{ - ip: innerDstIP, - port: dstPort, - proto: proto, - } - - return &key, &realIP -} - -// vsInfo contains information about a virtual service for validation -type vsInfo struct { - realAddrs map[netip.Addr]bool - enabledReals map[netip.Addr]bool -} - -// buildVSMaps builds lookup maps for virtual services and their reals -func buildVSMaps(config *balancerpb.BalancerConfig) map[vsKey]*vsInfo { - vsMap := make(map[vsKey]*vsInfo) - - if config.PacketHandler == nil { - return vsMap - } - - for _, vs := range config.PacketHandler.Vs { - vsIP, ok := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if !ok { - continue - } - - key := vsKey{ - ip: vsIP, - port: uint16(vs.Id.Port), - proto: vs.Id.Proto, - } - - info := &vsInfo{ - realAddrs: make(map[netip.Addr]bool), - enabledReals: make(map[netip.Addr]bool), - } - - for _, real := range vs.Reals { - realIP, ok := netip.AddrFromSlice(real.Id.Ip.Bytes) - if ok { - info.realAddrs[realIP] = true - // Initially all reals are enabled (EnableAllReals is called before testPacketSending) - info.enabledReals[realIP] = true - } - } - - vsMap[key] = info - } - - return vsMap -} - -// updateVSMapsWithRealStates updates the enabled state of reals in VS maps -func updateVSMapsWithRealStates( - vsMap map[vsKey]*vsInfo, - updates []*balancerpb.RealUpdate, -) { - for _, update := range updates { - if update.RealId == nil || update.RealId.Vs == nil || - update.RealId.Real == nil { - continue - } - - vsIP, ok := netip.AddrFromSlice(update.RealId.Vs.Addr.Bytes) - if !ok { - continue - } - - key := vsKey{ - ip: vsIP, - port: uint16(update.RealId.Vs.Port), - proto: update.RealId.Vs.Proto, - } - - info, exists := vsMap[key] - if !exists { - continue - } - - realIP, ok := netip.AddrFromSlice(update.RealId.Real.Ip.Bytes) - if !ok { - continue - } - - if update.Enable != nil { - info.enabledReals[realIP] = *update.Enable - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Main test function - -// TestBigConfig tests the balancer with large configurations -func TestBigConfig(t *testing.T) { - // Use fixed seed for reproducibility - rng := rand.New(rand.NewSource(42)) - - // Create initial config: 10 VS with 50 reals each - initialConfig := createBigConfig(10, 50, rng) - - // Setup test with appropriate memory - agentMemory := 128 * datasize.MB - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig( - 256*datasize.MB, - 16*datasize.MB, - ), - Balancer: initialConfig, - AgentMemory: &agentMemory, - }) - require.NoError(t, err) - defer ts.Free() - - t.Logf( - "Setup initial config with 10 VS and 50 reals per VS (total 500 reals)", - ) - - // Set initial time - ts.Mock.SetCurrentTime(time.Unix(0, 0)) - - // Enable all reals initially - utils.EnableAllReals(t, ts) - - // Phase 1: Test with 10 VS and 50 reals each - t.Run("Phase1_10VS_50Reals", func(t *testing.T) { - testPacketSending(t, ts, initialConfig, rng, 10, 8000) - }) - - // Phase 2: Update to 20 VS with 20 reals each - t.Run("Phase2_20VS_20Reals", func(t *testing.T) { - newConfig := createBigConfig(20, 20, rng) - - updateStart := time.Now() - _, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - t.Logf( - "Updated config to 20 VS and 20 reals per VS (total 400 reals), elapsed: %v", - time.Since(updateStart), - ) - - utils.EnableAllReals(t, ts) - - testPacketSending(t, ts, newConfig, rng, 10, 8000) - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Packet sending and validation - -// testPacketSending sends batches of packets and validates the results -func testPacketSending( - t *testing.T, - ts *utils.TestSetup, - config *balancerpb.BalancerConfig, - rng *rand.Rand, - numBatches int, - packetsPerBatch int, -) { - t.Helper() - - totalOutput := 0 - totalDrop := 0 - totalPackets := 0 - correctPackets := 0 - - // Extract VS information for packet generation - if config.PacketHandler == nil { - t.Fatal("PacketHandler config is nil") - } - virtualServices := config.PacketHandler.Vs - - for batch := range numBatches { - // Build VS maps for validation - vsMap := buildVSMaps(config) - - // Disable random half of reals before batch - var disableUpdates []*balancerpb.RealUpdate - var enableUpdates []*balancerpb.RealUpdate - disabledCount := 0 - - for _, vs := range virtualServices { - // Randomly select half of reals to disable - numToDisable := len(vs.Reals) / 2 - indices := rng.Perm(len(vs.Reals)) - - for i, real := range vs.Reals { - shouldDisable := false - for j := range numToDisable { - if indices[j] == i { - shouldDisable = true - break - } - } - - enableFalse := false - enableTrue := true - - if shouldDisable { - disableUpdates = append( - disableUpdates, - &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: vs.Id, - Real: real.Id, - }, - Enable: &enableFalse, - }, - ) - disabledCount++ - } else { - enableUpdates = append(enableUpdates, &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Vs: vs.Id, - Real: real.Id, - }, - Enable: &enableTrue, - }) - } - } - } - - // Apply disable updates and measure time - disableStartTime := time.Now() - _, err := ts.Balancer.UpdateReals(disableUpdates, false) - require.NoError(t, err) - disableDuration := time.Since(disableStartTime) - - // Update VS maps with new enabled states - updateVSMapsWithRealStates(vsMap, disableUpdates) - updateVSMapsWithRealStates(vsMap, enableUpdates) - - t.Logf( - "Batch %d/%d: Disabled %d reals in %v", - batch+1, - numBatches, - disabledCount, - disableDuration, - ) - - packets := make([]gopacket.Packet, 0, packetsPerBatch) - - // Generate packets to existing VS (90% of packets) - existingVSPackets := packetsPerBatch * 9 / 10 - for range existingVSPackets { - // Pick a random VS - vs := virtualServices[rng.Intn(len(virtualServices))] - - vsIP, ok := netip.AddrFromSlice(vs.Id.Addr.Bytes) - require.True(t, ok, "invalid VS IP") - - vsPort := uint16(vs.Id.Port) - - // Generate random source with matching IP protocol - var srcIP netip.Addr - if vsIP.Is4() { - srcIP = generateRandomIPv4(rng) - } else { - srcIP = generateRandomIPv6(rng) - } - // Use ephemeral port range, avoiding well-known ports - // Range: 32768-61000 to avoid protocol detection issues - srcPort := uint16(32768 + rng.Intn(28232)) - - // Create packet based on protocol - var packetLayers []gopacket.SerializableLayer - if vs.Id.Proto == balancerpb.TransportProto_TCP { - packetLayers = utils.MakeTCPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - } else { - packetLayers = utils.MakeUDPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - ) - } - - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - // Generate packets to non-existent VS (10% of packets) - nonExistentPackets := packetsPerBatch - existingVSPackets - for range nonExistentPackets { - // Generate a random IP that's unlikely to match any VS - nonExistentIP := netip.AddrFrom4([4]byte{ - byte(200 + rng.Intn(55)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - }) - nonExistentPort := uint16(60000 + rng.Intn(5535)) - - srcIP := generateRandomIPv4(rng) - // Use ephemeral port range - srcPort := uint16(32768 + rng.Intn(28232)) - - packetLayers := utils.MakeUDPPacket( - srcIP, - srcPort, - nonExistentIP, - nonExistentPort, - ) - - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - totalPackets += len(packets) - - // Send packets - handleStartTime := time.Now() - result, err := ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - handleDuration := time.Since(handleStartTime) - - totalOutput += len(result.Output) - totalDrop += len(result.Drop) - - assert.Equal( - t, - existingVSPackets, - len(result.Output), - "all packets to existing VS should be tunneled", - ) - - // Validate output packets - for _, outPkt := range result.Output { - // Check packet is tunneled - key, realIP := vsKeyFromPacket(t, outPkt) - require.NotNil(t, key) - require.NotNil(t, realIP) - - vsInfo, exists := vsMap[*key] - if !exists { - // all packets to non existent VS should be dropped - t.Errorf( - "Packet tunneled to non-existent VS %s", - key.String(), - ) - continue - } - - // Check that the real IP is in the VS's real list - if !vsInfo.realAddrs[*realIP] { - t.Errorf( - "Packet tunneled to real %s which is not in VS %s reals", - *realIP, - key.String(), - ) - continue - } - - // Check that the real is enabled - if !vsInfo.enabledReals[*realIP] { - t.Errorf("Packet tunneled to DISABLED real %s in VS %s", - *realIP, key.String()) - continue - } - - // All checks passed - correctPackets++ - } - - // Validate drop packets - for _, dropPkt := range result.Drop { - // Check packet is not tunneled (dropped before tunneling) - if !dropPkt.IsTunneled { - correctPackets += 1 - } else { - t.Errorf("Drop packet should not be tunneled") - } - } - - // Calculate RPS (requests per second) - rps := float64(len(packets)) / handleDuration.Seconds() - - // Log progress - t.Logf( - "Batch %d/%d: Sent %d packets, Output=%d, Drop=%d, Correct=%d/%d, HandleTime=%v, RPS=%.0f", - batch+1, - numBatches, - len(packets), - len(result.Output), - len(result.Drop), - correctPackets, - totalPackets, - handleDuration, - rps, - ) - - // Re-enable all reals after batch and measure time - enableStartTime := time.Now() - _, err = ts.Balancer.UpdateReals(enableUpdates, false) - require.NoError(t, err) - enableDuration := time.Since(enableStartTime) - - t.Logf( - "Batch %d/%d: Re-enabled all reals in %v", - batch+1, - numBatches, - enableDuration, - ) - - // Advance time by 1 second - ts.Mock.AdvanceTime(time.Second) - - // Sync active sessions - err = ts.Balancer.Refresh(ts.Mock.CurrentTime()) - require.NoError(t, err) - } - - // Final statistics - assert.Equal(t, totalPackets, totalOutput+totalDrop, - "Total packets should equal output + drop") - - outputRate := float64(totalOutput) / float64(totalPackets) * 100 - dropRate := float64(totalDrop) / float64(totalPackets) * 100 - correctRate := float64(correctPackets) / float64(totalPackets) * 100 - - t.Logf( - "Final statistics: Total=%d, Output=%d (%.2f%%), Drop=%d (%.2f%%), Correct=%d (%.2f%% of total)", - totalPackets, - totalOutput, - outputRate, - totalDrop, - dropRate, - correctPackets, - correctRate, - ) - - // Get final state info - info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - t.Logf("Active sessions: %d", info.ActiveSessions) - t.Logf("Session table capacity: %d", *config.State.SessionTableCapacity) -} diff --git a/modules/balancer/tests/go/config_update_test.go b/modules/balancer/tests/go/config_update_test.go deleted file mode 100644 index 5f472e086..000000000 --- a/modules/balancer/tests/go/config_update_test.go +++ /dev/null @@ -1,1728 +0,0 @@ -package balancer_test - -// TestConfigUpdateAndStats is a comprehensive test that verifies: -// -// # Initial Configuration -// - 2 virtual services (VS1, VS2) with 3 reals each -// - Packet distribution across all reals -// - Session creation and persistence -// -// # Real Server State Management -// - Disabling reals and verifying traffic routing -// - Enabled/disabled state tracking in Graph -// -// # API Outputs (Initial State) -// - Config(): Configuration retrieval -// - Graph(): Topology with enabled/disabled states -// - Stats(): Exact packet counts per VS and real -// - Info(): Exact session counts per VS -// - Sessions(): Session details and distribution -// -// # Configuration Update -// - Removing VS1 completely -// - Modifying VS2 (keeping Real6, adding Real7, Real8) -// - Adding new VS3 with 3 new reals -// - Verifying Real6 remains enabled while new reals are disabled -// -// # API Outputs (After Update) -// - Config(): Only VS2 and VS3 present -// - Graph(): Correct topology, Real6 enabled, new reals initially disabled -// - Stats(): VS2 cumulative (not reset), VS3 new, NO VS1 -// - Info(): Only VS2 and VS3, NO VS1 -// - Sessions(): Only VS2 and VS3 sessions, NO VS1 or deleted reals -// -// # State Persistence with New Agent -// - Creating new BalancerAgent attached to same shared memory -// - Verifying existing BalancerManager is discovered and accessible -// - Config(), Graph(), Stats(), Info(), Sessions() match previous outputs -// - Sending new packets through new agent (10 to VS2, 10 to VS3) -// - Verifying Stats, Info, Sessions update correctly with new traffic - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/logging" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "github.com/yanet-platform/yanet2/tests/functional/framework" - "go.uber.org/zap/zapcore" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Test addresses for config update test -var ( - // Virtual Services - cfgVs1IP = netip.MustParseAddr("10.10.1.1") - cfgVs1Port = uint16(80) - cfgVs2IP = netip.MustParseAddr("10.10.2.1") - cfgVs2Port = uint16(80) - cfgVs3IP = netip.MustParseAddr("10.10.3.1") - cfgVs3Port = uint16(80) - - // Real servers for VS1 - cfgReal1IP = netip.MustParseAddr("192.168.11.1") - cfgReal2IP = netip.MustParseAddr("192.168.11.2") - cfgReal3IP = netip.MustParseAddr("192.168.11.3") - - // Real servers for VS2 - cfgReal4IP = netip.MustParseAddr("192.168.12.1") - cfgReal5IP = netip.MustParseAddr("192.168.12.2") - cfgReal6IP = netip.MustParseAddr("192.168.12.3") - - // New real servers for VS2 (after update) - cfgReal7IP = netip.MustParseAddr("192.168.12.4") - cfgReal8IP = netip.MustParseAddr("192.168.12.5") - - // Real servers for VS3 (new) - cfgReal9IP = netip.MustParseAddr("192.168.13.1") - cfgReal10IP = netip.MustParseAddr("192.168.13.2") - cfgReal11IP = netip.MustParseAddr("192.168.13.3") - - // Client base IP - cfgClientBaseIP = netip.MustParseAddr("3.3.13.1") - - // Balancer source addresses - cfgBalancerSrcV4 = netip.MustParseAddr("5.5.15.5") - cfgBalancerSrcV6 = netip.MustParseAddr("fe80::15") -) - -// createCfgReal creates a Real configuration -func createCfgReal(ip netip.Addr, weight uint32) *balancerpb.Real { - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: 0, - }, - Weight: weight, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.14.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255").AsSlice(), - }, - } -} - -// createCfgVirtualService creates a VirtualService configuration -func createCfgVirtualService( - ip netip.Addr, - port uint16, - reals []*balancerpb.Real, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - } -} - -// createCfgInitialConfig creates the initial balancer configuration with VS1 and VS2 -func createCfgInitialConfig() *balancerpb.BalancerConfig { - vs1Reals := []*balancerpb.Real{ - createCfgReal(cfgReal1IP, 1), - createCfgReal(cfgReal2IP, 1), - createCfgReal(cfgReal3IP, 1), - } - - vs2Reals := []*balancerpb.Real{ - createCfgReal(cfgReal4IP, 1), - createCfgReal(cfgReal5IP, 1), - createCfgReal(cfgReal6IP, 1), - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: cfgBalancerSrcV4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: cfgBalancerSrcV6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - createCfgVirtualService(cfgVs1IP, cfgVs1Port, vs1Reals), - createCfgVirtualService(cfgVs2IP, cfgVs2Port, vs2Reals), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// createCfgUpdatedConfig creates the updated configuration with VS2 (modified) and VS3 (new) -func createCfgUpdatedConfig() *balancerpb.BalancerConfig { - vs2Reals := []*balancerpb.Real{ - createCfgReal(cfgReal6IP, 1), // OLD - persists - createCfgReal(cfgReal7IP, 1), // NEW - createCfgReal(cfgReal8IP, 1), // NEW - } - - vs3Reals := []*balancerpb.Real{ - createCfgReal(cfgReal9IP, 1), - createCfgReal(cfgReal10IP, 1), - createCfgReal(cfgReal11IP, 1), - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: cfgBalancerSrcV4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: cfgBalancerSrcV6.AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - createCfgVirtualService(cfgVs2IP, cfgVs2Port, vs2Reals), - createCfgVirtualService(cfgVs3IP, cfgVs3Port, vs3Reals), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// generateCfgClientIP generates a unique client IP based on index -func generateCfgClientIP(index int) netip.Addr { - base := cfgClientBaseIP.As4() - // Ensure we don't wrap around and create duplicates - base[3] = byte(index % 256) - base[2] = byte((int(base[2]) + index/256) % 256) - return netip.AddrFrom4(base) -} - -// sendCfgPacketsToVS sends count packets to a virtual service from different client IPs -func sendCfgPacketsToVS( - t *testing.T, - ts *utils.TestSetup, - vsIP netip.Addr, - vsPort uint16, - count int, - clientIPStart int, -) []*framework.PacketInfo { - t.Helper() - - var outputPackets []*framework.PacketInfo - - for i := range count { - clientIP := generateCfgClientIP(clientIPStart + i) - // Use clientIPStart + i to ensure unique ports across all phases - clientPort := uint16(10000 + clientIPStart + i) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - outputPackets = append(outputPackets, result.Output[0]) - } - - return outputPackets -} - -// verifyCfgPacketDistribution verifies that packets are distributed as expected -func verifyCfgPacketDistribution( - t *testing.T, - packets []*framework.PacketInfo, - expectedCounts map[netip.Addr]int, -) { - t.Helper() - - actualCounts := utils.CountPacketsPerReal(packets) - - for realIP, expectedCount := range expectedCounts { - actualCount := actualCounts[realIP] - assert.Equal( - t, - expectedCount, - actualCount, - "packet count mismatch for real %s", - realIP, - ) - } -} - -// findCfgVsInGraph finds a virtual service in the graph by IP -func findCfgVsInGraph( - graph *balancerpb.Graph, - vsIP netip.Addr, -) *balancerpb.GraphVs { - for _, vs := range graph.VirtualServices { - addr, _ := netip.AddrFromSlice(vs.Identifier.Addr.Bytes) - if addr == vsIP { - return vs - } - } - return nil -} - -// findCfgRealInVs finds a real in a virtual service by IP -func findCfgRealInVs( - vs *balancerpb.GraphVs, - realIP netip.Addr, -) *balancerpb.GraphReal { - for _, real := range vs.Reals { - addr, _ := netip.AddrFromSlice(real.Identifier.Ip.Bytes) - if addr == realIP { - return real - } - } - return nil -} - -// TestConfigUpdateAndStats is the main test function -func TestConfigUpdateAndStats(t *testing.T) { - // Create initial configuration - config := createCfgInitialConfig() - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Phase 1: Initial State - t.Run("Phase1_InitialState", func(t *testing.T) { - testCfgPhase1InitialState(t, ts) - }) - - // Phase 2: Disable Reals - t.Run("Phase2_DisableReals", func(t *testing.T) { - testCfgPhase2DisableReals(t, ts) - }) - - // Phase 3: Verify APIs (Initial State) - t.Run("Phase3_VerifyAPIs", func(t *testing.T) { - testCfgPhase3VerifyAPIs(t, ts) - }) - - // Phase 4: Update Configuration - t.Run("Phase4_UpdateConfig", func(t *testing.T) { - testCfgPhase4UpdateConfig(t, ts) - }) - - // Phase 5: Verify APIs (After Update) - t.Run("Phase5_VerifyUpdatedAPIs", func(t *testing.T) { - testCfgPhase5VerifyUpdatedAPIs(t, ts) - }) - - // Phase 6: Verify State with New Agent - t.Run("Phase6_StateWithNewAgent", func(t *testing.T) { - testCfgPhase6StateWithNewAgent(t, ts) - }) -} - -// testCfgPhase1InitialState tests the initial state with all reals enabled -func testCfgPhase1InitialState(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Send 15 packets to VS1 - t.Log("Sending 15 packets to VS1") - vs1Packets := sendCfgPacketsToVS(t, ts, cfgVs1IP, cfgVs1Port, 15, 0) - - // Send 15 packets to VS2 - t.Log("Sending 15 packets to VS2") - vs2Packets := sendCfgPacketsToVS(t, ts, cfgVs2IP, cfgVs2Port, 15, 100) - - // Verify distribution (ROUND_ROBIN with equal weights: 5 packets per real) - t.Log("Verifying packet distribution for VS1") - verifyCfgPacketDistribution(t, vs1Packets, map[netip.Addr]int{ - cfgReal1IP: 5, - cfgReal2IP: 5, - cfgReal3IP: 5, - }) - - t.Log("Verifying packet distribution for VS2") - verifyCfgPacketDistribution(t, vs2Packets, map[netip.Addr]int{ - cfgReal4IP: 5, - cfgReal5IP: 5, - cfgReal6IP: 5, - }) -} - -// testCfgPhase2DisableReals tests disabling reals and verifying traffic routing -func testCfgPhase2DisableReals(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - - // Find VS1 and VS2 - var vs1, vs2 *balancerpb.VirtualService - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - switch addr { - case cfgVs1IP: - vs1 = vs - case cfgVs2IP: - vs2 = vs - } - } - require.NotNil(t, vs1, "VS1 not found") - require.NotNil(t, vs2, "VS2 not found") - - // Disable Real1 and Real2 in VS1, Real4 and Real5 in VS2 - enableFalse := false - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs1.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal1IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs1.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal2IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs2.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal4IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs2.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal5IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - } - - _, err := ts.Balancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to disable reals") - - // Send 5 new packets to VS1 (new client IPs) - t.Log("Sending 5 packets to VS1 after disabling reals") - vs1Packets := sendCfgPacketsToVS(t, ts, cfgVs1IP, cfgVs1Port, 5, 200) - - // Send 5 new packets to VS2 (new client IPs) - t.Log("Sending 5 packets to VS2 after disabling reals") - vs2Packets := sendCfgPacketsToVS(t, ts, cfgVs2IP, cfgVs2Port, 5, 300) - - // Verify all packets go to enabled reals only - t.Log("Verifying packets only go to Real3") - verifyCfgPacketDistribution(t, vs1Packets, map[netip.Addr]int{ - cfgReal3IP: 5, - }) - - t.Log("Verifying packets only go to Real6") - verifyCfgPacketDistribution(t, vs2Packets, map[netip.Addr]int{ - cfgReal6IP: 5, - }) -} - -// testCfgPhase3VerifyAPIs verifies all API outputs in the initial state -func testCfgPhase3VerifyAPIs(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - t.Run("VerifyConfig", func(t *testing.T) { - config := ts.Balancer.Config() - require.NotNil(t, config) - require.NotNil(t, config.PacketHandler) - - // Verify 2 virtual services - assert.Equal( - t, - 2, - len(config.PacketHandler.Vs), - "should have 2 virtual services", - ) - - // Verify each VS has 3 reals - for _, vs := range config.PacketHandler.Vs { - assert.Equal(t, 3, len(vs.Reals), "each VS should have 3 reals") - } - }) - - t.Run("VerifyGraph", func(t *testing.T) { - graph := ts.Balancer.Graph() - require.NotNil(t, graph) - require.NotNil(t, graph.VirtualServices) - - // Verify 2 virtual services - assert.Equal( - t, - 2, - len(graph.VirtualServices), - "should have 2 virtual services", - ) - - // Find VS1 and verify real states - vs1 := findCfgVsInGraph(graph, cfgVs1IP) - require.NotNil(t, vs1, "VS1 not found in graph") - assert.Equal(t, 3, len(vs1.Reals), "VS1 should have 3 reals") - - // Verify Real1: DISABLED - real1 := findCfgRealInVs(vs1, cfgReal1IP) - require.NotNil(t, real1, "Real1 not found") - assert.False(t, real1.Enabled, "Real1 should be disabled") - - // Verify Real2: DISABLED - real2 := findCfgRealInVs(vs1, cfgReal2IP) - require.NotNil(t, real2, "Real2 not found") - assert.False(t, real2.Enabled, "Real2 should be disabled") - - // Verify Real3: ENABLED - real3 := findCfgRealInVs(vs1, cfgReal3IP) - require.NotNil(t, real3, "Real3 not found") - assert.True(t, real3.Enabled, "Real3 should be enabled") - - // Find VS2 and verify real states - vs2 := findCfgVsInGraph(graph, cfgVs2IP) - require.NotNil(t, vs2, "VS2 not found in graph") - assert.Equal(t, 3, len(vs2.Reals), "VS2 should have 3 reals") - - // Verify Real4: DISABLED - real4 := findCfgRealInVs(vs2, cfgReal4IP) - require.NotNil(t, real4, "Real4 not found") - assert.False(t, real4.Enabled, "Real4 should be disabled") - - // Verify Real5: DISABLED - real5 := findCfgRealInVs(vs2, cfgReal5IP) - require.NotNil(t, real5, "Real5 not found") - assert.False(t, real5.Enabled, "Real5 should be disabled") - - // Verify Real6: ENABLED - real6 := findCfgRealInVs(vs2, cfgReal6IP) - require.NotNil(t, real6, "Real6 not found") - assert.True(t, real6.Enabled, "Real6 should be enabled") - - // Verify weights - for _, vs := range graph.VirtualServices { - for _, real := range vs.Reals { - assert.Equal( - t, - uint32(1), - real.Weight, - "all reals should have weight 1", - ) - } - } - }) - - t.Run("VerifyStats", func(t *testing.T) { - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, stats) - - // Find VS1 stats - var vs1Stats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs1IP { - vs1Stats = vs - break - } - } - require.NotNil(t, vs1Stats, "VS1 stats not found") - - // Verify VS1 stats: 20 packets total (15 initial + 5 after disable) - assert.Equal( - t, - uint64(20), - vs1Stats.Stats.IncomingPackets, - "VS1 incoming packets", - ) - assert.Equal( - t, - uint64(20), - vs1Stats.Stats.OutgoingPackets, - "VS1 outgoing packets", - ) - - // Verify Real1: 5 packets - var real1Stats *balancerpb.NamedRealStats - for _, real := range vs1Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal1IP { - real1Stats = real - break - } - } - require.NotNil(t, real1Stats, "Real1 stats not found") - assert.Equal(t, uint64(5), real1Stats.Stats.Packets, "Real1 packets") - - // Verify Real2: 5 packets - var real2Stats *balancerpb.NamedRealStats - for _, real := range vs1Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal2IP { - real2Stats = real - break - } - } - require.NotNil(t, real2Stats, "Real2 stats not found") - assert.Equal(t, uint64(5), real2Stats.Stats.Packets, "Real2 packets") - - // Verify Real3: 10 packets (5 initial + 5 after disable) - var real3Stats *balancerpb.NamedRealStats - for _, real := range vs1Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal3IP { - real3Stats = real - break - } - } - require.NotNil(t, real3Stats, "Real3 stats not found") - assert.Equal(t, uint64(10), real3Stats.Stats.Packets, "Real3 packets") - - // Find VS2 stats - var vs2Stats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs2IP { - vs2Stats = vs - break - } - } - require.NotNil(t, vs2Stats, "VS2 stats not found") - - // Verify VS2 stats: 20 packets total - assert.Equal( - t, - uint64(20), - vs2Stats.Stats.IncomingPackets, - "VS2 incoming packets", - ) - assert.Equal( - t, - uint64(20), - vs2Stats.Stats.OutgoingPackets, - "VS2 outgoing packets", - ) - - // Verify Real4: 5 packets - var real4Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal4IP { - real4Stats = real - break - } - } - require.NotNil(t, real4Stats, "Real4 stats not found") - assert.Equal(t, uint64(5), real4Stats.Stats.Packets, "Real4 packets") - - // Verify Real5: 5 packets - var real5Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal5IP { - real5Stats = real - break - } - } - require.NotNil(t, real5Stats, "Real5 stats not found") - assert.Equal(t, uint64(5), real5Stats.Stats.Packets, "Real5 packets") - - // Verify Real6: 10 packets - var real6Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal6IP { - real6Stats = real - break - } - } - require.NotNil(t, real6Stats, "Real6 stats not found") - assert.Equal(t, uint64(10), real6Stats.Stats.Packets, "Real6 packets") - }) - - t.Run("VerifyInfo", func(t *testing.T) { - info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - require.NotNil(t, info) - - // Verify total active sessions: 40 (20 from VS1 + 20 from VS2) - // Each VS: 15 initial packets + 5 after disabling = 20 sessions - assert.Equal( - t, - uint64(40), - info.ActiveSessions, - "total active sessions", - ) - - // Find VS1 info - var vs1Info *balancerpb.VsInfo - for _, vs := range info.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs1IP { - vs1Info = vs - break - } - } - require.NotNil(t, vs1Info, "VS1 info not found") - assert.Equal( - t, - uint64(20), - vs1Info.ActiveSessions, - "VS1 active sessions", - ) - - // Find VS2 info - var vs2Info *balancerpb.VsInfo - for _, vs := range info.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs2IP { - vs2Info = vs - break - } - } - require.NotNil(t, vs2Info, "VS2 info not found") - assert.Equal( - t, - uint64(20), - vs2Info.ActiveSessions, - "VS2 active sessions", - ) - }) - - t.Run("VerifySessions", func(t *testing.T) { - sessions, err := ts.Balancer.Sessions(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get sessions") - require.NotNil(t, sessions) - - // Count sessions - sessionCount := 0 - vs1Sessions := 0 - vs2Sessions := 0 - - for _, session := range sessions { - sessionCount++ - - vsAddr, _ := netip.AddrFromSlice(session.VsId.Addr.Bytes) - switch vsAddr { - case cfgVs1IP: - vs1Sessions++ - case cfgVs2IP: - vs2Sessions++ - } - } - - // Verify total sessions: 40 (20 per VS) - assert.Equal(t, 40, sessionCount, "total sessions") - assert.Equal(t, 20, vs1Sessions, "VS1 sessions") - assert.Equal(t, 20, vs2Sessions, "VS2 sessions") - }) -} - -// testCfgPhase4UpdateConfig tests updating the configuration -func testCfgPhase4UpdateConfig(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Update configuration - newConfig := createCfgUpdatedConfig() - _, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to update configuration") - - // Verify graph immediately after update (before enabling new reals) - t.Run("VerifyGraphAfterUpdate_BeforeEnabling", func(t *testing.T) { - graph := ts.Balancer.Graph() - require.NotNil(t, graph) - - // Verify 2 virtual services (VS2, VS3) - assert.Equal( - t, - 2, - len(graph.VirtualServices), - "should have 2 virtual services", - ) - - // Find VS2 - vs2 := findCfgVsInGraph(graph, cfgVs2IP) - require.NotNil(t, vs2, "VS2 not found in graph") - assert.Equal(t, 3, len(vs2.Reals), "VS2 should have 3 reals") - - // Verify Real6: ENABLED (persisted from old config) - real6 := findCfgRealInVs(vs2, cfgReal6IP) - require.NotNil(t, real6, "Real6 not found") - assert.True(t, real6.Enabled, "Real6 should be enabled (persisted)") - - // Verify Real7: DISABLED (new real) - real7 := findCfgRealInVs(vs2, cfgReal7IP) - require.NotNil(t, real7, "Real7 not found") - assert.False(t, real7.Enabled, "Real7 should be disabled (new)") - - // Verify Real8: DISABLED (new real) - real8 := findCfgRealInVs(vs2, cfgReal8IP) - require.NotNil(t, real8, "Real8 not found") - assert.False(t, real8.Enabled, "Real8 should be disabled (new)") - - // Find VS3 - vs3 := findCfgVsInGraph(graph, cfgVs3IP) - require.NotNil(t, vs3, "VS3 not found in graph") - assert.Equal(t, 3, len(vs3.Reals), "VS3 should have 3 reals") - - // Verify all VS3 reals are disabled - real9 := findCfgRealInVs(vs3, cfgReal9IP) - require.NotNil(t, real9, "Real9 not found") - assert.False(t, real9.Enabled, "Real9 should be disabled (new)") - - real10 := findCfgRealInVs(vs3, cfgReal10IP) - require.NotNil(t, real10, "Real10 not found") - assert.False(t, real10.Enabled, "Real10 should be disabled (new)") - - real11 := findCfgRealInVs(vs3, cfgReal11IP) - require.NotNil(t, real11, "Real11 not found") - assert.False(t, real11.Enabled, "Real11 should be disabled (new)") - - // Verify weights - for _, vs := range graph.VirtualServices { - for _, real := range vs.Reals { - assert.Equal( - t, - uint32(1), - real.Weight, - "all reals should have weight 1", - ) - } - } - }) - - // Enable new reals - config := ts.Balancer.Config() - var vs2, vs3 *balancerpb.VirtualService - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - switch addr { - case cfgVs2IP: - vs2 = vs - case cfgVs3IP: - vs3 = vs - } - } - require.NotNil(t, vs2, "VS2 not found") - require.NotNil(t, vs3, "VS3 not found") - - enableTrue := true - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs2.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal7IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs2.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal8IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs3.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal9IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs3.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal10IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs3.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: cfgReal11IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - } - - _, err = ts.Balancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to enable new reals") - - // Send 15 packets to VS2 - t.Log("Sending 15 packets to VS2 after update") - vs2Packets := sendCfgPacketsToVS(t, ts, cfgVs2IP, cfgVs2Port, 15, 400) - - // Send 15 packets to VS3 - t.Log("Sending 15 packets to VS3 after update") - vs3Packets := sendCfgPacketsToVS(t, ts, cfgVs3IP, cfgVs3Port, 15, 500) - - // Verify distribution - t.Log("Verifying packet distribution for VS2") - verifyCfgPacketDistribution(t, vs2Packets, map[netip.Addr]int{ - cfgReal6IP: 5, - cfgReal7IP: 5, - cfgReal8IP: 5, - }) - - t.Log("Verifying packet distribution for VS3") - verifyCfgPacketDistribution(t, vs3Packets, map[netip.Addr]int{ - cfgReal9IP: 5, - cfgReal10IP: 5, - cfgReal11IP: 5, - }) -} - -// testCfgPhase5VerifyUpdatedAPIs verifies all API outputs after configuration update -func testCfgPhase5VerifyUpdatedAPIs(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - t.Run("VerifyConfig", func(t *testing.T) { - config := ts.Balancer.Config() - require.NotNil(t, config) - require.NotNil(t, config.PacketHandler) - - // Verify only 2 virtual services (VS2, VS3) - assert.Equal( - t, - 2, - len(config.PacketHandler.Vs), - "should have 2 virtual services", - ) - - // Verify VS1 is NOT present - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - assert.NotEqual(t, cfgVs1IP, addr, "VS1 should not be present") - } - - // Verify VS2 and VS3 are present - foundVs2 := false - foundVs3 := false - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - switch addr { - case cfgVs2IP: - foundVs2 = true - assert.Equal(t, 3, len(vs.Reals), "VS2 should have 3 reals") - case cfgVs3IP: - foundVs3 = true - assert.Equal(t, 3, len(vs.Reals), "VS3 should have 3 reals") - } - } - assert.True(t, foundVs2, "VS2 should be present") - assert.True(t, foundVs3, "VS3 should be present") - }) - - t.Run("VerifyGraph", func(t *testing.T) { - graph := ts.Balancer.Graph() - require.NotNil(t, graph) - - // Verify 2 virtual services - assert.Equal( - t, - 2, - len(graph.VirtualServices), - "should have 2 virtual services", - ) - - // Verify all reals are enabled - for _, vs := range graph.VirtualServices { - for _, real := range vs.Reals { - assert.True(t, real.Enabled, "all reals should be enabled") - assert.Equal( - t, - uint32(1), - real.Weight, - "all reals should have weight 1", - ) - } - } - }) - - t.Run("VerifyStats", func(t *testing.T) { - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, stats) - - // Verify VS1 is NOT present - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - assert.NotEqual( - t, - cfgVs1IP, - addr, - "VS1 stats should not be present", - ) - } - - // Find VS2 stats - var vs2Stats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs2IP { - vs2Stats = vs - break - } - } - require.NotNil(t, vs2Stats, "VS2 stats not found") - - // Verify VS2 stats are CUMULATIVE: 35 packets (20 from old + 15 from new) - assert.Equal( - t, - uint64(35), - vs2Stats.Stats.IncomingPackets, - "VS2 incoming packets (cumulative)", - ) - assert.Equal( - t, - uint64(35), - vs2Stats.Stats.OutgoingPackets, - "VS2 outgoing packets (cumulative)", - ) - - // Verify Real6: 15 packets (10 from old + 5 from new) - var real6Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal6IP { - real6Stats = real - break - } - } - require.NotNil(t, real6Stats, "Real6 stats not found") - assert.Equal( - t, - uint64(15), - real6Stats.Stats.Packets, - "Real6 packets (cumulative)", - ) - - // Verify Real7: 5 packets (new) - var real7Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal7IP { - real7Stats = real - break - } - } - require.NotNil(t, real7Stats, "Real7 stats not found") - assert.Equal(t, uint64(5), real7Stats.Stats.Packets, "Real7 packets") - - // Verify Real8: 5 packets (new) - var real8Stats *balancerpb.NamedRealStats - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - if addr == cfgReal8IP { - real8Stats = real - break - } - } - require.NotNil(t, real8Stats, "Real8 stats not found") - assert.Equal(t, uint64(5), real8Stats.Stats.Packets, "Real8 packets") - - // Verify NO stats for Real4, Real5 (deleted) - for _, real := range vs2Stats.Reals { - addr, _ := netip.AddrFromSlice(real.Real.Real.Ip.Bytes) - assert.NotEqual( - t, - cfgReal4IP, - addr, - "Real4 stats should not be present", - ) - assert.NotEqual( - t, - cfgReal5IP, - addr, - "Real5 stats should not be present", - ) - } - - // Find VS3 stats - var vs3Stats *balancerpb.NamedVsStats - for _, vs := range stats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs3IP { - vs3Stats = vs - break - } - } - require.NotNil(t, vs3Stats, "VS3 stats not found") - - // Verify VS3 stats: 15 packets (new) - assert.Equal( - t, - uint64(15), - vs3Stats.Stats.IncomingPackets, - "VS3 incoming packets", - ) - assert.Equal( - t, - uint64(15), - vs3Stats.Stats.OutgoingPackets, - "VS3 outgoing packets", - ) - - // Verify each real has 5 packets - assert.Equal(t, 3, len(vs3Stats.Reals), "VS3 should have 3 reals") - for _, real := range vs3Stats.Reals { - assert.Equal( - t, - uint64(5), - real.Stats.Packets, - "each VS3 real should have 5 packets", - ) - } - }) - - t.Run("VerifyInfo", func(t *testing.T) { - info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - require.NotNil(t, info) - - // Verify total active sessions: 40 (25 from VS2 + 15 from VS3) - // VS1 deleted: its 20 sessions are removed - // VS2: 10 sessions to Real6 (persisted) + 15 new = 25 - // VS3: 15 new sessions - assert.Equal( - t, - uint64(40), - info.ActiveSessions, - "total active sessions", - ) - - // Verify VS1 is NOT present - for _, vs := range info.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - assert.NotEqual(t, cfgVs1IP, addr, "VS1 info should not be present") - } - - // Find VS2 info - var vs2Info *balancerpb.VsInfo - for _, vs := range info.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs2IP { - vs2Info = vs - break - } - } - require.NotNil(t, vs2Info, "VS2 info not found") - // VS2: 10 old sessions (to Real6) + 15 new = 25 - assert.Equal( - t, - uint64(25), - vs2Info.ActiveSessions, - "VS2 active sessions", - ) - - // Find VS3 info - var vs3Info *balancerpb.VsInfo - for _, vs := range info.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs3IP { - vs3Info = vs - break - } - } - require.NotNil(t, vs3Info, "VS3 info not found") - assert.Equal( - t, - uint64(15), - vs3Info.ActiveSessions, - "VS3 active sessions", - ) - }) - - t.Run("VerifySessions", func(t *testing.T) { - sessions, err := ts.Balancer.Sessions(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get sessions") - require.NotNil(t, sessions) - - // Count sessions - sessionCount := 0 - vs1Sessions := 0 - vs2Sessions := 0 - vs3Sessions := 0 - - for _, session := range sessions { - sessionCount++ - - vsAddr, _ := netip.AddrFromSlice(session.VsId.Addr.Bytes) - switch vsAddr { - case cfgVs1IP: - vs1Sessions++ - case cfgVs2IP: - vs2Sessions++ - case cfgVs3IP: - vs3Sessions++ - } - } - - // Verify total sessions: 40 (25 VS2 + 15 VS3) - assert.Equal(t, 40, sessionCount, "total sessions") - - // Verify NO sessions for VS1 (deleted) - assert.Equal(t, 0, vs1Sessions, "VS1 should have no sessions") - - // Verify sessions for VS2 and VS3 - // VS2: 10 old (to Real6) + 15 new = 25 - assert.Equal(t, 25, vs2Sessions, "VS2 sessions") - assert.Equal(t, 15, vs3Sessions, "VS3 sessions") - }) -} - -// testCfgPhase6StateWithNewAgent tests state persistence by creating a new balancer agent -// that attaches to the same shared memory and verifies all API outputs match Phase 5 -func testCfgPhase6StateWithNewAgent(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Store Phase 5 outputs for comparison - t.Run("StorePhase5Outputs", func(t *testing.T) { - // These will be captured in the parent scope for comparison - t.Log("Phase 5 outputs will be compared with new agent outputs") - }) - - // Get Phase 5 API outputs before creating new agent - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - phase5Config := ts.Balancer.Config() - phase5Graph := ts.Balancer.Graph() - phase5Stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err, "failed to get Phase 5 stats") - phase5Info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get Phase 5 info") - phase5Sessions, err := ts.Balancer.Sessions(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get Phase 5 sessions") - - // Create new balancer agent and attach to existing manager - t.Run("CreateNewAgentAndAttach", func(t *testing.T) { - // Create new BalancerAgent using same shared memory - logLevel := zapcore.InfoLevel - sugaredLogger, _, _ := logging.Init(&logging.Config{ - Level: logLevel, - }) - - agentMemory := 16 * datasize.MB - newAgent, err := balancer.NewBalancerAgent( - ts.Mock.SharedMemory(), // Same shared memory - agentMemory, - sugaredLogger, - ) - require.NoError(t, err, "failed to create new balancer agent") - - // Attach to existing BalancerManager - newBalancer, err := newAgent.BalancerManager(utils.BalancerName) - require.NoError(t, err, "failed to attach to existing balancer manager") - require.NotNil(t, newBalancer, "balancer manager should not be nil") - - // Verify Config matches Phase 5 - t.Run("VerifyConfigMatches", func(t *testing.T) { - newConfig := newBalancer.Config() - require.NotNil(t, newConfig) - require.NotNil(t, newConfig.PacketHandler) - - // Verify same number of virtual services - assert.Equal( - t, - len(phase5Config.PacketHandler.Vs), - len(newConfig.PacketHandler.Vs), - "should have same number of virtual services", - ) - - // Verify VS2 and VS3 are present - assert.Equal( - t, - 2, - len(newConfig.PacketHandler.Vs), - "should have 2 virtual services", - ) - - // Verify each VS has 3 reals - for _, vs := range newConfig.PacketHandler.Vs { - assert.Equal(t, 3, len(vs.Reals), "each VS should have 3 reals") - } - }) - - // Verify Graph matches Phase 5 - t.Run("VerifyGraphMatches", func(t *testing.T) { - newGraph := newBalancer.Graph() - require.NotNil(t, newGraph) - - // Verify same number of virtual services - assert.Equal( - t, - len(phase5Graph.VirtualServices), - len(newGraph.VirtualServices), - "should have same number of virtual services", - ) - - // Verify all reals are enabled (as in Phase 5) - for _, vs := range newGraph.VirtualServices { - for _, real := range vs.Reals { - assert.True(t, real.Enabled, "all reals should be enabled") - assert.Equal( - t, - uint32(1), - real.Weight, - "all reals should have weight 1", - ) - } - } - }) - - // Verify Stats match Phase 5 - t.Run("VerifyStatsMatch", func(t *testing.T) { - newStats, err := newBalancer.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, newStats) - - // Verify same number of VS stats - assert.Equal( - t, - len(phase5Stats.Vs), - len(newStats.Vs), - "should have same number of VS stats", - ) - - // Find VS2 stats - var vs2Stats *balancerpb.NamedVsStats - for _, vs := range newStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs2IP { - vs2Stats = vs - break - } - } - require.NotNil(t, vs2Stats, "VS2 stats not found") - - // Verify VS2 stats: 35 packets (cumulative from Phase 5) - assert.Equal( - t, - uint64(35), - vs2Stats.Stats.IncomingPackets, - "VS2 incoming packets should match Phase 5", - ) - assert.Equal( - t, - uint64(35), - vs2Stats.Stats.OutgoingPackets, - "VS2 outgoing packets should match Phase 5", - ) - - // Find VS3 stats - var vs3Stats *balancerpb.NamedVsStats - for _, vs := range newStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs3IP { - vs3Stats = vs - break - } - } - require.NotNil(t, vs3Stats, "VS3 stats not found") - - // Verify VS3 stats: 15 packets - assert.Equal( - t, - uint64(15), - vs3Stats.Stats.IncomingPackets, - "VS3 incoming packets should match Phase 5", - ) - assert.Equal( - t, - uint64(15), - vs3Stats.Stats.OutgoingPackets, - "VS3 outgoing packets should match Phase 5", - ) - }) - - // Verify Info matches Phase 5 - t.Run("VerifyInfoMatches", func(t *testing.T) { - newInfo, err := newBalancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - require.NotNil(t, newInfo) - - // Verify total active sessions: 40 (from Phase 5) - assert.Equal( - t, - phase5Info.ActiveSessions, - newInfo.ActiveSessions, - "total active sessions should match Phase 5", - ) - assert.Equal( - t, - uint64(40), - newInfo.ActiveSessions, - "total active sessions should be 40", - ) - - // Find VS2 info - var vs2Info *balancerpb.VsInfo - for _, vs := range newInfo.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs2IP { - vs2Info = vs - break - } - } - require.NotNil(t, vs2Info, "VS2 info not found") - assert.Equal( - t, - uint64(25), - vs2Info.ActiveSessions, - "VS2 active sessions should match Phase 5", - ) - - // Find VS3 info - var vs3Info *balancerpb.VsInfo - for _, vs := range newInfo.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs3IP { - vs3Info = vs - break - } - } - require.NotNil(t, vs3Info, "VS3 info not found") - assert.Equal( - t, - uint64(15), - vs3Info.ActiveSessions, - "VS3 active sessions should match Phase 5", - ) - }) - - // Verify Sessions match Phase 5 - t.Run("VerifySessionsMatch", func(t *testing.T) { - newSessions, err := newBalancer.Sessions(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get sessions") - require.NotNil(t, newSessions) - - // Verify same number of sessions - assert.Equal( - t, - len(phase5Sessions), - len(newSessions), - "should have same number of sessions as Phase 5", - ) - assert.Equal(t, 40, len(newSessions), "should have 40 sessions") - - // Count sessions per VS - vs2Sessions := 0 - vs3Sessions := 0 - for _, session := range newSessions { - vsAddr, _ := netip.AddrFromSlice(session.VsId.Addr.Bytes) - switch vsAddr { - case cfgVs2IP: - vs2Sessions++ - case cfgVs3IP: - vs3Sessions++ - } - } - - assert.Equal(t, 25, vs2Sessions, "VS2 should have 25 sessions") - assert.Equal(t, 15, vs3Sessions, "VS3 should have 15 sessions") - }) - - // Send new packets through the new balancer - t.Run("SendNewPackets", func(t *testing.T) { - // Send 10 packets to VS2 (new client IPs starting at 600) - t.Log("Sending 10 packets to VS2 through new balancer") - vs2Packets := sendCfgPacketsToVS( - t, - ts, - cfgVs2IP, - cfgVs2Port, - 10, - 600, - ) - - // Send 10 packets to VS3 (new client IPs starting at 700) - t.Log("Sending 10 packets to VS3 through new balancer") - vs3Packets := sendCfgPacketsToVS( - t, - ts, - cfgVs3IP, - cfgVs3Port, - 10, - 700, - ) - - // Verify distribution (ROUND_ROBIN: ~3-4 packets per real) - t.Log("Verifying packet distribution for VS2") - verifyCfgPacketDistribution(t, vs2Packets, map[netip.Addr]int{ - cfgReal6IP: 4, - cfgReal7IP: 3, - cfgReal8IP: 3, - }) - - t.Log("Verifying packet distribution for VS3") - verifyCfgPacketDistribution(t, vs3Packets, map[netip.Addr]int{ - cfgReal9IP: 4, - cfgReal10IP: 3, - cfgReal11IP: 3, - }) - }) - - // Verify Stats updated correctly - t.Run("VerifyStatsUpdated", func(t *testing.T) { - newStats, err := newBalancer.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, newStats) - - // Find VS2 stats - var vs2Stats *balancerpb.NamedVsStats - for _, vs := range newStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs2IP { - vs2Stats = vs - break - } - } - require.NotNil(t, vs2Stats, "VS2 stats not found") - - // Verify VS2 stats: 45 packets (35 from Phase 5 + 10 new) - assert.Equal( - t, - uint64(45), - vs2Stats.Stats.IncomingPackets, - "VS2 incoming packets should be 45 (35+10)", - ) - assert.Equal( - t, - uint64(45), - vs2Stats.Stats.OutgoingPackets, - "VS2 outgoing packets should be 45 (35+10)", - ) - - // Find VS3 stats - var vs3Stats *balancerpb.NamedVsStats - for _, vs := range newStats.Vs { - addr, _ := netip.AddrFromSlice(vs.Vs.Addr.Bytes) - if addr == cfgVs3IP { - vs3Stats = vs - break - } - } - require.NotNil(t, vs3Stats, "VS3 stats not found") - - // Verify VS3 stats: 25 packets (15 from Phase 5 + 10 new) - assert.Equal( - t, - uint64(25), - vs3Stats.Stats.IncomingPackets, - "VS3 incoming packets should be 25 (15+10)", - ) - assert.Equal( - t, - uint64(25), - vs3Stats.Stats.OutgoingPackets, - "VS3 outgoing packets should be 25 (15+10)", - ) - }) - - // Verify Info updated correctly - t.Run("VerifyInfoUpdated", func(t *testing.T) { - newInfo, err := newBalancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - require.NotNil(t, newInfo) - - // Verify total active sessions: 60 (40 from Phase 5 + 20 new) - assert.Equal( - t, - uint64(60), - newInfo.ActiveSessions, - "total active sessions should be 60 (40+20)", - ) - - // Find VS2 info - var vs2Info *balancerpb.VsInfo - for _, vs := range newInfo.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs2IP { - vs2Info = vs - break - } - } - require.NotNil(t, vs2Info, "VS2 info not found") - assert.Equal( - t, - uint64(35), - vs2Info.ActiveSessions, - "VS2 active sessions should be 35 (25+10)", - ) - - // Find VS3 info - var vs3Info *balancerpb.VsInfo - for _, vs := range newInfo.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == cfgVs3IP { - vs3Info = vs - break - } - } - require.NotNil(t, vs3Info, "VS3 info not found") - assert.Equal( - t, - uint64(25), - vs3Info.ActiveSessions, - "VS3 active sessions should be 25 (15+10)", - ) - }) - - // Verify Sessions updated correctly - t.Run("VerifySessionsUpdated", func(t *testing.T) { - newSessions, err := newBalancer.Sessions(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get sessions") - require.NotNil(t, newSessions) - - // Verify total sessions: 60 (40 from Phase 5 + 20 new) - assert.Equal( - t, - 60, - len(newSessions), - "should have 60 total sessions (40+20)", - ) - - // Count sessions per VS - vs2Sessions := 0 - vs3Sessions := 0 - for _, session := range newSessions { - vsAddr, _ := netip.AddrFromSlice(session.VsId.Addr.Bytes) - switch vsAddr { - case cfgVs2IP: - vs2Sessions++ - case cfgVs3IP: - vs3Sessions++ - } - } - - assert.Equal( - t, - 35, - vs2Sessions, - "VS2 should have 35 sessions (25+10)", - ) - assert.Equal( - t, - 25, - vs3Sessions, - "VS3 should have 25 sessions (15+10)", - ) - }) - }) -} diff --git a/modules/balancer/tests/go/filter_reuse_test.go b/modules/balancer/tests/go/filter_reuse_test.go deleted file mode 100644 index 9183d2e8c..000000000 --- a/modules/balancer/tests/go/filter_reuse_test.go +++ /dev/null @@ -1,1372 +0,0 @@ -package balancer_test - -// TestFilterReuse is a comprehensive test that verifies the balancer's filter reuse logic -// during configuration updates. It tests various scenarios with 20 IPv4 and 20 IPv6 virtual -// services to ensure that: -// -// 1. IPv4 VS matcher is reused when the set of IPv4 virtual services remains the same -// 2. IPv6 VS matcher is reused when the set of IPv6 virtual services remains the same -// 3. ACL filters are reused when allowed_srcs configuration remains the same -// 4. ACL comparison is order-independent (different order = same ACL) -// 5. ACL comparison is duplicate-tolerant (duplicates don't affect equality) -// -// Test Phases: -// - Phase 1: Initial configuration with 20 IPv4 + 20 IPv6 VS -// - Phase 2: Same IPv4 set, different IPv6 set -// - Phase 3: Different IPv4 set, same IPv6 set -// - Phase 4: Same VS sets, different ACL for some VS -// - Phase 5: Same VS sets, same ACL with different order -// - Phase 6: Same VS sets, same ACL with duplicates -// - Phase 7: Completely different configuration -// - Phase 8: Identical configuration (everything reused) -// - Phase 9: Edge cases (empty ACL, mixed protocols, partial changes) - -import ( - "fmt" - "math/rand" - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Test constants for filter reuse tests -const ( - frIPv4VSCount = 20 - frIPv6VSCount = 20 - frRealsPerVS = 2 -) - -// frGenerateIPv4Addr generates a deterministic IPv4 address based on index -func frGenerateIPv4Addr(index int) netip.Addr { - // Generate addresses in 10.0.x.x range - return netip.AddrFrom4([4]byte{10, 0, byte(index / 256), byte(index % 256)}) -} - -// frGenerateIPv6Addr generates a deterministic IPv6 address based on index -func frGenerateIPv6Addr(index int) netip.Addr { - // Generate addresses in 2001:db8::/32 range - return netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, - byte(index >> 24), byte(index >> 16), byte(index >> 8), byte(index), - 0, 0, 0, 0, 0, 0, 0, 1, - }) -} - -// frGenerateClientIPv4Addr generates a client IPv4 address that matches ACL rules -// For simple ACL (0.0.0.0/0), any address works -// For complex ACL, we use 10.0.x.x range which matches the 10.0.0.0/8 rule -func frGenerateClientIPv4Addr(index int) netip.Addr { - // Generate addresses in 10.0.x.x range to match complex ACL rule (10.0.0.0/8) - return netip.AddrFrom4( - [4]byte{10, 0, byte((index / 256) % 256), byte(index % 256)}, - ) -} - -// frGenerateClientIPv6Addr generates a client IPv6 address that matches ACL rules -// For simple ACL (::/0), any address works -// For complex ACL, we use 2001:db8:1::/48 range which matches the ACL rule -func frGenerateClientIPv6Addr(index int) netip.Addr { - // Generate addresses in 2001:db8:1::/48 range to match complex ACL rule - return netip.AddrFrom16([16]byte{ - 0x20, 0x01, 0x0d, 0xb8, - 0x00, 0x01, // 2001:db8:1:: - byte(index >> 8), byte(index), - 0, 0, 0, 0, 0, 0, 0, 1, - }) -} - -// frGetVSProtocol returns the protocol for a VS at the given index -// This matches the logic in frCreateVSSet -func frGetVSProtocol(index int) balancerpb.TransportProto { - if index%4 == 0 { - return balancerpb.TransportProto_UDP - } - return balancerpb.TransportProto_TCP -} - -// frGenerateRealIPv4Addr generates a deterministic IPv4 address for real servers -func frGenerateRealIPv4Addr(vsIndex, realIndex int) netip.Addr { - // Generate addresses in 192.168.x.x range - return netip.AddrFrom4([4]byte{192, 168, byte(vsIndex), byte(realIndex)}) -} - -// frGenerateRealIPv6Addr generates a deterministic IPv6 address for real servers -func frGenerateRealIPv6Addr(vsIndex, realIndex int) netip.Addr { - // Generate addresses in fd00::/8 range - return netip.AddrFrom16([16]byte{ - 0xfd, 0x00, 0, 0, - byte(vsIndex >> 8), byte(vsIndex), byte(realIndex), 0, - 0, 0, 0, 0, 0, 0, 0, 1, - }) -} - -// frCreateReal creates a real server configuration -func frCreateReal(ip netip.Addr, weight uint32) *balancerpb.Real { - var srcAddr, srcMask netip.Addr - if ip.Is4() { - srcAddr = netip.AddrFrom4([4]byte{172, 16, 0, 1}) - srcMask = netip.AddrFrom4([4]byte{255, 255, 255, 255}) - } else { - srcAddr = netip.AddrFrom16([16]byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) - srcMask = netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - } - - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: 0, - }, - Weight: weight, - SrcAddr: &balancerpb.Addr{ - Bytes: srcAddr.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: srcMask.AsSlice(), - }, - } -} - -// frCreateSimpleACL creates a simple allow-all ACL -func frCreateSimpleACL(isIPv6 bool) []*balancerpb.AllowedSources { - var addr, mask netip.Addr - if isIPv6 { - addr = netip.AddrFrom16([16]byte{}) - mask = netip.AddrFrom16([16]byte{}) - } else { - addr = netip.AddrFrom4([4]byte{0, 0, 0, 0}) - mask = netip.AddrFrom4([4]byte{0, 0, 0, 0}) - } - - return []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: addr.AsSlice()}, - Mask: &balancerpb.Addr{Bytes: mask.AsSlice()}, - }}, - }, - } -} - -// frCreateComplexACL creates a complex ACL with multiple rules and port ranges -func frCreateComplexACL(index int, isIPv6 bool) []*balancerpb.AllowedSources { - var acl []*balancerpb.AllowedSources - - if isIPv6 { - // Rule 1: Allow from 2001:db8:1::/48 - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, - }, - }) - - // Rule 2: Allow from 2001:db8:2::/48 with specific ports - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - }, - }) - } else { - // Rule 1: Allow from 10.0.0.0/8 - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{10, 0, 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 0, 0, 0}).AsSlice()}, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 1024, To: 65535}, - }, - }) - - // Rule 2: Allow from 192.168.0.0/16 with specific ports - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{192, 168, 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice()}, - }}, - Ports: []*balancerpb.PortsRange{ - {From: 80, To: 80}, - {From: 443, To: 443}, - }, - }) - } - - // Add index-specific rule for variation - if index%2 == 0 { - if isIPv6 { - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - }) - } else { - acl = append(acl, &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{172, 16, 0, 0}).AsSlice()}, - Mask: &balancerpb.Addr{Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}).AsSlice()}, - }}, - }) - } - } - - return acl -} - -// frShuffleACL returns a new ACL with rules in random order -func frShuffleACL( - acl []*balancerpb.AllowedSources, - rng *rand.Rand, -) []*balancerpb.AllowedSources { - shuffled := make([]*balancerpb.AllowedSources, len(acl)) - copy(shuffled, acl) - rng.Shuffle(len(shuffled), func(i, j int) { - shuffled[i], shuffled[j] = shuffled[j], shuffled[i] - }) - return shuffled -} - -// frDuplicateACLRules returns a new ACL with some rules duplicated -func frDuplicateACLRules( - acl []*balancerpb.AllowedSources, -) []*balancerpb.AllowedSources { - if len(acl) == 0 { - return acl - } - - duplicated := make([]*balancerpb.AllowedSources, 0, len(acl)*2) - for i, rule := range acl { - duplicated = append(duplicated, rule) - // Duplicate every other rule - if i%2 == 0 { - duplicated = append(duplicated, rule) - } - } - return duplicated -} - -// frCreateVirtualService creates a virtual service with given parameters -func frCreateVirtualService( - ip netip.Addr, - port uint16, - proto balancerpb.TransportProto, - reals []*balancerpb.Real, - acl []*balancerpb.AllowedSources, - scheduler balancerpb.VsScheduler, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: proto, - }, - Scheduler: scheduler, - AllowedSrcs: acl, - Reals: reals, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Peers: []*balancerpb.Addr{}, - } -} - -// frCreateVSSet creates a set of virtual services (IPv4 or IPv6) -func frCreateVSSet( - count int, - isIPv6 bool, - baseIndex int, - useComplexACL bool, -) []*balancerpb.VirtualService { - vsList := make([]*balancerpb.VirtualService, 0, count) - - for i := 0; i < count; i++ { - var vsIP netip.Addr - if isIPv6 { - vsIP = frGenerateIPv6Addr(baseIndex + i) - } else { - vsIP = frGenerateIPv4Addr(baseIndex + i) - } - - // Create reals for this VS - reals := make([]*balancerpb.Real, 0, frRealsPerVS) - for j := 0; j < frRealsPerVS; j++ { - var realIP netip.Addr - if isIPv6 { - realIP = frGenerateRealIPv6Addr(baseIndex+i, j) - } else { - realIP = frGenerateRealIPv4Addr(baseIndex+i, j) - } - reals = append(reals, frCreateReal(realIP, 1)) - } - - // Create ACL - var acl []*balancerpb.AllowedSources - if useComplexACL && i%3 == 0 { - // Use complex ACL for every 3rd VS - acl = frCreateComplexACL(i, isIPv6) - } else { - // Use simple ACL for others - acl = frCreateSimpleACL(isIPv6) - } - - // Alternate between TCP and UDP - proto := balancerpb.TransportProto_TCP - if i%4 == 0 { - proto = balancerpb.TransportProto_UDP - } - - // Alternate between schedulers - scheduler := balancerpb.VsScheduler_ROUND_ROBIN - if i%2 == 0 { - scheduler = balancerpb.VsScheduler_SOURCE_HASH - } - - vs := frCreateVirtualService(vsIP, 80, proto, reals, acl, scheduler) - vsList = append(vsList, vs) - } - - return vsList -} - -// frCreateConfig creates a balancer configuration with given VS sets -func frCreateConfig( - ipv4VS, ipv6VS []*balancerpb.VirtualService, -) *balancerpb.BalancerConfig { - allVS := make([]*balancerpb.VirtualService, 0, len(ipv4VS)+len(ipv6VS)) - allVS = append(allVS, ipv4VS...) - allVS = append(allVS, ipv6VS...) - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: allVS, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(10000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// frVerifyUpdateInfo verifies the UpdateInfo fields match expectations -func frVerifyUpdateInfo( - t *testing.T, - updateInfo *ffi.UpdateInfo, - expectedIPv4Reused bool, - expectedIPv6Reused bool, - expectedACLReusedCount int, -) { - t.Helper() - - assert.Equal(t, expectedIPv4Reused, updateInfo.VsIpv4MatcherReused, - "IPv4 VS matcher reuse mismatch") - assert.Equal(t, expectedIPv6Reused, updateInfo.VsIpv6MatcherReused, - "IPv6 VS matcher reuse mismatch") - assert.Equal(t, expectedACLReusedCount, len(updateInfo.ACLReusedVs), - "ACL reused count mismatch") -} - -// frVerifyACLReusedVS verifies that specific VS indices have ACL reused -// expectedIndices should be a map of VS index to expected reuse status -func frVerifyACLReusedVS( - t *testing.T, - updateInfo *ffi.UpdateInfo, - vsList []*balancerpb.VirtualService, - expectedIndices map[int]bool, -) { - t.Helper() - - // Build a map of VS identifiers that have ACL reused - reusedVSMap := make(map[string]bool) - for _, vsID := range updateInfo.ACLReusedVs { - addr, _ := netip.AddrFromSlice(vsID.Addr.AsSlice()) - key := fmt.Sprintf("%s:%d/%d", addr, vsID.Port, vsID.TransportProto) - reusedVSMap[key] = true - } - - // Check each expected index - for idx, shouldBeReused := range expectedIndices { - if idx >= len(vsList) { - t.Errorf( - "Index %d out of range (VS list has %d elements)", - idx, - len(vsList), - ) - continue - } - - vs := vsList[idx] - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - key := fmt.Sprintf("%s:%d/%d", addr, vs.Id.Port, vs.Id.Proto) - - isReused := reusedVSMap[key] - if shouldBeReused && !isReused { - t.Errorf( - "VS at index %d (%s) should have ACL reused but doesn't", - idx, - key, - ) - } else if !shouldBeReused && isReused { - t.Errorf("VS at index %d (%s) should NOT have ACL reused but does", idx, key) - } - } -} - -// frSendTestPackets sends test packets to a VS and verifies they are processed -func frSendTestPackets( - t *testing.T, - ts *utils.TestSetup, - vsIP netip.Addr, - vsPort uint16, - proto balancerpb.TransportProto, - count int, - clientBaseIndex int, -) { - t.Helper() - - for i := range count { - var clientIP netip.Addr - if vsIP.Is4() { - // Use client IPs that match ACL rules (10.0.0.0/8 for complex ACL) - clientIP = frGenerateClientIPv4Addr(1000 + clientBaseIndex + i) - } else { - // Use client IPs that match ACL rules (2001:db8:1::/48 for complex ACL) - clientIP = frGenerateClientIPv6Addr(1000 + clientBaseIndex + i) - } - clientPort := uint16(10000 + i) - - var packetLayers []gopacket.SerializableLayer - if proto == balancerpb.TransportProto_TCP { - packetLayers = utils.MakeTCPPacket( - clientIP, - clientPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - } else { - packetLayers = utils.MakeUDPPacket( - clientIP, - clientPort, - vsIP, - vsPort, - ) - } - - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "expected 1 output packet for VS %s", - vsIP, - ) - require.Empty( - t, - result.Drop, - "expected no dropped packets for VS %s", - vsIP, - ) - } -} - -// TestFilterReuse is the main test function -func TestFilterReuse(t *testing.T) { - // Setup test with appropriate memory - agentMemory := 512 * datasize.MB - - // Create initial configuration - ipv4VS := frCreateVSSet(frIPv4VSCount, false, 0, true) - ipv6VS := frCreateVSSet(frIPv6VSCount, true, 0, true) - initialConfig := frCreateConfig(ipv4VS, ipv6VS) - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig( - 1024*datasize.MB, - 16*datasize.MB, - ), - Balancer: initialConfig, - AgentMemory: &agentMemory, - }) - require.NoError(t, err) - defer ts.Free() - - t.Logf( - "Setup initial config with %d IPv4 and %d IPv6 virtual services", - frIPv4VSCount, - frIPv6VSCount, - ) - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Phase 1: Verify initial setup works - t.Run("Phase1_InitialSetup", func(t *testing.T) { - t.Log("Verifying initial configuration is accessible") - - // Test a few IPv4 VS - for i := 0; i < 3; i++ { - vsIP := frGenerateIPv4Addr(i) - proto := balancerpb.TransportProto_TCP - if i%4 == 0 { - proto = balancerpb.TransportProto_UDP - } - frSendTestPackets(t, ts, vsIP, 80, proto, 2, i*100) - } - - // Test a few IPv6 VS - for i := 0; i < 3; i++ { - vsIP := frGenerateIPv6Addr(i) - proto := balancerpb.TransportProto_TCP - if i%4 == 0 { - proto = balancerpb.TransportProto_UDP - } - frSendTestPackets(t, ts, vsIP, 80, proto, 2, i*100+1000) - } - - t.Log("Initial configuration verified successfully") - }) - - // Phase 2: Same IPv4 VS set, Different IPv6 VS set - t.Run("Phase2_SameIPv4_DifferentIPv6", func(t *testing.T) { - t.Log("Testing: Same IPv4 set, different IPv6 set") - - // Keep IPv4 VS the same - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 0, true) - - // Create different IPv6 VS (different base index) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 1000, true) - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Verify: IPv4 matcher reused, IPv6 matcher NOT reused - frVerifyUpdateInfo(t, updateInfo, true, false, frIPv4VSCount) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Enable new reals and test - utils.EnableAllReals(t, ts) - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(1), - 80, - balancerpb.TransportProto_TCP, - 2, - 2000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(1001), - 80, - balancerpb.TransportProto_TCP, - 2, - 2100, - ) - }) - - // Phase 3: Different IPv4 VS set, Same IPv6 VS set - t.Run("Phase3_DifferentIPv4_SameIPv6", func(t *testing.T) { - t.Log("Testing: Different IPv4 set, same IPv6 set") - - // Create different IPv4 VS (different base index) - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 2000, true) - - // Keep IPv6 VS the same as Phase 2 - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 1000, true) - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Verify: IPv4 matcher NOT reused, IPv6 matcher reused - frVerifyUpdateInfo(t, updateInfo, false, true, frIPv6VSCount) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Enable new reals and test - utils.EnableAllReals(t, ts) - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(2001), - 80, - balancerpb.TransportProto_TCP, - 2, - 3000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(1001), - 80, - balancerpb.TransportProto_TCP, - 2, - 3100, - ) - }) - - // Phase 4: Same VS sets, Different ACL for some VS - t.Run("Phase4_SameVS_DifferentACLForSome", func(t *testing.T) { - t.Log("Testing: Same VS sets, different ACL for some VS") - - // Keep VS identifiers the same as Phase 3 - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 2000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 1000, true) - - // Change ACL for 5 IPv4 VS and 5 IPv6 VS by adding a unique rule - // This ensures the ACL is truly different from the original - for i := 0; i < 5; i++ { - // Add a unique network rule that makes this ACL different - uniqueRule := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{100, byte(i), 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv4VS[i].AllowedSrcs = append( - newIPv4VS[i].AllowedSrcs, - uniqueRule, - ) - - uniqueRuleV6 := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0x00, byte(i), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv6VS[i].AllowedSrcs = append( - newIPv6VS[i].AllowedSrcs, - uniqueRuleV6, - ) - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Expected behavior: - // - Both matchers reused (VS identifiers unchanged) - // - 30 VS have ACL reused: 15 unchanged IPv4 + 15 unchanged IPv6 - // - 10 VS have different ACL: 5 modified IPv4 + 5 modified IPv6 - assert.True( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should be reused", - ) - assert.True( - t, - updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused", - ) - assert.Equal( - t, - 30, - len(updateInfo.ACLReusedVs), - "30 VS should have ACL reused (15 unchanged IPv4 + 15 unchanged IPv6)", - ) - - // Verify specific VS indices have correct ACL reuse status - allVS := append(newIPv4VS, newIPv6VS...) - expectedReuse := make(map[int]bool) - // First 5 IPv4 VS (indices 0-4) have modified ACL - should NOT be reused - for i := range 5 { - expectedReuse[i] = false - } - // Remaining 15 IPv4 VS (indices 5-19) have unchanged ACL - should be reused - for i := 5; i < 20; i++ { - expectedReuse[i] = true - } - // First 5 IPv6 VS (indices 20-24) have modified ACL - should NOT be reused - for i := 20; i < 25; i++ { - expectedReuse[i] = false - } - // Remaining 15 IPv6 VS (indices 25-39) have unchanged ACL - should be reused - for i := 25; i < 40; i++ { - expectedReuse[i] = true - } - frVerifyACLReusedVS(t, updateInfo, allVS, expectedReuse) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Test packet processing still works - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(2001), - 80, - balancerpb.TransportProto_TCP, - 2, - 4000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(1001), - 80, - balancerpb.TransportProto_TCP, - 2, - 4100, - ) - }) - - // Phase 5: Same VS sets, Same ACL with different order - t.Run("Phase5_SameVS_SameACL_DifferentOrder", func(t *testing.T) { - t.Log("Testing: Same VS sets, same ACL with different order") - - rng := rand.New(rand.NewSource(42)) - - // Keep VS identifiers the same as Phase 4 - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 2000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 1000, true) - - // Apply the SAME ACL changes as Phase 4 (add unique rules to first 5 VS) - for i := range 5 { - uniqueRule := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{100, byte(i), 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv4VS[i].AllowedSrcs = append( - newIPv4VS[i].AllowedSrcs, - uniqueRule, - ) - - uniqueRuleV6 := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0x00, byte(i), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv6VS[i].AllowedSrcs = append( - newIPv6VS[i].AllowedSrcs, - uniqueRuleV6, - ) - } - - // Shuffle ACL order for all VS - this should NOT affect ACL equality - for i := range newIPv4VS { - newIPv4VS[i].AllowedSrcs = frShuffleACL( - newIPv4VS[i].AllowedSrcs, - rng, - ) - } - for i := range newIPv6VS { - newIPv6VS[i].AllowedSrcs = frShuffleACL( - newIPv6VS[i].AllowedSrcs, - rng, - ) - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Expected behavior: - // - Both matchers reused (VS identifiers unchanged) - // - ALL VS have ACL reused because order doesn't matter for ACL comparison - assert.True( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should be reused", - ) - assert.True( - t, - updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused", - ) - assert.Equal( - t, - frIPv4VSCount+frIPv6VSCount, - len(updateInfo.ACLReusedVs), - "All 40 VS should have ACL reused (order doesn't matter)", - ) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Test packet processing still works - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(2001), - 80, - balancerpb.TransportProto_TCP, - 2, - 5000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(1001), - 80, - balancerpb.TransportProto_TCP, - 2, - 5100, - ) - }) - - // Phase 6: Same VS sets, Same ACL with duplicates - t.Run("Phase6_SameVS_SameACL_WithDuplicates", func(t *testing.T) { - t.Log("Testing: Same VS sets, same ACL with duplicates") - - rng := rand.New(rand.NewSource(42)) - - // Keep VS identifiers the same as Phase 5 - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 2000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 1000, true) - - // Apply the SAME ACL changes as Phase 4/5 (add unique rules to first 5 VS) - for i := 0; i < 5; i++ { - uniqueRule := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{100, byte(i), 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom4([4]byte{255, 255, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv4VS[i].AllowedSrcs = append( - newIPv4VS[i].AllowedSrcs, - uniqueRule, - ) - - uniqueRuleV6 := &balancerpb.AllowedSources{ - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0x00, byte(i), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.AddrFrom16([16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - AsSlice(), - }, - }}, - } - newIPv6VS[i].AllowedSrcs = append( - newIPv6VS[i].AllowedSrcs, - uniqueRuleV6, - ) - } - - // Shuffle ACL order (same as Phase 5) - for i := range newIPv4VS { - newIPv4VS[i].AllowedSrcs = frShuffleACL( - newIPv4VS[i].AllowedSrcs, - rng, - ) - } - for i := range newIPv6VS { - newIPv6VS[i].AllowedSrcs = frShuffleACL( - newIPv6VS[i].AllowedSrcs, - rng, - ) - } - - // Add duplicates to ACL - this should NOT affect ACL equality - for i := range newIPv4VS { - newIPv4VS[i].AllowedSrcs = frDuplicateACLRules( - newIPv4VS[i].AllowedSrcs, - ) - } - for i := range newIPv6VS { - newIPv6VS[i].AllowedSrcs = frDuplicateACLRules( - newIPv6VS[i].AllowedSrcs, - ) - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Expected behavior: - // - Both matchers reused (VS identifiers unchanged) - // - ALL VS have ACL reused because duplicates don't matter for ACL comparison - assert.True( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should be reused", - ) - assert.True( - t, - updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused", - ) - assert.Equal( - t, - frIPv4VSCount+frIPv6VSCount, - len(updateInfo.ACLReusedVs), - "All 40 VS should have ACL reused (duplicates don't matter)", - ) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Test packet processing still works - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(2001), - 80, - balancerpb.TransportProto_TCP, - 2, - 6000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(1001), - 80, - balancerpb.TransportProto_TCP, - 2, - 6100, - ) - }) - - // Phase 7: Completely different configuration - t.Run("Phase7_CompletelyDifferent", func(t *testing.T) { - t.Log("Testing: Completely different configuration") - - // Create completely new VS sets with different base indices - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Verify: Nothing reused - frVerifyUpdateInfo(t, updateInfo, false, false, 0) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Enable new reals and test - utils.EnableAllReals(t, ts) - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(5001), - 80, - balancerpb.TransportProto_TCP, - 2, - 7000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(5001), - 80, - balancerpb.TransportProto_TCP, - 2, - 7100, - ) - }) - - // Phase 8: Identical configuration (everything reused) - t.Run("Phase8_IdenticalConfig", func(t *testing.T) { - t.Log("Testing: Identical configuration (everything should be reused)") - - // Create the exact same configuration as Phase 7 - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update(newConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - - // Verify: Everything reused - frVerifyUpdateInfo( - t, - updateInfo, - true, - true, - frIPv4VSCount+frIPv6VSCount, - ) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - // Test packet processing still works - // Use VS index 1 which has TCP protocol (index 0 has UDP because 0%4==0) - frSendTestPackets( - t, - ts, - frGenerateIPv4Addr(5001), - 80, - balancerpb.TransportProto_TCP, - 2, - 8000, - ) - frSendTestPackets( - t, - ts, - frGenerateIPv6Addr(5001), - 80, - balancerpb.TransportProto_TCP, - 2, - 8100, - ) - }) - - // Phase 9: Edge cases - // Each subtest explicitly sets up a known state first, then makes a specific change - // to test the expected reuse behavior. - t.Run("Phase9_EdgeCases", func(t *testing.T) { - // Phase 9a: Partial VS set changes (add some VS, remove some VS) - t.Run("PartialVSChanges", func(t *testing.T) { - t.Log("Testing: Partial VS set changes") - - // First, establish a known baseline state - baseIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - baseIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - baseConfig := frCreateConfig(baseIPv4VS, baseIPv6VS) - _, err := ts.Balancer.Update(baseConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - utils.EnableAllReals(t, ts) - t.Log( - "Baseline established: 20 IPv4 VS (base 5000) + 20 IPv6 VS (base 5000)", - ) - - // Now make partial changes: keep first 15 VS, add 5 new ones - newIPv4VS := frCreateVSSet(15, false, 5000, true) - additionalIPv4VS := frCreateVSSet(5, false, 6000, true) - newIPv4VS = append(newIPv4VS, additionalIPv4VS...) - - newIPv6VS := frCreateVSSet(15, true, 5000, true) - additionalIPv6VS := frCreateVSSet(5, true, 6000, true) - newIPv6VS = append(newIPv6VS, additionalIPv6VS...) - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update( - newConfig, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Expected behavior: - // - VS matchers NOT reused because the VS set changed (removed 5, added 5 different) - // - ACL reused for the 15 unchanged VS in each family = 30 total - assert.False( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should NOT be reused when VS set changes (removed 5 VS, added 5 new)", - ) - assert.False( - t, - updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should NOT be reused when VS set changes (removed 5 VS, added 5 new)", - ) - assert.Equal( - t, - 30, - len(updateInfo.ACLReusedVs), - "30 VS should have ACL reused (15 unchanged IPv4 + 15 unchanged IPv6)", - ) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - utils.EnableAllReals(t, ts) - }) - - // Phase 9b: Mixed protocol changes - change protocol for some IPv4 VS only - t.Run("MixedProtocolChanges", func(t *testing.T) { - t.Log("Testing: Mixed protocol changes") - - // First, establish a known baseline state - baseIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - baseIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - baseConfig := frCreateConfig(baseIPv4VS, baseIPv6VS) - _, err := ts.Balancer.Update(baseConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - utils.EnableAllReals(t, ts) - t.Log( - "Baseline established: 20 IPv4 VS + 20 IPv6 VS with standard protocols", - ) - - // Now change protocol for first 5 IPv4 VS - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - - // Change protocol for first 5 IPv4 VS from TCP to UDP (or vice versa) - for i := 0; i < 5; i++ { - if newIPv4VS[i].Id.Proto == balancerpb.TransportProto_TCP { - newIPv4VS[i].Id.Proto = balancerpb.TransportProto_UDP - } else { - newIPv4VS[i].Id.Proto = balancerpb.TransportProto_TCP - } - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update( - newConfig, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Expected behavior: - // - IPv4 matcher NOT reused because protocol changed for 5 VS (different VS identifiers) - // - IPv6 matcher REUSED because IPv6 VS set is identical - // - ACL reused for unchanged VS: 15 IPv4 (with same identifier) + 20 IPv6 = 35 - assert.False( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should NOT be reused when protocol changes for some VS", - ) - assert.True(t, updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused (IPv6 VS set unchanged)") - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - utils.EnableAllReals(t, ts) - }) - - // Phase 9c: Port changes - change port for some IPv4 VS only - t.Run("PortChanges", func(t *testing.T) { - t.Log("Testing: Port changes") - - // First, establish a known baseline state - baseIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - baseIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - baseConfig := frCreateConfig(baseIPv4VS, baseIPv6VS) - _, err := ts.Balancer.Update(baseConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - utils.EnableAllReals(t, ts) - t.Log("Baseline established: 20 IPv4 VS + 20 IPv6 VS with port 80") - - // Now change port for first 5 IPv4 VS - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - - for i := range 5 { - newIPv4VS[i].Id.Port = 8080 - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update( - newConfig, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Expected behavior: - // - IPv4 matcher NOT reused because port changed for 5 VS (different VS identifiers) - // - IPv6 matcher REUSED because IPv6 VS set is identical - assert.False( - t, - updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should NOT be reused when port changes for some VS", - ) - assert.True(t, updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused (IPv6 VS set unchanged)") - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - - utils.EnableAllReals(t, ts) - }) - - // Phase 9d: ACL with different port ranges (same VS identifiers, different ACL) - t.Run("ACLPortRangeChanges", func(t *testing.T) { - t.Log("Testing: ACL with different port ranges") - - // First, establish a known baseline state - baseIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - baseIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - baseConfig := frCreateConfig(baseIPv4VS, baseIPv6VS) - _, err := ts.Balancer.Update(baseConfig, ts.Mock.CurrentTime()) - require.NoError(t, err) - utils.EnableAllReals(t, ts) - t.Log( - "Baseline established: 20 IPv4 VS + 20 IPv6 VS with standard ACLs", - ) - - // Now change ACL port ranges for first 5 IPv4 VS - newIPv4VS := frCreateVSSet(frIPv4VSCount, false, 5000, true) - newIPv6VS := frCreateVSSet(frIPv6VSCount, true, 5000, true) - - // Change port ranges in ACL for first 5 IPv4 VS - for i := 0; i < 5; i++ { - if len(newIPv4VS[i].AllowedSrcs) > 0 { - // Add a new port range to the ACL - newIPv4VS[i].AllowedSrcs[0].Ports = append( - newIPv4VS[i].AllowedSrcs[0].Ports, - &balancerpb.PortsRange{From: 8000, To: 9000}, - ) - } - } - - newConfig := frCreateConfig(newIPv4VS, newIPv6VS) - updateInfo, err := ts.Balancer.Update( - newConfig, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Expected behavior: - // - Both matchers REUSED because VS identifiers are identical - // - ACL reused for 15 IPv4 (unchanged ACL) + 20 IPv6 (unchanged) = 35 - assert.True(t, updateInfo.VsIpv4MatcherReused, - "IPv4 matcher should be reused (VS identifiers unchanged)") - assert.True(t, updateInfo.VsIpv6MatcherReused, - "IPv6 matcher should be reused (VS identifiers unchanged)") - assert.Equal( - t, - 35, - len(updateInfo.ACLReusedVs), - "35 VS should have ACL reused (15 unchanged IPv4 + 20 unchanged IPv6)", - ) - - t.Logf( - "IPv4 matcher reused: %v, IPv6 matcher reused: %v, ACL reused count: %d", - updateInfo.VsIpv4MatcherReused, - updateInfo.VsIpv6MatcherReused, - len(updateInfo.ACLReusedVs), - ) - }) - }) -} diff --git a/modules/balancer/tests/go/icmp_bcast_test.go b/modules/balancer/tests/go/icmp_bcast_test.go deleted file mode 100644 index c4c9942e6..000000000 --- a/modules/balancer/tests/go/icmp_bcast_test.go +++ /dev/null @@ -1,1085 +0,0 @@ -package balancer_test - -// TestICMPBroadcast is a comprehensive test suite for ICMP broadcast functionality in the balancer module that covers: -// -// # ICMP Broadcast Logic - Simplified -// - Case 1: Decap + Any ICMP_IDENT → Should NOT broadcast (came from peer) -// - Case 2: Decap + Any ICMP_IDENT → Should NOT broadcast (came from peer) -// - Case 3: No Decap + Any ICMP_IDENT → Should broadcast (external packet) -// - Case 4: No Decap + Any ICMP_IDENT → Should broadcast (external packet) -// -// Note: ICMP_BROADCAST_IDENT marker is no longer used for broadcast decision. -// Only the decap_flag matters: if packet was decapsulated, it came from a peer -// and should not be re-broadcasted to prevent loops. -// -// # ICMP Broadcast Marker -// - ICMP_BROADCAST_IDENT (0x0BDC) magic value set on broadcasted packets -// - Marker set in the unused field of ICMP error messages -// - Used for identification/debugging but not for broadcast decisions -// -// # Tunneled ICMP Packets -// - IP-in-IP tunneling for IPv4 ICMP errors -// - IPv6-in-IPv6 tunneling for ICMPv6 errors -// - Proper decapsulation and marker checking -// -// # Two-Balancer Integration -// - Balancer1 broadcasts ICMP error to Balancer2 -// - Balancer2 receives broadcasted packet with marker -// - Balancer2 forwards to real server (has session) -// - Balancer2 does NOT re-broadcast to Balancer1 (marker prevents loop) -// -// The test validates: -// - Correct broadcast behavior based on decap and marker presence -// - Prevention of broadcast loops using ICMP_BROADCAST_IDENT -// - Proper tunneling and decapsulation -// - Session-based forwarding after broadcast reception -// - IPv4 and IPv6 support for all scenarios - -import ( - "net" - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Broadcast Logic - All Four Cases -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPBroadcastLogic(t *testing.T) { - vsIPv4 := netip.MustParseAddr("10.1.1.1") - realIPv4 := netip.MustParseAddr("10.2.2.2") - clientIPv4 := netip.MustParseAddr("10.0.1.1") - balancerIPv4 := netip.MustParseAddr("5.5.5.5") - peer1IPv4 := netip.MustParseAddr("5.5.5.6") - peer2IPv4 := netip.MustParseAddr("5.5.5.7") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - realIPv6 := netip.MustParseAddr("2001:db8:2::2") - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - balancerIPv6 := netip.MustParseAddr("fe80::5") - peer1IPv6 := netip.MustParseAddr("fe80::6") - peer2IPv6 := netip.MustParseAddr("fe80::7") - - clientPort := uint16(12345) - vsPort := uint16(80) - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: balancerIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: balancerIPv6.AsSlice(), - }, - // Configure decap addresses - packets to these addresses will be decapsulated - DecapAddresses: []*balancerpb.Addr{ - {Bytes: balancerIPv4.AsSlice()}, - {Bytes: balancerIPv6.AsSlice()}, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - // Configure peers for broadcasting - Peers: []*balancerpb.Addr{ - {Bytes: peer1IPv4.AsSlice()}, - {Bytes: peer2IPv4.AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - // Configure peers for broadcasting - Peers: []*balancerpb.Addr{ - {Bytes: peer1IPv6.AsSlice()}, - {Bytes: peer2IPv6.AsSlice()}, - }, - }, - }, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Create an original TCP packet that will be embedded in ICMP errors - originalTCPLayers := utils.MakeTCPPacket( - vsIPv4, - vsPort, - clientIPv4, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - originalTCPPacket := xpacket.LayersToPacket(t, originalTCPLayers...) - - originalTCPv6Layers := utils.MakeTCPPacket( - vsIPv6, - vsPort, - clientIPv6, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - originalTCPv6Packet := xpacket.LayersToPacket(t, originalTCPv6Layers...) - - t.Run( - "Case1_IPv4_Decap_NoIcmpIdent_ShouldNotBroadcast", - func(t *testing.T) { - // Create a tunneled ICMP packet with normal ident (not ICMP_BROADCAST_IDENT) - // The outer destination is the balancer address (will trigger decap) - icmpLayers := utils.MakeTunneledICMPv4DestUnreachable( - peer1IPv4, // tunnel src (from another balancer) - balancerIPv4, // tunnel dst (this balancer - will trigger decap) - clientIPv4, // inner ICMP src - vsIPv4, // inner ICMP dst - originalTCPPacket, - 0x1234, // normal ident (not ICMP_BROADCAST_IDENT) - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Expected: packet should NOT be broadcasted (came from another balancer) - require.Equal( - t, - 0, - len(result.Output), - "Case 1: decap (from peer) should NOT broadcast", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 1: packet should be dropped", - ) - }, - ) - - t.Run( - "Case2_IPv4_Decap_WithIcmpIdent_ShouldNotBroadcast", - func(t *testing.T) { - // Create a tunneled ICMP packet with ICMP_BROADCAST_IDENT - // This simulates a packet that was already broadcasted by another balancer - icmpLayers := utils.MakeTunneledICMPv4DestUnreachable( - peer1IPv4, // tunnel src (from another balancer) - balancerIPv4, // tunnel dst (this balancer - will trigger decap) - clientIPv4, // inner ICMP src - vsIPv4, // inner ICMP dst - originalTCPPacket, - utils.ICMPBroadcastIdent, // magic ident indicating already broadcasted - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Expected: packet should NOT be broadcasted (already was by another balancer) - require.Equal( - t, - 0, - len(result.Output), - "Case 2: decap + icmp_ident should NOT broadcast", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 2: packet should be dropped", - ) - }, - ) - - t.Run( - "Case3_IPv4_NoDecap_WithIcmpIdent_ShouldBroadcast", - func(t *testing.T) { - // Create a non-tunneled ICMP packet with ICMP_BROADCAST_IDENT - // Since there's no decap, the ident check is skipped - icmpLayers := utils.MakeICMPv4DestUnreachableWithIdent( - clientIPv4, - vsIPv4, - originalTCPPacket, - utils.ICMPBroadcastIdent, // has magic ident but no decap - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Expected: packet should be broadcasted (no decap, so ident is ignored) - require.Equal( - t, - 2, - len(result.Output), - "Case 3: no decap + icmp_ident should broadcast to 2 peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 3: original packet should be dropped", - ) - - // Verify both broadcasted packets are properly tunneled with ICMP_BROADCAST_IDENT - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[0], - net.IP(peer1IPv4.AsSlice()), - ) - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[1], - net.IP(peer2IPv4.AsSlice()), - ) - }, - ) - - t.Run("Case4_IPv4_NoDecap_NoIcmpIdent_ShouldBroadcast", func(t *testing.T) { - // Create a non-tunneled ICMP packet with normal ident - icmpLayers := utils.MakeICMPv4DestUnreachableWithIdent( - clientIPv4, - vsIPv4, - originalTCPPacket, - 0x5678, // normal ident - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Expected: packet should be broadcasted (normal case) - require.Equal( - t, - 2, - len(result.Output), - "Case 4: no decap + no icmp_ident should broadcast to 2 peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 4: original packet should be dropped", - ) - - // Verify both broadcasted packets are properly tunneled with ICMP_BROADCAST_IDENT - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[0], - net.IP(peer1IPv4.AsSlice()), - ) - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[1], - net.IP(peer2IPv4.AsSlice()), - ) - }) - - // IPv6 test cases - t.Run( - "Case1_IPv6_Decap_NoIcmpIdent_ShouldNotBroadcast", - func(t *testing.T) { - icmpLayers := utils.MakeTunneledICMPv6DestUnreachable( - peer1IPv6, - balancerIPv6, - clientIPv6, - vsIPv6, - originalTCPv6Packet, - 0x1234, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - require.Equal( - t, - 0, - len(result.Output), - "Case 1 IPv6: decap (from peer) should NOT broadcast", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 1 IPv6: packet should be dropped", - ) - }, - ) - - t.Run( - "Case2_IPv6_Decap_WithIcmpIdent_ShouldNotBroadcast", - func(t *testing.T) { - icmpLayers := utils.MakeTunneledICMPv6DestUnreachable( - peer1IPv6, - balancerIPv6, - clientIPv6, - vsIPv6, - originalTCPv6Packet, - utils.ICMPBroadcastIdent, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - require.Equal( - t, - 0, - len(result.Output), - "Case 2 IPv6: decap + icmp_ident should NOT broadcast", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 2 IPv6: packet should be dropped", - ) - }, - ) - - t.Run( - "Case3_IPv6_NoDecap_WithIcmpIdent_ShouldBroadcast", - func(t *testing.T) { - icmpLayers := utils.MakeICMPv6DestUnreachableWithIdent( - clientIPv6, - vsIPv6, - originalTCPv6Packet, - utils.ICMPBroadcastIdent, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - require.Equal( - t, - 2, - len(result.Output), - "Case 3 IPv6: no decap + icmp_ident should broadcast to 2 peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 3 IPv6: original packet should be dropped", - ) - }, - ) - - t.Run("Case4_IPv6_NoDecap_NoIcmpIdent_ShouldBroadcast", func(t *testing.T) { - icmpLayers := utils.MakeICMPv6DestUnreachableWithIdent( - clientIPv6, - vsIPv6, - originalTCPv6Packet, - 0x5678, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - require.Equal( - t, - 2, - len(result.Output), - "Case 4 IPv6: no decap + no icmp_ident should broadcast to 2 peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Case 4 IPv6: original packet should be dropped", - ) - - // Verify both broadcasted packets are properly tunneled with ICMP_BROADCAST_IDENT - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[0], - net.IP(peer1IPv6.AsSlice()), - ) - utils.VerifyBroadcastedICMPPacket( - t, - result.Output[1], - net.IP(peer2IPv6.AsSlice()), - ) - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Test: Two-Balancer ICMP Broadcast Integration -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPBroadcastTwoBalancers(t *testing.T) { - // Setup: Two balancers where Balancer1 broadcasts to Balancer2 - // Balancer1 has no session, so it broadcasts - // Balancer2 has a session, so it forwards to real - // Balancer2 also has Balancer1 as peer, but should NOT re-broadcast - // because the packet has ICMP_BROADCAST_IDENT marker - - vsIPv4 := netip.MustParseAddr("10.1.1.1") - realIPv4 := netip.MustParseAddr("10.2.2.2") - clientIPv4 := netip.MustParseAddr("10.0.1.1") - balancer1IPv4 := netip.MustParseAddr("5.5.5.5") - balancer2IPv4 := netip.MustParseAddr("5.5.5.6") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - realIPv6 := netip.MustParseAddr("2001:db8:2::2") - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - balancer1IPv6 := netip.MustParseAddr("fe80::5") - balancer2IPv6 := netip.MustParseAddr("fe80::6") - - clientPort := uint16(12345) - vsPort := uint16(80) - - // Configure Balancer1 - has Balancer2 as peer, no session - config1 := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: balancer1IPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: balancer1IPv6.AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: balancer1IPv4.AsSlice()}, - {Bytes: balancer1IPv6.AsSlice()}, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - // Balancer1 has Balancer2 as peer - Peers: []*balancerpb.Addr{ - {Bytes: balancer2IPv4.AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - // Balancer1 has Balancer2 as IPv6 peer - Peers: []*balancerpb.Addr{ - {Bytes: balancer2IPv6.AsSlice()}, - }, - }, - }, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Configure Balancer2 - can decap packets from Balancer1 - // Has Balancer1 as peer to verify it doesn't re-broadcast - config2 := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: balancer2IPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: balancer2IPv6.AsSlice(), - }, - DecapAddresses: []*balancerpb.Addr{ - {Bytes: balancer2IPv4.AsSlice()}, - {Bytes: balancer2IPv6.AsSlice()}, - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - // Balancer2 has Balancer1 as peer - // This verifies that Balancer2 doesn't re-broadcast the packet - // back to Balancer1 (because it has ICMP_BROADCAST_IDENT marker) - Peers: []*balancerpb.Addr{ - {Bytes: balancer1IPv4.AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - // Balancer2 has Balancer1 as IPv6 peer - // This verifies that Balancer2 doesn't re-broadcast the packet - // back to Balancer1 (because it has ICMP_BROADCAST_IDENT marker) - Peers: []*balancerpb.Addr{ - {Bytes: balancer1IPv6.AsSlice()}, - }, - }, - }, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup Balancer1 - setup1, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config1, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer setup1.Free() - - // Setup Balancer2 - setup2, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config2, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer setup2.Free() - - // Enable all reals on both balancers - utils.EnableAllReals(t, setup1) - utils.EnableAllReals(t, setup2) - - t.Run("IPv4", func(t *testing.T) { - // Step 1: Create a session on Balancer2 by sending a TCP SYN packet - tcpLayers := utils.MakeTCPPacket( - clientIPv4, - clientPort, - vsIPv4, - vsPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - result, err := setup2.Mock.HandlePackets(tcpPacket) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "Balancer2 should forward TCP SYN", - ) - - // Step 2: Create an ICMP error packet for the response - // The response would come from VS IP to client IP - responsePacket := utils.MakeTCPPacket( - vsIPv4, - vsPort, - clientIPv4, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - responsePacketData := xpacket.LayersToPacket(t, responsePacket...) - - // Step 3: Send ICMP error to Balancer1 (which has no session) - icmpLayers := utils.MakeICMPv4DestUnreachable( - clientIPv4, - vsIPv4, - responsePacketData, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err = setup1.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Balancer1 should broadcast to Balancer2 (1 output packet) - require.Equal( - t, - 1, - len(result.Output), - "Balancer1 should broadcast ICMP to Balancer2", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Balancer1 should drop original packet", - ) - - // Verify the broadcasted packet is tunneled to Balancer2 - broadcastedPacket := result.Output[0] - require.True( - t, - broadcastedPacket.IsTunneled, - "packet should be tunneled", - ) - require.Equal( - t, - net.IP(balancer2IPv4.AsSlice()), - broadcastedPacket.DstIP, - "packet should be sent to Balancer2", - ) - - // Step 4: Send the broadcasted packet to Balancer2 - // Balancer2 should: - // 1. Decap the packet - // 2. See it has ICMP_BROADCAST_IDENT marker and decap=true - // 3. Forward to real (because it has a session) - // 4. NOT re-broadcast to Balancer1 (because of the marker) - broadcastedGoPacket := xpacket.ParseEtherPacket( - broadcastedPacket.RawData, - ) - result, err = setup2.Mock.HandlePackets(broadcastedGoPacket) - require.NoError(t, err) - - // Balancer2 should forward the ICMP error to the real server - // and NOT re-broadcast it - require.Equal( - t, - 1, - len(result.Output), - "Balancer2 should forward ICMP to real (not re-broadcast)", - ) - require.Empty(t, result.Drop, "Balancer2 should not drop the packet") - - // Verify the packet is tunneled to the real server (not to Balancer1) - forwardedPacket := result.Output[0] - require.True( - t, - forwardedPacket.IsTunneled, - "packet should be tunneled to real", - ) - require.Equal( - t, - net.IP(realIPv4.AsSlice()), - forwardedPacket.DstIP, - "packet should be sent to real server, not back to Balancer1", - ) - }) - - t.Run("IPv6", func(t *testing.T) { - // Step 1: Create a session on Balancer2 by sending a TCP SYN packet - tcpLayers := utils.MakeTCPPacket( - clientIPv6, - clientPort, - vsIPv6, - vsPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - result, err := setup2.Mock.HandlePackets(tcpPacket) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "Balancer2 should forward TCP SYN", - ) - - // Step 2: Create an ICMPv6 error packet for the response - // The response would come from VS IP to client IP - responsePacket := utils.MakeTCPPacket( - vsIPv6, - vsPort, - clientIPv6, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - responsePacketData := xpacket.LayersToPacket(t, responsePacket...) - - // Step 3: Send ICMPv6 error to Balancer1 (which has no session) - icmpLayers := utils.MakeICMPv6DestUnreachableWithIdent( - clientIPv6, - vsIPv6, - responsePacketData, - 0x1234, // normal ident (not ICMP_BROADCAST_IDENT) - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err = setup1.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Balancer1 should broadcast to Balancer2 (1 output packet) - require.Equal( - t, - 1, - len(result.Output), - "Balancer1 should broadcast ICMPv6 to Balancer2", - ) - require.Equal( - t, - 1, - len(result.Drop), - "Balancer1 should drop original packet", - ) - - // Verify the broadcasted packet is tunneled to Balancer2 - broadcastedPacket := result.Output[0] - require.True( - t, - broadcastedPacket.IsTunneled, - "packet should be tunneled", - ) - require.Equal( - t, - net.IP(balancer2IPv6.AsSlice()), - broadcastedPacket.DstIP, - "packet should be sent to Balancer2", - ) - - // Step 4: Send the broadcasted packet to Balancer2 - // Balancer2 should: - // 1. Decap the packet - // 2. See it has ICMP_BROADCAST_IDENT marker and decap=true - // 3. Forward to real (because it has a session) - // 4. NOT re-broadcast to Balancer1 (because of the marker) - broadcastedGoPacket := xpacket.ParseEtherPacket( - broadcastedPacket.RawData, - ) - result, err = setup2.Mock.HandlePackets(broadcastedGoPacket) - require.NoError(t, err) - - // Balancer2 should forward the ICMPv6 error to the real server - // and NOT re-broadcast it - require.Equal( - t, - 1, - len(result.Output), - "Balancer2 should forward ICMPv6 to real (not re-broadcast)", - ) - require.Empty(t, result.Drop, "Balancer2 should not drop the packet") - - // Verify the packet is tunneled to the real server (not to Balancer1) - forwardedPacket := result.Output[0] - require.True( - t, - forwardedPacket.IsTunneled, - "packet should be tunneled to real", - ) - require.Equal( - t, - net.IP(realIPv6.AsSlice()), - forwardedPacket.DstIP, - "packet should be sent to real server, not back to Balancer1", - ) - }) -} diff --git a/modules/balancer/tests/go/icmp_test.go b/modules/balancer/tests/go/icmp_test.go deleted file mode 100644 index 13fd31fe1..000000000 --- a/modules/balancer/tests/go/icmp_test.go +++ /dev/null @@ -1,1315 +0,0 @@ -package balancer_test - -// TestICMP is a comprehensive test suite for ICMP packet handling in the balancer module that covers: -// -// # ICMP Echo Request/Reply -// - IPv4 and IPv6 echo request handling -// - Proper echo reply generation -// - IP address swapping in responses -// - TTL/HopLimit reset to 64 -// -// # ICMP Echo to Non-Virtual Service -// - Dropping echo requests to non-configured VS IPs -// - Proper response to valid VS IPs -// - IPv4 and IPv6 validation -// -// # ICMP Error Packet Forwarding -// - Forwarding ICMP errors when session exists -// - Tunneling ICMP errors to real servers -// - IPv4 Destination Unreachable handling -// - IPv6 Destination Unreachable handling -// -// # ICMP Error Packet Dropping -// - Dropping ICMP errors for unknown virtual services -// - Dropping ICMP errors when no session exists -// - Broadcasting ICMP errors to peers when no session found -// -// The test validates: -// - Correct ICMP packet type and code -// - Proper IP address handling -// - Session-based ICMP error forwarding -// - Peer broadcasting for unknown sessions -// - Packet tunneling to real servers - -import ( - "net" - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Echo Request/Reply for IPv4 and IPv6 -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPEchoRequest(t *testing.T) { - // Define test addresses - vsIPv4 := netip.MustParseAddr("10.1.1.1") - clientIPv4 := netip.MustParseAddr("10.0.1.1") - vsIPv6 := netip.MustParseAddr("2001:db8::1") - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - - // Create balancer configuration - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("IPv4", func(t *testing.T) { - // Create ICMP Echo Request - packetLayers := utils.MakeICMPv4EchoRequest(clientIPv4, vsIPv4, 1234, 1) - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Send packet - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "should have one output packet") - require.Empty(t, result.Drop, "should not drop packet") - - // Parse response - responsePacket := gopacket.NewPacket( - result.Output[0].RawData, - layers.LayerTypeEthernet, - gopacket.Default, - ) - - // Verify it's an ICMP Echo Reply - icmpLayer := responsePacket.Layer(layers.LayerTypeICMPv4) - require.NotNil(t, icmpLayer, "response should have ICMPv4 layer") - - icmp := icmpLayer.(*layers.ICMPv4) - assert.Equal( - t, - uint8(layers.ICMPv4TypeEchoReply), - uint8(icmp.TypeCode.Type()), - "should be Echo Reply", - ) - assert.Equal( - t, - uint8(0), - uint8(icmp.TypeCode.Code()), - "code should be 0", - ) - assert.Equal(t, uint16(1234), icmp.Id, "ID should match request") - assert.Equal(t, uint16(1), icmp.Seq, "sequence should match request") - - // Verify IP addresses are swapped - ipLayer := responsePacket.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "response should have IPv4 layer") - - ip := ipLayer.(*layers.IPv4) - assert.Equal( - t, - net.IP(vsIPv4.AsSlice()), - ip.SrcIP, - "src IP should be VS IP", - ) - assert.Equal( - t, - net.IP(clientIPv4.AsSlice()), - ip.DstIP, - "dst IP should be client IP", - ) - assert.Equal(t, uint8(64), ip.TTL, "TTL should be reset to 64") - }) - - t.Run("IPv6", func(t *testing.T) { - // Create ICMPv6 Echo Request - packetLayers := utils.MakeICMPv6EchoRequest(clientIPv6, vsIPv6, 5678, 2) - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Send packet - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "should have one output packet") - require.Empty(t, result.Drop, "should not drop packet") - - // Parse response - responsePacket := gopacket.NewPacket( - result.Output[0].RawData, - layers.LayerTypeEthernet, - gopacket.Default, - ) - - // Verify it's an ICMPv6 Echo Reply - icmpLayer := responsePacket.Layer(layers.LayerTypeICMPv6) - require.NotNil(t, icmpLayer, "response should have ICMPv6 layer") - - icmp := icmpLayer.(*layers.ICMPv6) - assert.Equal( - t, - uint8(layers.ICMPv6TypeEchoReply), - uint8(icmp.TypeCode.Type()), - "should be Echo Reply", - ) - assert.Equal( - t, - uint8(0), - uint8(icmp.TypeCode.Code()), - "code should be 0", - ) - - // Verify IP addresses are swapped - ipLayer := responsePacket.Layer(layers.LayerTypeIPv6) - require.NotNil(t, ipLayer, "response should have IPv6 layer") - - ip := ipLayer.(*layers.IPv6) - assert.Equal( - t, - net.IP(vsIPv6.AsSlice()), - ip.SrcIP, - "src IP should be VS IP", - ) - assert.Equal( - t, - net.IP(clientIPv6.AsSlice()), - ip.DstIP, - "dst IP should be client IP", - ) - assert.Equal( - t, - uint8(64), - ip.HopLimit, - "hop limit should be reset to 64", - ) - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Echo Request to non-virtual service IP should be dropped -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPEchoRequestToNonVirtualService(t *testing.T) { - vsIPv4 := netip.MustParseAddr("10.1.1.1") - nonVsIPv4 := netip.MustParseAddr("10.99.99.99") // Not configured as VS - clientIPv4 := netip.MustParseAddr("10.0.1.1") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - nonVsIPv6 := netip.MustParseAddr("2001:db8:99::99") // Not configured as VS - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: 80, - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("IPv4_NonVS_ShouldDrop", func(t *testing.T) { - // Create ICMP Echo Request to non-VS IP - packetLayers := utils.MakeICMPv4EchoRequest( - clientIPv4, - nonVsIPv4, - 1234, - 1, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Send packet - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - - // Packet should be dropped, not responded to - require.Empty(t, result.Output, "should not respond to non-VS IP") - require.Equal(t, 1, len(result.Drop), "should drop packet") - }) - - t.Run("IPv6_NonVS_ShouldDrop", func(t *testing.T) { - // Create ICMPv6 Echo Request to non-VS IP - packetLayers := utils.MakeICMPv6EchoRequest( - clientIPv6, - nonVsIPv6, - 5678, - 2, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Send packet - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - - // Packet should be dropped, not responded to - require.Empty(t, result.Output, "should not respond to non-VS IP") - require.Equal(t, 1, len(result.Drop), "should drop packet") - }) - - t.Run("IPv4_ValidVS_ShouldRespond", func(t *testing.T) { - // Create ICMP Echo Request to valid VS IP - packetLayers := utils.MakeICMPv4EchoRequest(clientIPv4, vsIPv4, 1234, 1) - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Send packet - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - - // Should respond to valid VS IP - require.Equal(t, 1, len(result.Output), "should respond to VS IP") - require.Empty(t, result.Drop, "should not drop packet") - - // Verify it's an ICMP Echo Reply - responsePacket := gopacket.NewPacket( - result.Output[0].RawData, - layers.LayerTypeEthernet, - gopacket.Default, - ) - icmpLayer := responsePacket.Layer(layers.LayerTypeICMPv4) - require.NotNil(t, icmpLayer, "response should have ICMPv4 layer") - icmp := icmpLayer.(*layers.ICMPv4) - assert.Equal( - t, - uint8(layers.ICMPv4TypeEchoReply), - uint8(icmp.TypeCode.Type()), - "should be Echo Reply", - ) - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Error packet forwarding when session exists -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPErrorWithExistingSession(t *testing.T) { - vsIPv4 := netip.MustParseAddr("10.1.1.1") - realIPv4 := netip.MustParseAddr("10.2.2.2") - clientIPv4 := netip.MustParseAddr("10.0.1.1") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - realIPv6 := netip.MustParseAddr("2001:db8:2::2") - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - - clientPort := uint16(12345) - vsPort := uint16(80) - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv4.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realIPv6.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - t.Run("IPv4", func(t *testing.T) { - // First, create a session by sending a TCP SYN packet - tcpLayers := utils.MakeTCPPacket( - clientIPv4, - clientPort, - vsIPv4, - vsPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - result, err := ts.Mock.HandlePackets(tcpPacket) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "TCP packet should be forwarded", - ) - - // Now simulate the real server's response packet (which would trigger an ICMP error) - // The real server responds with src=vsIP (as configured), dst=clientIP - responsePacket := utils.MakeTCPPacket( - vsIPv4, - vsPort, - clientIPv4, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - responsePacketData := xpacket.LayersToPacket(t, responsePacket...) - - // Now send an ICMP Destination Unreachable error containing the response packet - // The ICMP error comes from the client network to the VS IP (balancer) - // because the response packet had src=vsIP - icmpLayers := utils.MakeICMPv4DestUnreachable( - clientIPv4, - vsIPv4, - responsePacketData, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err = ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // The ICMP error should be forwarded to the real server (tunneled) - require.Equal( - t, - 1, - len(result.Output), - "ICMP error should be forwarded", - ) - require.Empty(t, result.Drop, "ICMP error should not be dropped") - - // Verify the packet is tunneled - outputPacket := result.Output[0] - assert.True( - t, - outputPacket.IsTunneled, - "ICMP error should be tunneled to real", - ) - assert.Equal( - t, - net.IP(realIPv4.AsSlice()), - outputPacket.DstIP, - "should be sent to real server", - ) - }) - - t.Run("IPv6", func(t *testing.T) { - // First, create a session by sending a TCP SYN packet - tcpLayers := utils.MakeTCPPacket( - clientIPv6, - clientPort, - vsIPv6, - vsPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - result, err := ts.Mock.HandlePackets(tcpPacket) - require.NoError(t, err) - require.Equal( - t, - 1, - len(result.Output), - "TCP packet should be forwarded", - ) - - // Now simulate the real server's response packet (which would trigger an ICMPv6 error) - // The real server responds with src=vsIP (as configured), dst=clientIP - responsePacket := utils.MakeTCPPacket( - vsIPv6, - vsPort, - clientIPv6, - clientPort, - &layers.TCP{SYN: true, ACK: true}, - ) - responsePacketData := xpacket.LayersToPacket(t, responsePacket...) - - // Now send an ICMPv6 Destination Unreachable error containing the response packet - // The ICMPv6 error comes from the client network to the VS IP (balancer) - // because the response packet had src=vsIP - icmpLayers := utils.MakeICMPv6DestUnreachable( - clientIPv6, - vsIPv6, - responsePacketData, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err = ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // The ICMPv6 error should be forwarded to the real server (tunneled) - require.Equal( - t, - 1, - len(result.Output), - "ICMPv6 error should be forwarded", - ) - require.Empty(t, result.Drop, "ICMPv6 error should not be dropped") - - // Verify the packet is tunneled - outputPacket := result.Output[0] - assert.True( - t, - outputPacket.IsTunneled, - "ICMPv6 error should be tunneled to real", - ) - assert.Equal( - t, - net.IP(realIPv6.AsSlice()), - outputPacket.DstIP, - "should be sent to real server", - ) - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Error packet drop when VS not found -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPErrorWithUnknownVS(t *testing.T) { - vsIPv4 := netip.MustParseAddr("10.1.1.1") - unknownVsIPv4 := netip.MustParseAddr("10.99.99.99") // Not configured - clientIPv4 := netip.MustParseAddr("10.0.1.1") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - unknownVsIPv6 := netip.MustParseAddr("2001:db8:99::99") // Not configured - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - - clientPort := uint16(12345) - vsPort := uint16(80) - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("IPv4", func(t *testing.T) { - // Create a TCP packet to an unknown VS - tcpLayers := utils.MakeTCPPacket( - unknownVsIPv4, - vsPort, - clientIPv4, - clientPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - // Create ICMP error for the unknown VS - icmpLayers := utils.MakeICMPv4DestUnreachable( - clientIPv4, - unknownVsIPv4, - tcpPacket, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // The ICMP error should be dropped because VS is not found - require.Empty(t, result.Output, "ICMP error should not be forwarded") - require.Equal(t, 1, len(result.Drop), "ICMP error should be dropped") - }) - - t.Run("IPv6", func(t *testing.T) { - // Create a TCP packet to an unknown VS - tcpLayers := utils.MakeTCPPacket( - unknownVsIPv6, - vsPort, - clientIPv6, - clientPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - // Create ICMPv6 error for the unknown VS - icmpLayers := utils.MakeICMPv6DestUnreachable( - clientIPv6, - unknownVsIPv6, - tcpPacket, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // The ICMPv6 error should be dropped because VS is not found - require.Empty(t, result.Output, "ICMPv6 error should not be forwarded") - require.Equal(t, 1, len(result.Drop), "ICMPv6 error should be dropped") - }) -} - -//////////////////////////////////////////////////////////////////////////////// -// Test: ICMP Error packet drop when session not found -//////////////////////////////////////////////////////////////////////////////// - -func TestICMPErrorWithNoSession(t *testing.T) { - // In this test packet must be broadcasted to peers - vsIPv4 := netip.MustParseAddr("10.1.1.1") - clientIPv4 := netip.MustParseAddr("10.0.1.1") - - vsIPv6 := netip.MustParseAddr("2001:db8::1") - clientIPv6 := netip.MustParseAddr("2001:db8:1::1") - - clientPort := uint16(12345) - vsPort := uint16(80) - - peer1 := netip.MustParseAddr("10.12.11.13") - peer2 := netip.MustParseAddr("fe80::11") - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv4.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.2.2.2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: peer1.AsSlice()}, - {Bytes: peer2.AsSlice()}, - }, - }, - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIPv6.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff::"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8:2::2"). - AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{ - {Bytes: peer1.AsSlice()}, - {Bytes: peer2.AsSlice()}, - }, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - value := 16 * datasize.MB - return &value - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("IPv4", func(t *testing.T) { - // Create a TCP packet (but don't send it to create a session) - tcpLayers := utils.MakeTCPPacket( - vsIPv4, - vsPort, - clientIPv4, - clientPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - // Create ICMP error for a non-existent session - icmpLayers := utils.MakeICMPv4DestUnreachable( - clientIPv4, - vsIPv4, - tcpPacket, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Since there's no session, the packet should be broadcasted to peers - require.Equal( - t, - 2, - len(result.Output), - "ICMP error clone must be broadcasted to both peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "The original packet must be dropped", - ) - }) - - t.Run("IPv6", func(t *testing.T) { - // Create a TCP packet (but don't send it to create a session) - tcpLayers := utils.MakeTCPPacket( - vsIPv6, - vsPort, - clientIPv6, - clientPort, - &layers.TCP{SYN: true}, - ) - tcpPacket := xpacket.LayersToPacket(t, tcpLayers...) - - // Create ICMPv6 error for a non-existent session - icmpLayers := utils.MakeICMPv6DestUnreachable( - clientIPv6, - vsIPv6, - tcpPacket, - ) - icmpPacket := xpacket.LayersToPacket(t, icmpLayers...) - - result, err := ts.Mock.HandlePackets(icmpPacket) - require.NoError(t, err) - - // Since there's no session, the packet should be broadcasted to peers - require.Equal( - t, - 2, - len(result.Output), - "ICMPv6 error clone must be broadcasted to both peers", - ) - require.Equal( - t, - 1, - len(result.Drop), - "The original packet must be dropped", - ) - }) -} diff --git a/modules/balancer/tests/go/packet_test.go b/modules/balancer/tests/go/packet_test.go deleted file mode 100644 index 246a2b72c..000000000 --- a/modules/balancer/tests/go/packet_test.go +++ /dev/null @@ -1,577 +0,0 @@ -package balancer_test - -// TestPacketProcessing is a comprehensive test suite for packet processing in the balancer module that covers: -// -// # Packet Encapsulation -// - Basic encapsulation without GRE or MSS fixing -// - IPv4 and IPv6 virtual services -// - TCP and UDP protocols -// - IPv4 and IPv6 real servers -// -// # GRE Tunneling -// - GRE encapsulation for IPv4 and IPv6 -// - Proper tunnel type identification -// - All protocol and IP version combinations -// -// # MSS Fixing -// - MSS option insertion when missing -// - MSS option update when present -// - MSS clamping to maximum value (1220) -// - Default MSS value (536) when no MSS option present -// -// # Combined Features -// - GRE tunneling with MSS fixing -// - All combinations of features working together -// -// The test validates: -// - Correct packet encapsulation -// - ToS/TrafficClass preservation -// - Protocol consistency -// - Tunnel type correctness -// - MSS option handling - -import ( - "fmt" - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "github.com/yanet-platform/yanet2/tests/functional/framework" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Test addresses -var ( - clientIPv4 = netip.MustParseAddr("10.0.1.1") - clientIPv6 = netip.MustParseAddr("ffff::1") - - balancerSrcIPv4 = netip.MustParseAddr("5.5.5.5") - balancerSrcIPv6 = netip.MustParseAddr("fe80::5") -) - -// createPacketTestConfig creates a balancer configuration with all combinations of: -// - VS IP version (IPv4, IPv6) -// - Protocol (TCP, UDP) -// - GRE enabled/disabled -// - FixMSS enabled/disabled -// - Real IP version (IPv4, IPv6) -func createPacketTestConfig() *balancerpb.BalancerConfig { - var virtualServices []*balancerpb.VirtualService - counter := 1 - - for _, vsIPVersion := range []int{4, 6} { - for _, proto := range []balancerpb.TransportProto{ - balancerpb.TransportProto_TCP, - balancerpb.TransportProto_UDP, - } { - for _, greEnabled := range []bool{false, true} { - for _, fixMssEnabled := range []bool{false, true} { - for _, realIPVersion := range []int{4, 6} { - // Create VS address - var vsAddr netip.Addr - var allowedSrc *balancerpb.Net - if vsIPVersion == 4 { - vsAddr = netip.MustParseAddr( - fmt.Sprintf("10.12.1.%d", counter), - ) - allowedSrc = &balancerpb.Net{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.1.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.0"). - AsSlice(), - }, - } - } else { - vsAddr = netip.MustParseAddr(fmt.Sprintf("2001:db8::%d", counter)) - allowedSrc = &balancerpb.Net{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff::0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff::").AsSlice(), - }, - } - } - - // Create real address - var realAddr netip.Addr - if realIPVersion == 4 { - realAddr = netip.MustParseAddr("10.1.1.1") - } else { - realAddr = netip.MustParseAddr("fe80::1") - } - - vs := &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsAddr.AsSlice(), - }, - Port: 8080, - Proto: proto, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - {Nets: []*balancerpb.Net{allowedSrc}}, - }, - Flags: &balancerpb.VsFlags{ - Gre: greEnabled, - FixMss: fixMssEnabled, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - } - - virtualServices = append(virtualServices, vs) - counter++ - } - } - } - } - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: balancerSrcIPv4.AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: balancerSrcIPv6.AsSlice(), - }, - Vs: virtualServices, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 10, - TcpSyn: 10, - TcpFin: 10, - Tcp: 10, - Udp: 10, - Default: 10, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(100); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// VsSelector helps select a specific virtual service based on its characteristics -type VsSelector struct { - VsIPVersion int // 4 or 6 - Proto balancerpb.TransportProto - Gre bool - FixMSS bool - RealIPVersion int // 4 or 6 -} - -// findMatchingVS finds a virtual service that matches the selector criteria -func findMatchingVS( - config *balancerpb.BalancerConfig, - selector VsSelector, -) *balancerpb.VirtualService { - for _, vs := range config.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - vsIsIPv4 := vsAddr.Is4() - vsIsIPv6 := vsAddr.Is6() - - if (vsIsIPv4 && selector.VsIPVersion == 4) || - (vsIsIPv6 && selector.VsIPVersion == 6) { - if vs.Id.Proto == selector.Proto { - if vs.Flags.FixMss == selector.FixMSS && - vs.Flags.Gre == selector.Gre { - if len(vs.Reals) > 0 { - real := vs.Reals[0] - realAddr, _ := netip.AddrFromSlice(real.Id.Ip.Bytes) - if (realAddr.Is4() && selector.RealIPVersion == 4) || - (realAddr.Is6() && selector.RealIPVersion == 6) { - return vs - } - } - } - } - } - } - return nil -} - -// sendPacketToVS sends a packet to a virtual service and returns the result -func sendPacketToVS( - t *testing.T, - ts *utils.TestSetup, - vs *balancerpb.VirtualService, - mss *uint16, -) *framework.PacketInfo { - t.Helper() - - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - clientAddr := clientIPv4 - if vsAddr.Is6() { - clientAddr = clientIPv6 - } - clientPort := uint16(40441) - vsPort := uint16(vs.Id.Port) - - var tcp *layers.TCP - if vs.Id.Proto == balancerpb.TransportProto_TCP { - tcp = &layers.TCP{SYN: true} - } - - packetLayers := utils.MakePacketLayers( - clientAddr, - clientPort, - vsAddr, - vsPort, - tcp, - ) - - packet := xpacket.LayersToPacket(t, packetLayers...) - - // Add MSS option if requested and TCP - if tcp != nil && mss != nil { - modifiedPacket, err := insertOrUpdateMSS(packet, *mss) - require.NoError(t, err, "failed to insert/update MSS") - packet = *modifiedPacket - } - - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err, "failed to handle packet") - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - - if len(result.Output) > 0 { - resultPacket := result.Output[0] - utils.ValidatePacket(t, ts.Balancer.Config(), packet, resultPacket) - return resultPacket - } - - return nil -} - -// insertOrUpdateMSS inserts or updates the MSS option in a TCP packet -func insertOrUpdateMSS( - p gopacket.Packet, - newMSS uint16, -) (*gopacket.Packet, error) { - return utils.InsertOrUpdateMSS(p, newMSS) -} - -// TestPacketProcessing is the main test function -func TestPacketProcessing(t *testing.T) { - config := createPacketTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(256*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := datasize.MB * 128 - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Test basic encapsulation - t.Run("Encapsulation", func(t *testing.T) { - testEncapsulation(t, ts) - }) - - // Test GRE tunneling - t.Run("GRE_Tunneling", func(t *testing.T) { - testGRETunneling(t, ts) - }) - - // Test MSS fixing - t.Run("MSS_Fixing", func(t *testing.T) { - testMSSFixing(t, ts) - }) - - // Test GRE + MSS fixing - t.Run("GRE_MSS_Combined", func(t *testing.T) { - testGREMSSCombined(t, ts) - }) -} - -// testEncapsulation tests basic packet encapsulation without GRE or MSS fixing -func testEncapsulation(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - - for _, proto := range []balancerpb.TransportProto{ - balancerpb.TransportProto_TCP, - balancerpb.TransportProto_UDP, - } { - for _, vsIPVersion := range []int{4, 6} { - for _, realIPVersion := range []int{4, 6} { - selector := VsSelector{ - VsIPVersion: vsIPVersion, - Proto: proto, - RealIPVersion: realIPVersion, - Gre: false, - FixMSS: false, - } - - vs := findMatchingVS(config, selector) - require.NotNil( - t, - vs, - "failed to find VS for selector: %+v", - selector, - ) - - t.Logf( - "Testing encapsulation: vsIP=v%d, realIP=v%d, proto=%s", - selector.VsIPVersion, - selector.RealIPVersion, - selector.Proto.String(), - ) - - result := sendPacketToVS(t, ts, vs, nil) - assert.NotNil(t, result, "expected result packet") - assert.True(t, result.IsTunneled, "packet should be tunneled") - } - } - } -} - -// testGRETunneling tests GRE tunneling functionality -func testGRETunneling(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - - for _, proto := range []balancerpb.TransportProto{ - balancerpb.TransportProto_TCP, - balancerpb.TransportProto_UDP, - } { - for _, vsIPVersion := range []int{4, 6} { - for _, realIPVersion := range []int{4, 6} { - selector := VsSelector{ - VsIPVersion: vsIPVersion, - Proto: proto, - RealIPVersion: realIPVersion, - Gre: true, - FixMSS: false, - } - - vs := findMatchingVS(config, selector) - require.NotNil( - t, - vs, - "failed to find VS for selector: %+v", - selector, - ) - - t.Logf( - "Testing GRE: vsIP=v%d, realIP=v%d, proto=%s", - selector.VsIPVersion, - selector.RealIPVersion, - selector.Proto.String(), - ) - - result := sendPacketToVS(t, ts, vs, nil) - assert.NotNil(t, result, "expected result packet") - assert.True(t, result.IsTunneled, "packet should be tunneled") - - // Verify GRE tunnel type - expectedTunnelType := "gre-ip4" - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if vsAddr.Is6() { - expectedTunnelType = "gre-ip6" - } - assert.Equal( - t, - expectedTunnelType, - result.TunnelType, - "tunnel type should be GRE", - ) - } - } - } -} - -// testMSSFixing tests MSS fixing functionality -func testMSSFixing(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - - // Test with different MSS values - for _, mssValue := range []uint16{0, 500, 1200, 1400, 1460} { - for _, realIPVersion := range []int{4, 6} { - selector := VsSelector{ - VsIPVersion: 6, // Use IPv6 VS - Proto: balancerpb.TransportProto_TCP, - RealIPVersion: realIPVersion, - Gre: false, - FixMSS: true, - } - - vs := findMatchingVS(config, selector) - require.NotNil( - t, - vs, - "failed to find VS for selector: %+v", - selector, - ) - - t.Logf( - "Testing MSS fixing: vsIP=v%d, realIP=v%d, mss=%d", - selector.VsIPVersion, - selector.RealIPVersion, - mssValue, - ) - - var mssPtr *uint16 - if mssValue > 0 { - mssPtr = &mssValue - } - - result := sendPacketToVS(t, ts, vs, mssPtr) - assert.NotNil(t, result, "expected result packet") - - // Verify MSS was fixed - if mssValue > 0 { - // MSS should be clamped to min(original, 1220) - expectedMSS := min(mssValue, 1220) - actualMSS := extractMSS(t, result) - assert.Equal( - t, - expectedMSS, - actualMSS, - "MSS should be fixed to %d", - expectedMSS, - ) - } else { - // No MSS option in original, should insert default (536) - actualMSS := extractMSS(t, result) - assert.Equal(t, uint16(536), actualMSS, "MSS should be default 536") - } - } - } -} - -// testGREMSSCombined tests GRE tunneling combined with MSS fixing -func testGREMSSCombined(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - - for _, mssValue := range []uint16{0, 500, 1200, 1400} { - for _, realIPVersion := range []int{4, 6} { - selector := VsSelector{ - VsIPVersion: 6, - Proto: balancerpb.TransportProto_TCP, - RealIPVersion: realIPVersion, - Gre: true, - FixMSS: true, - } - - vs := findMatchingVS(config, selector) - require.NotNil( - t, - vs, - "failed to find VS for selector: %+v", - selector, - ) - - t.Logf( - "Testing GRE+MSS: vsIP=v%d, realIP=v%d, mss=%d", - selector.VsIPVersion, - selector.RealIPVersion, - mssValue, - ) - - var mssPtr *uint16 - if mssValue > 0 { - mssPtr = &mssValue - } - - result := sendPacketToVS(t, ts, vs, mssPtr) - assert.NotNil(t, result, "expected result packet") - assert.True(t, result.IsTunneled, "packet should be tunneled") - - // Verify GRE tunnel type - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - expectedTunnelType := "gre-ip4" - if vsAddr.Is6() { - expectedTunnelType = "gre-ip6" - } - assert.Equal( - t, - expectedTunnelType, - result.TunnelType, - "tunnel type should be GRE", - ) - - // Verify MSS was fixed - if mssValue > 0 { - expectedMSS := min(mssValue, 1220) - actualMSS := extractMSS(t, result) - assert.Equal( - t, - expectedMSS, - actualMSS, - "MSS should be fixed to %d", - expectedMSS, - ) - } else { - actualMSS := extractMSS(t, result) - assert.Equal(t, uint16(536), actualMSS, "MSS should be default 536") - } - } - } -} - -// extractMSS extracts the MSS value from a packet -func extractMSS(t *testing.T, packet *framework.PacketInfo) uint16 { - t.Helper() - - p := gopacket.NewPacket( - packet.RawData, - layers.LayerTypeEthernet, - gopacket.Default, - ) - - mss, err := xpacket.PacketMSS(p) - require.NoError(t, err, "failed to extract MSS from packet") - - return mss -} diff --git a/modules/balancer/tests/go/scheduling_test.go b/modules/balancer/tests/go/scheduling_test.go deleted file mode 100644 index b0f7754f0..000000000 --- a/modules/balancer/tests/go/scheduling_test.go +++ /dev/null @@ -1,1558 +0,0 @@ -package balancer_test - -// TestScheduling is a comprehensive test suite for the balancer module that covers: -// -// # Session Management -// - TCP session establishment and persistence -// - UDP session establishment and persistence -// - Session table overflow handling -// - Session timeout configuration -// -// # Scheduling Algorithms -// - SOURCE_HASH: consistent hashing based on client IP+port -// - ROUND_ROBIN: sequential distribution across reals -// - Both with and without One Packet Scheduling (OPS) mode -// -// # Weight Distribution -// - Equal weight distribution (1:1:1) -// - Weighted distribution (1:2:3) -// - Weight updates and redistribution -// - Statistical validation of weight-based traffic distribution -// -// # Real Server Management -// - Enabling/disabling real servers -// - Handling disabled reals with existing sessions -// - Removing reals from configuration -// - Real server state transitions -// -// # API Outputs -// - Config(): Configuration retrieval and validation -// - Info(): Runtime information (active sessions, VS info) -// - Stats(): Packet processing statistics (L4, ICMP, common) -// - Graph(): Topology visualization (VS-to-real relationships) -// -// # State Restoration -// - Creating new agent from existing shared memory -// - Preserving configuration across agent restarts -// - Maintaining session state after restoration -// - Verifying all functionality after restoration -// -// The test uses 8 virtual services with different configurations: -// - VS1: TCP + SOURCE_HASH (session-based) -// - VS2: UDP + SOURCE_HASH (session-based) -// - VS3: TCP + ROUND_ROBIN (session-based) -// - VS4: UDP + ROUND_ROBIN (session-based) -// - VS5: TCP + SOURCE_HASH + OPS (no session) -// - VS6: TCP + ROUND_ROBIN + OPS (no session) -// - VS7: TCP + SOURCE_HASH + OPS + Weighted (1:2:3) -// - VS8: TCP + ROUND_ROBIN + OPS + Weighted (1:2:3) -// -// Each VS has 3 real servers for comprehensive distribution testing. - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/logging" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "github.com/yanet-platform/yanet2/tests/functional/framework" - "go.uber.org/zap/zapcore" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Virtual Service configurations for different test scenarios -var ( - // VS1: TCP + SOURCE_HASH (session-based) - vs1IP = netip.MustParseAddr("10.0.1.1") - vs1Port = uint16(80) - - // VS2: UDP + SOURCE_HASH (session-based) - vs2IP = netip.MustParseAddr("10.0.2.1") - vs2Port = uint16(80) - - // VS3: TCP + ROUND_ROBIN (session-based) - vs3IP = netip.MustParseAddr("10.0.3.1") - vs3Port = uint16(80) - - // VS4: UDP + ROUND_ROBIN (session-based) - vs4IP = netip.MustParseAddr("10.0.4.1") - vs4Port = uint16(80) - - // VS5: TCP + SOURCE_HASH + OPS (no session) - vs5IP = netip.MustParseAddr("10.0.5.1") - vs5Port = uint16(80) - - // VS6: TCP + ROUND_ROBIN + OPS (no session) - vs6IP = netip.MustParseAddr("10.0.6.1") - vs6Port = uint16(80) - - // VS7: TCP + SOURCE_HASH + OPS + Weighted (1:2:3) - vs7IP = netip.MustParseAddr("10.0.7.1") - vs7Port = uint16(80) - - // VS8: TCP + ROUND_ROBIN + OPS + Weighted (1:2:3) - vs8IP = netip.MustParseAddr("10.0.8.1") - vs8Port = uint16(80) - - // VS9: TCP + SOURCE_HASH + PureL3 (port must be 0 for PureL3) - vs9IP = netip.MustParseAddr("10.0.9.1") - vs9Port = uint16(0) - - // VS10: TCP + ROUND_ROBIN + PureL3 (port must be 0 for PureL3) - vs10IP = netip.MustParseAddr("10.0.10.1") - vs10Port = uint16(0) - - // VS11: UDP + SOURCE_HASH + PureL3 (port must be 0 for PureL3) - vs11IP = netip.MustParseAddr("10.0.11.1") - vs11Port = uint16(0) - - // VS12: UDP + ROUND_ROBIN + PureL3 (port must be 0 for PureL3) - vs12IP = netip.MustParseAddr("10.0.12.1") - vs12Port = uint16(0) - - // Real servers (3 per VS, same IPs for simplicity) - real1IP = netip.MustParseAddr("192.168.1.1") - real2IP = netip.MustParseAddr("192.168.1.2") - real3IP = netip.MustParseAddr("192.168.1.3") - - // Client base IP - clientBaseIP = netip.MustParseAddr("3.3.3.1") - - // Source address for balancer - balancerSrcV4 = netip.MustParseAddr("5.5.5.5") - balancerSrcV6 = netip.MustParseAddr("fe80::5") -) - -// createReal creates a Real configuration -func createReal(ip netip.Addr, weight uint32) *balancerpb.Real { - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: 0, - }, - Weight: weight, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255").AsSlice(), - }, - } -} - -// createVirtualService creates a VirtualService configuration -func createVirtualService( - ip netip.Addr, - port uint16, - proto balancerpb.TransportProto, - scheduler balancerpb.VsScheduler, - ops bool, - reals []*balancerpb.Real, -) *balancerpb.VirtualService { - return createVirtualServiceWithFlags( - ip, - port, - proto, - scheduler, - ops, - false, - reals, - ) -} - -// createVirtualServiceWithFlags creates a VirtualService configuration with custom flags -func createVirtualServiceWithFlags( - ip netip.Addr, - port uint16, - proto balancerpb.TransportProto, - scheduler balancerpb.VsScheduler, - ops bool, - pureL3 bool, - reals []*balancerpb.Real, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: proto, - }, - Scheduler: scheduler, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: ops, - PureL3: pureL3, - Wlc: false, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - } -} - -// createSchedulingTestConfig creates the balancer configuration with all 8 virtual services -func createSchedulingTestConfig() *balancerpb.BalancerConfig { - // Equal weight reals (weight = 1) - equalReals := []*balancerpb.Real{ - createReal(real1IP, 1), - createReal(real2IP, 1), - createReal(real3IP, 1), - } - - // Weighted reals (1:2:3) - weightedReals := []*balancerpb.Real{ - createReal(real1IP, 1), - createReal(real2IP, 2), - createReal(real3IP, 3), - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{Bytes: balancerSrcV4.AsSlice()}, - SourceAddressV6: &balancerpb.Addr{Bytes: balancerSrcV6.AsSlice()}, - Vs: []*balancerpb.VirtualService{ - // VS1: TCP + SOURCE_HASH (session-based) - createVirtualService( - vs1IP, - vs1Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_SOURCE_HASH, - false, - equalReals, - ), - // VS2: UDP + SOURCE_HASH (session-based) - createVirtualService( - vs2IP, - vs2Port, - balancerpb.TransportProto_UDP, - balancerpb.VsScheduler_SOURCE_HASH, - false, - equalReals, - ), - // VS3: TCP + ROUND_ROBIN (session-based) - createVirtualService( - vs3IP, - vs3Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_ROUND_ROBIN, - false, - equalReals, - ), - // VS4: UDP + ROUND_ROBIN (session-based) - createVirtualService( - vs4IP, - vs4Port, - balancerpb.TransportProto_UDP, - balancerpb.VsScheduler_ROUND_ROBIN, - false, - equalReals, - ), - // VS5: TCP + SOURCE_HASH + OPS (no session) - createVirtualService( - vs5IP, - vs5Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_SOURCE_HASH, - true, - equalReals, - ), - // VS6: TCP + ROUND_ROBIN + OPS (no session) - createVirtualService( - vs6IP, - vs6Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_ROUND_ROBIN, - true, - equalReals, - ), - // VS7: TCP + SOURCE_HASH + OPS + Weighted - createVirtualService( - vs7IP, - vs7Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_SOURCE_HASH, - true, - weightedReals, - ), - // VS8: TCP + ROUND_ROBIN + OPS + Weighted - createVirtualService( - vs8IP, - vs8Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_ROUND_ROBIN, - true, - weightedReals, - ), - // VS9: TCP + SOURCE_HASH + PureL3 - createVirtualServiceWithFlags( - vs9IP, - vs9Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_SOURCE_HASH, - false, - true, - equalReals, - ), - // VS10: TCP + ROUND_ROBIN + PureL3 - createVirtualServiceWithFlags( - vs10IP, - vs10Port, - balancerpb.TransportProto_TCP, - balancerpb.VsScheduler_ROUND_ROBIN, - false, - true, - equalReals, - ), - // VS11: UDP + SOURCE_HASH + PureL3 - createVirtualServiceWithFlags( - vs11IP, - vs11Port, - balancerpb.TransportProto_UDP, - balancerpb.VsScheduler_SOURCE_HASH, - false, - true, - equalReals, - ), - // VS12: UDP + ROUND_ROBIN + PureL3 - createVirtualServiceWithFlags( - vs12IP, - vs12Port, - balancerpb.TransportProto_UDP, - balancerpb.VsScheduler_ROUND_ROBIN, - false, - true, - equalReals, - ), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// generateClientIP generates a unique client IP based on index -func generateClientIP(index int) netip.Addr { - // Start from 3.3.3.1 and increment - base := clientBaseIP.As4() - base[3] = byte((int(base[3]) + index) % 256) - if index >= 256 { - base[2] = byte((int(base[2]) + index/256) % 256) - } - return netip.AddrFrom4(base) -} - -// TestScheduling is the main test function that tests all scheduling scenarios -func TestScheduling(t *testing.T) { - // Create balancer configuration with all 8 virtual services - config := createSchedulingTestConfig() - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(128*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 64 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Run all scheduling checks - t.Run("InitialChecks", func(t *testing.T) { - runSchedulingChecks(t, ts, "initial") - }) - - // State restoration test: create new balancer agent and verify it contains the previous balancer - t.Run("StateRestoration", func(t *testing.T) { - testStateRestoration(t, ts) - }) -} - -// runSchedulingChecks runs all scheduling-related subtests -func runSchedulingChecks(t *testing.T, ts *utils.TestSetup, phase string) { - t.Helper() - - // TCP Session Establishment (VS1) - t.Run("TCP_SessionEstablishment", func(t *testing.T) { - testTCPSessionEstablishment(t, ts) - }) - - // UDP Session Establishment (VS2) - t.Run("UDP_SessionEstablishment", func(t *testing.T) { - testUDPSessionEstablishment(t, ts) - }) - - // OPS Mode - No Session Created (VS5) - t.Run("OPS_NoSessionCreated", func(t *testing.T) { - testOPSNoSessionCreated(t, ts) - }) - - // Source Hash - Same Client Same Real (VS1, VS5) - t.Run("SourceHash_SameClientSameReal", func(t *testing.T) { - testSourceHashSameClientSameReal(t, ts) - }) - - // Source Hash with OPS - Same Client Same Real (VS5) - t.Run("SourceHash_OPS_SameClientSameReal", func(t *testing.T) { - testSourceHashOPSSameClientSameReal(t, ts) - }) - - // Round Robin Distribution (VS6 - OPS mode for independent scheduling) - t.Run("RoundRobin_Distribution", func(t *testing.T) { - testRoundRobinDistribution(t, ts) - }) - - // Weight Distribution - Source Hash (VS7) - t.Run("WeightDistribution_SourceHash", func(t *testing.T) { - testWeightDistributionSourceHash(t, ts) - }) - - // Weight Distribution - Round Robin (VS8) - t.Run("WeightDistribution_RoundRobin", func(t *testing.T) { - testWeightDistributionRoundRobin(t, ts) - }) - - // Weight Distribution After Update (VS7, VS8) - t.Run("WeightDistribution_AfterUpdate", func(t *testing.T) { - testWeightDistributionAfterUpdate(t, ts) - }) - - // Disabled Reals - No New Packets (VS1) - t.Run("DisabledReals_NoNewPackets", func(t *testing.T) { - testDisabledRealsNoNewPackets(t, ts) - }) - - // API Output Tests - t.Run("Config_Output", func(t *testing.T) { - testConfigOutput(t, ts) - }) - - t.Run("Info_Output", func(t *testing.T) { - testInfoOutput(t, ts) - }) - - t.Run("Stats_Output", func(t *testing.T) { - testStatsOutput(t, ts) - }) - - t.Run("Graph_Output", func(t *testing.T) { - testGraphOutput(t, ts) - }) - - // PureL3 Tests - t.Run("PureL3_SourceHash_PortIndependence", func(t *testing.T) { - testPureL3SourceHashPortIndependence(t, ts) - }) - - t.Run("PureL3_RoundRobin_Distribution", func(t *testing.T) { - testPureL3RoundRobinDistribution(t, ts) - }) - - t.Run("PureL3_SessionCreation", func(t *testing.T) { - testPureL3SessionCreation(t, ts) - }) - - t.Run("PureL3_UDP_SourceHash", func(t *testing.T) { - testPureL3UDPSourceHash(t, ts) - }) -} - -// testTCPSessionEstablishment verifies that TCP sessions are established -// and packets from the same client go to the same real -func testTCPSessionEstablishment(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - clientIP := generateClientIP(100) - clientPort := uint16(10000) - - // Send multiple TCP packets from the same client - var outputPackets []*framework.PacketInfo - for i := range 5 { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs1IP, - vs1Port, - &layers.TCP{SYN: i == 0, ACK: i > 0}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real - realIP, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True( - t, - allSame, - "all packets from same client should go to same real", - ) - assert.True(t, realIP.IsValid(), "real IP should be valid") - - // Verify session was created - info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - assert.GreaterOrEqual( - t, - info.ActiveSessions, - uint64(1), - "should have at least one active session", - ) -} - -// testUDPSessionEstablishment verifies that UDP sessions are established -// and packets from the same client go to the same real -func testUDPSessionEstablishment(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - clientIP := generateClientIP(200) - clientPort := uint16(20000) - - // Send multiple UDP packets from the same client - var outputPackets []*framework.PacketInfo - for range 5 { - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - vs2IP, - vs2Port, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - require.Empty(t, result.Drop, "expected no dropped packets") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real - realIP, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True( - t, - allSame, - "all packets from same client should go to same real", - ) - assert.True(t, realIP.IsValid(), "real IP should be valid") -} - -// testOPSNoSessionCreated verifies that OPS mode does not create sessions -func testOPSNoSessionCreated(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Get initial session count - initialInfo, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - initialSessions := initialInfo.ActiveSessions - - // Send packets to VS5 (OPS mode) from a new client - clientIP := generateClientIP(300) - clientPort := uint16(30000) - - for i := 0; i < 5; i++ { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs5IP, - vs5Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - } - - // Verify no new sessions were created - finalInfo, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - assert.Equal( - t, - initialSessions, - finalInfo.ActiveSessions, - "OPS mode should not create new sessions", - ) -} - -// testSourceHashSameClientSameReal verifies that source_hash schedules -// packets from the same client to the same real -func testSourceHashSameClientSameReal(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Test with VS1 (TCP + SOURCE_HASH, session-based) - clientIP := generateClientIP(400) - clientPort := uint16(40000) - - var outputPackets []*framework.PacketInfo - for i := 0; i < 10; i++ { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs1IP, - vs1Port, - &layers.TCP{SYN: i == 0, ACK: i > 0}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real - _, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True( - t, - allSame, - "source_hash should send all packets from same client to same real", - ) -} - -// testSourceHashOPSSameClientSameReal verifies that source_hash with OPS -// still schedules based on hash (same IP+port -> same real) -func testSourceHashOPSSameClientSameReal(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Test with VS5 (TCP + SOURCE_HASH + OPS) - clientIP := generateClientIP(500) - clientPort := uint16(50000) - - var outputPackets []*framework.PacketInfo - for range 10 { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs5IP, - vs5Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real (hash-based, not session-based) - _, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True( - t, - allSame, - "source_hash with OPS should send all packets from same client to same real based on hash", - ) -} - -// testRoundRobinDistribution verifies that round_robin distributes packets across reals -func testRoundRobinDistribution(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Test with VS6 (TCP + ROUND_ROBIN + OPS) - OPS mode for independent scheduling - var outputPackets []*framework.PacketInfo - - // Send packets from different clients to trigger round-robin - for i := range 30 { - clientIP := generateClientIP(600 + i) - clientPort := uint16(60000 + i) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs6IP, - vs6Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify packets are distributed across multiple reals - distributed := utils.PacketsDistributedAcrossReals(outputPackets) - assert.True( - t, - distributed, - "round_robin should distribute packets across multiple reals", - ) - - // Count packets per real - counts := utils.CountPacketsPerReal(outputPackets) - assert.GreaterOrEqual( - t, - len(counts), - 2, - "packets should go to at least 2 different reals", - ) -} - -// testWeightDistributionSourceHash verifies that packets are distributed -// proportionally to weights with source_hash scheduler -func testWeightDistributionSourceHash(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Test with VS7 (TCP + SOURCE_HASH + OPS + Weighted 1:2:3) - var outputPackets []*framework.PacketInfo - - // Send many packets from different clients - for i := range 600 { - clientIP := generateClientIP(700 + i) - clientPort := uint16(1000 + (i % 60000)) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs7IP, - vs7Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify weight distribution (1:2:3 = ~16.7%:~33.3%:~50%) - counts := utils.CountPacketsPerReal(outputPackets) - expectedWeights := map[netip.Addr]uint32{ - real1IP: 1, - real2IP: 2, - real3IP: 3, - } - - // Use 15% tolerance for statistical variance - utils.ValidateWeightDistribution(t, counts, expectedWeights, 0.15) -} - -// testWeightDistributionRoundRobin verifies that packets are distributed -// proportionally to weights with round_robin scheduler -func testWeightDistributionRoundRobin(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Test with VS8 (TCP + ROUND_ROBIN + OPS + Weighted 1:2:3) - var outputPackets []*framework.PacketInfo - - // Send many packets from different clients - for i := 0; i < 600; i++ { - clientIP := generateClientIP(800 + i) - clientPort := uint16(2000 + (i % 60000)) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs8IP, - vs8Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify weight distribution (1:2:3 = ~16.7%:~33.3%:~50%) - counts := utils.CountPacketsPerReal(outputPackets) - expectedWeights := map[netip.Addr]uint32{ - real1IP: 1, - real2IP: 2, - real3IP: 3, - } - - // Use 15% tolerance for statistical variance - utils.ValidateWeightDistribution(t, counts, expectedWeights, 0.15) -} - -// testWeightDistributionAfterUpdate verifies that weight distribution -// holds after real weight update -func testWeightDistributionAfterUpdate(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Update weights for VS7 to 3:2:1 (reverse of original 1:2:3) - config := ts.Balancer.Config() - var vs7 *balancerpb.VirtualService - for _, vs := range config.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if vsAddr == vs7IP { - vs7 = vs - break - } - } - require.NotNil(t, vs7, "VS7 should exist") - - // Update weights - newWeight1 := uint32(3) - newWeight2 := uint32(2) - newWeight3 := uint32(1) - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs7.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1IP.AsSlice()}, - Port: 0, - }, - }, - Weight: &newWeight1, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs7.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real2IP.AsSlice()}, - Port: 0, - }, - }, - Weight: &newWeight2, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs7.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real3IP.AsSlice()}, - Port: 0, - }, - }, - Weight: &newWeight3, - }, - } - - _, err := ts.Balancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to update real weights") - - // Send packets and verify new distribution - var outputPackets []*framework.PacketInfo - for i := range 600 { - clientIP := generateClientIP(900 + i) - clientPort := uint16(3000 + (i % 60000)) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs7IP, - vs7Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify new weight distribution (3:2:1 = ~50%:~33.3%:~16.7%) - counts := utils.CountPacketsPerReal(outputPackets) - expectedWeights := map[netip.Addr]uint32{ - real1IP: 3, - real2IP: 2, - real3IP: 1, - } - - // Use 15% tolerance for statistical variance - utils.ValidateWeightDistribution(t, counts, expectedWeights, 0.15) -} - -// testDisabledRealsNoNewPackets verifies that disabled reals do not accept new packets -func testDisabledRealsNoNewPackets(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Get VS1 configuration - config := ts.Balancer.Config() - var vs1 *balancerpb.VirtualService - for _, vs := range config.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if vsAddr == vs1IP { - vs1 = vs - break - } - } - require.NotNil(t, vs1, "VS1 should exist") - - // Disable real1 - enableFalse := false - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs1.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - } - - _, err := ts.Balancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to disable real") - - // Send packets from new clients (use different client IPs to avoid existing sessions) - var outputPackets []*framework.PacketInfo - for i := 0; i < 100; i++ { - clientIP := generateClientIP(1000 + i) - clientPort := uint16(4000 + i) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs1IP, - vs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify no packets went to the disabled real - counts := utils.CountPacketsPerReal(outputPackets) - disabledRealCount := counts[real1IP] - assert.Equal( - t, - 0, - disabledRealCount, - "disabled real should not receive any new packets", - ) - - // Re-enable real1 for subsequent tests - enableTrue := true - reEnableUpdates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: vs1.Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1IP.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - } - _, err = ts.Balancer.UpdateReals(reEnableUpdates, false) - require.NoError(t, err, "failed to re-enable real") -} - -// testConfigOutput verifies the Config() API output -func testConfigOutput(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - config := ts.Balancer.Config() - require.NotNil(t, config, "config should not be nil") - require.NotNil( - t, - config.PacketHandler, - "packet handler config should not be nil", - ) - - // Verify virtual services - assert.Equal( - t, - 12, - len(config.PacketHandler.Vs), - "should have 12 virtual services", - ) - - // Verify state config - require.NotNil(t, config.State, "state config should not be nil") - assert.NotNil( - t, - config.State.SessionTableCapacity, - "session table capacity should be set", - ) - assert.NotNil( - t, - config.State.SessionTableMaxLoadFactor, - "max load factor should be set", - ) - - // Verify sessions timeouts - require.NotNil( - t, - config.PacketHandler.SessionsTimeouts, - "sessions timeouts should not be nil", - ) - assert.Equal( - t, - uint32(60), - config.PacketHandler.SessionsTimeouts.Tcp, - "TCP timeout should be 60", - ) - assert.Equal( - t, - uint32(60), - config.PacketHandler.SessionsTimeouts.Udp, - "UDP timeout should be 60", - ) - - t.Logf("Config verified: %d virtual services, table_capacity=%d", - len(config.PacketHandler.Vs), *config.State.SessionTableCapacity) -} - -// testInfoOutput verifies the Info() API output -func testInfoOutput(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - info, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to get balancer info") - require.NotNil(t, info, "info should not be nil") - - // Verify info fields - assert.GreaterOrEqual( - t, - info.ActiveSessions, - uint64(0), - "active sessions should be non-negative", - ) - require.NotNil(t, info.Vs, "virtual services info should not be nil") - - // Verify VS info - for i, vsInfo := range info.Vs { - require.NotNil(t, vsInfo, "VS info %d should not be nil", i) - assert.GreaterOrEqual( - t, - vsInfo.ActiveSessions, - uint64(0), - "VS %d active sessions should be non-negative", - i, - ) - } - - t.Logf("Info verified: active_sessions=%d, vs_count=%d", - info.ActiveSessions, len(info.Vs)) -} - -// testStatsOutput verifies the Stats() API output -func testStatsOutput(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - - stats, err := ts.Balancer.Stats(statsRef) - require.NoError(t, err, "failed to get balancer stats") - require.NotNil(t, stats, "stats should not be nil") - - // Verify common stats - require.NotNil(t, stats.Common, "common stats should not be nil") - assert.GreaterOrEqual( - t, - stats.Common.IncomingPackets, - uint64(0), - "incoming packets should be non-negative", - ) - assert.GreaterOrEqual( - t, - stats.Common.IncomingBytes, - uint64(0), - "incoming bytes should be non-negative", - ) - assert.GreaterOrEqual( - t, - stats.Common.OutgoingPackets, - uint64(0), - "outgoing packets should be non-negative", - ) - assert.GreaterOrEqual( - t, - stats.Common.OutgoingBytes, - uint64(0), - "outgoing bytes should be non-negative", - ) - - // Verify L4 stats - require.NotNil(t, stats.L4, "L4 stats should not be nil") - assert.GreaterOrEqual( - t, - stats.L4.IncomingPackets, - uint64(0), - "L4 incoming packets should be non-negative", - ) - assert.GreaterOrEqual( - t, - stats.L4.OutgoingPackets, - uint64(0), - "L4 outgoing packets should be non-negative", - ) - - // Verify ICMP stats - require.NotNil(t, stats.Icmpv4, "ICMPv4 stats should not be nil") - require.NotNil(t, stats.Icmpv6, "ICMPv6 stats should not be nil") - - t.Logf( - "Stats verified: incoming_packets=%d, incoming_bytes=%d, outgoing_packets=%d, outgoing_bytes=%d", - stats.Common.IncomingPackets, - stats.Common.IncomingBytes, - stats.Common.OutgoingPackets, - stats.Common.OutgoingBytes, - ) -} - -// testGraphOutput verifies the Graph() API output -func testGraphOutput(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - graph := ts.Balancer.Graph() - require.NotNil(t, graph, "graph should not be nil") - require.NotNil( - t, - graph.VirtualServices, - "virtual services should not be nil", - ) - - // Verify number of virtual services - assert.Equal( - t, - 12, - len(graph.VirtualServices), - "should have 12 virtual services in graph", - ) - - // Verify each virtual service - for i, vs := range graph.VirtualServices { - require.NotNil(t, vs, "virtual service %d should not be nil", i) - require.NotNil(t, vs.Identifier, "VS %d should have identifier", i) - require.NotNil(t, vs.Reals, "VS %d should have reals", i) - assert.Equal(t, 3, len(vs.Reals), "VS %d should have 3 reals", i) - - // Verify each real - for j, real := range vs.Reals { - require.NotNil(t, real, "real %d of VS %d should not be nil", j, i) - require.NotNil( - t, - real.Identifier, - "real %d of VS %d should have identifier", - j, - i, - ) - assert.GreaterOrEqual( - t, - real.Weight, - uint32(0), - "real %d of VS %d weight should be non-negative", - j, - i, - ) - assert.GreaterOrEqual( - t, - real.EffectiveWeight, - uint32(0), - "real %d of VS %d effective weight should be non-negative", - j, - i, - ) - } - } - - t.Logf( - "Graph verified: %d virtual services with 3 reals each", - len(graph.VirtualServices), - ) -} - -// testStateRestoration verifies that creating a new balancer agent -// restores the previous balancer state from shared memory -func testStateRestoration(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Create a logger for the new agent - logLevel := zapcore.InfoLevel - logger, _, err := logging.Init(&logging.Config{ - Level: logLevel, - }) - require.NoError(t, err, "failed to create logger") - - // Create a new balancer agent using the same shared memory - // This should restore the existing balancer from shared memory - newAgent, err := balancer.NewBalancerAgent( - ts.Mock.SharedMemory(), - 4*datasize.MB, - logger, - ) - require.NoError(t, err, "failed to create new balancer agent") - - // Verify the balancer exists in the new agent - managers := newAgent.Managers() - assert.Contains( - t, - managers, - utils.BalancerName, - "new agent should contain the existing balancer", - ) - - // Get the balancer manager from the new agent - newBalancer, err := newAgent.BalancerManager(utils.BalancerName) - require.NoError(t, err, "failed to get balancer from new agent") - - // Update test setup to use new agent and balancer - ts.Agent = newAgent - ts.Balancer = newBalancer - - // Run all scheduling checks again to verify they work after state restoration - t.Run("AfterRestoreChecks", func(t *testing.T) { - // Re-enable all reals (in case any were disabled) - utils.EnableAllReals(t, ts) - - // Run a subset of checks to verify state restoration - t.Run("TCP_SessionEstablishment", func(t *testing.T) { - testTCPSessionEstablishmentAfterRestore(t, ts) - }) - - t.Run("SourceHash_Consistency", func(t *testing.T) { - testSourceHashConsistencyAfterRestore(t, ts) - }) - - t.Run("RoundRobin_Distribution", func(t *testing.T) { - testRoundRobinDistributionAfterRestore(t, ts) - }) - }) -} - -// testTCPSessionEstablishmentAfterRestore verifies TCP session establishment after state restoration -func testTCPSessionEstablishmentAfterRestore( - t *testing.T, - ts *utils.TestSetup, -) { - t.Helper() - - clientIP := generateClientIP(1100) - clientPort := uint16(11000) - - var outputPackets []*framework.PacketInfo - for i := range 5 { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs1IP, - vs1Port, - &layers.TCP{SYN: i == 0, ACK: i > 0}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real - _, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True( - t, - allSame, - "all packets from same client should go to same real after restore", - ) -} - -// testSourceHashConsistencyAfterRestore verifies source hash consistency after state restoration -func testSourceHashConsistencyAfterRestore(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - clientIP := generateClientIP(1200) - clientPort := uint16(12000) - - var outputPackets []*framework.PacketInfo - for range 10 { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs5IP, - vs5Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify all packets went to the same real (hash-based) - _, allSame := utils.AllPacketsToSameReal(outputPackets) - assert.True(t, allSame, "source_hash should be consistent after restore") -} - -// testRoundRobinDistributionAfterRestore verifies round robin distribution after state restoration -func testRoundRobinDistributionAfterRestore(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - var outputPackets []*framework.PacketInfo - for i := range 30 { - clientIP := generateClientIP(1300 + i) - clientPort := uint16(13000 + i) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs6IP, - vs6Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify packets are distributed across multiple reals - distributed := utils.PacketsDistributedAcrossReals(outputPackets) - assert.True( - t, - distributed, - "round_robin should distribute packets after restore", - ) -} - -// testPureL3SourceHashPortIndependence verifies that PureL3 mode accepts packets -// on any destination port and schedules based on destination port -func testPureL3SourceHashPortIndependence(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - clientIP := generateClientIP(1400) - clientPort := uint16(14000) - - // Test 1: Packets to the same destination port should go to the same real - var sameDstPortPackets []*framework.PacketInfo - dstPort1 := uint16(8080) - for i := 0; i < 5; i++ { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs9IP, - dstPort1, // Same destination port - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - sameDstPortPackets = append(sameDstPortPackets, result.Output[0]) - } - - // Verify all packets to same dst port went to the same real - realIP1, allSame := utils.AllPacketsToSameReal(sameDstPortPackets) - assert.True( - t, - allSame, - "PureL3 SOURCE_HASH should send all packets to same dst port to same real", - ) - assert.True(t, realIP1.IsValid(), "real IP should be valid") - - // Test 2: Packets to different destination ports can go to different reals - var differentDstPortPackets []*framework.PacketInfo - for i := 0; i < 10; i++ { - dstPort := uint16(9000 + i) // Different destination ports - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs9IP, - dstPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - differentDstPortPackets = append( - differentDstPortPackets, - result.Output[0], - ) - } - - // Verify packets to different dst ports can be distributed - counts := utils.CountPacketsPerReal(differentDstPortPackets) - t.Logf("PureL3 distribution across different dst ports: %v", counts) - // We don't assert distribution here as it depends on hash function, - // but we verify that packets were accepted on different ports -} - -// testPureL3RoundRobinDistribution verifies that PureL3 mode with ROUND_ROBIN -// distributes packets across reals -func testPureL3RoundRobinDistribution(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - var outputPackets []*framework.PacketInfo - - // Send packets from different client IPs - for i := range 30 { - clientIP := generateClientIP(1500 + i) - clientPort := uint16(15000) - - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs10IP, - vs10Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - outputPackets = append(outputPackets, result.Output[0]) - } - - // Verify packets are distributed across multiple reals - distributed := utils.PacketsDistributedAcrossReals(outputPackets) - assert.True( - t, - distributed, - "PureL3 ROUND_ROBIN should distribute packets across multiple reals", - ) - - // Count packets per real - counts := utils.CountPacketsPerReal(outputPackets) - assert.GreaterOrEqual( - t, - len(counts), - 2, - "packets should go to at least 2 different reals", - ) -} - -// testPureL3SessionCreation verifies that sessions are created correctly in PureL3 mode -// Sessions should be based on client IP + client port + dst port combination -func testPureL3SessionCreation(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - // Get initial session count - initialInfo, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - initialSessions := initialInfo.ActiveSessions - - clientIP := generateClientIP(1600) - clientPort := uint16(16000) - dstPort := uint16(8080) - - // Send multiple packets with same client IP, client port, and dst port - for i := 0; i < 5; i++ { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vs9IP, - dstPort, // Same dst port - &layers.TCP{SYN: i == 0, ACK: i > 0}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - } - - // Verify sessions were created - finalInfo, err := ts.Balancer.Info(ts.Mock.CurrentTime()) - require.NoError(t, err) - - // In PureL3 mode, sessions should be created based on the flow - assert.Greater( - t, - finalInfo.ActiveSessions, - initialSessions, - "PureL3 mode should create sessions", - ) -} - -// testPureL3UDPSourceHash verifies that PureL3 mode works with UDP and SOURCE_HASH -// Packets to the same destination port should go to the same real -func testPureL3UDPSourceHash(t *testing.T, ts *utils.TestSetup) { - t.Helper() - - clientIP := generateClientIP(1700) - clientPort := uint16(17000) - - // Test 1: Packets to the same destination port should go to the same real - var sameDstPortPackets []*framework.PacketInfo - dstPort1 := uint16(5353) - for i := 0; i < 5; i++ { - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - vs11IP, - dstPort1, // Same destination port - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - sameDstPortPackets = append(sameDstPortPackets, result.Output[0]) - } - - // Verify all packets to same dst port went to the same real - realIP, allSame := utils.AllPacketsToSameReal(sameDstPortPackets) - assert.True( - t, - allSame, - "PureL3 SOURCE_HASH with UDP should send all packets to same dst port to same real", - ) - assert.True(t, realIP.IsValid(), "real IP should be valid") - - // Test 2: Verify PureL3 accepts packets on different destination ports - var differentDstPortPackets []*framework.PacketInfo - for i := 0; i < 10; i++ { - dstPort := uint16(6000 + i) // Different destination ports - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - vs11IP, - dstPort, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output), "expected 1 output packet") - differentDstPortPackets = append( - differentDstPortPackets, - result.Output[0], - ) - } - - // Verify packets were accepted on different ports - counts := utils.CountPacketsPerReal(differentDstPortPackets) - t.Logf("PureL3 UDP distribution across different dst ports: %v", counts) -} diff --git a/modules/balancer/tests/go/st_basic_test.go b/modules/balancer/tests/go/st_basic_test.go deleted file mode 100644 index a9a37fb74..000000000 --- a/modules/balancer/tests/go/st_basic_test.go +++ /dev/null @@ -1,1656 +0,0 @@ -package balancer_test - -// TestSessionTableManual validates session table behavior with comprehensive testing of: -// -// # Session Creation and Persistence -// - Creating multiple sessions with random client IPs and ports -// - Verifying sessions remain active when refreshed within timeout period -// - Testing session table resizing with active sessions -// -// # Session Timeout Validation -// - Sending packets to existing sessions to refresh their timeout -// - Advancing time and verifying sessions expire after configured timeout -// - Testing that new sessions can be created alongside existing ones -// -// # Load Testing and Capacity Management -// - Creating 256 sessions and verifying at least 80% acceptance rate -// - Testing dynamic session table resizing (256 → 1024 → 300 capacity) -// - Verifying session persistence across table resize operations -// - Validating expired sessions are properly removed while active ones persist -// -// TestSessionTimeouts validates different session timeout types work correctly: -// -// # Configuration -// - Two virtual services: TCP (1.1.1.1:80) and UDP (2.2.2.2:53) -// - Different timeout values: UDP=30s, TCP=60s, TCP_SYN=20s, TCP_SYN_ACK=25s, TCP_FIN=15s -// -// # Timeout Validation Tests -// - UDP Session Timeout: Verifies UDP sessions expire at 30 seconds -// - TCP SYN Timeout: Verifies TCP SYN sessions expire at 20 seconds -// - TCP SYN-ACK Timeout: Verifies timeout switches to 25s after SYN-ACK packet -// - TCP Basic Timeout: Verifies timeout switches from SYN (20s) to TCP (60s) after regular packet -// - TCP FIN Timeout: Verifies timeout switches to 15s after FIN packet -// -// # Validation Pattern -// Each test follows the pattern: -// - Send packet(s) to create session -// - Verify session exists -// - Advance time by (timeout - 1) seconds -// - Verify session still persists -// - Advance time by 1 second (reaching exact timeout) -// - Verify session has expired - -import ( - "math/rand" - "net/netip" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -// sessionKey represents a unique session identifier -type sessionKey struct { - ip netip.Addr - port uint16 -} - -// checkActiveSessions verifies active sessions match expected sessions -func checkActiveSessions( - t *testing.T, - ts *utils.TestSetup, - currentTime time.Time, - expectedSessions []sessionKey, - vsIP netip.Addr, - vsPort uint16, - realAddr netip.Addr, -) { - t.Helper() - - expectedCount := uint64(len(expectedSessions)) - - // Get sessions info - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - assert.Equal( - t, - int(expectedCount), - len(sessions), - "sessions list should have %d entries", - expectedCount, - ) - - // Create a map of expected sessions for validation - expectedSessionsMap := make(map[sessionKey]bool) - for _, session := range expectedSessions { - expectedSessionsMap[session] = true - } - - // Verify each session has correct properties - for i, session := range sessions { - // Verify VS identifier - vsAddr, _ := netip.AddrFromSlice(session.RealId.Vs.Addr.Bytes) - assert.Equal( - t, - vsIP, - vsAddr, - "session %d: VS IP should match", - i, - ) - assert.Equal( - t, - uint32(vsPort), - session.RealId.Vs.Port, - "session %d: VS port should match", - i, - ) - - // Verify Real identifier - realIP, _ := netip.AddrFromSlice(session.RealId.Real.Ip.Bytes) - assert.Equal( - t, - realAddr, - realIP, - "session %d: Real IP should match", - i, - ) - - // Verify client IP and port match one of our expected sessions - clientAddr, _ := netip.AddrFromSlice(session.ClientAddr.Bytes) - clientKey := sessionKey{ - ip: clientAddr, - port: uint16(session.ClientPort), - } - assert.True( - t, - expectedSessionsMap[clientKey], - "session %d: client %v:%d should be in expected sessions", - i, - clientAddr, - session.ClientPort, - ) - - // Delete session to not match same session twice. - // In this way, we check sessions are unique - delete(expectedSessionsMap, clientKey) - } - - assert.Empty( - t, - expectedSessionsMap, - "%d expected sessions not found", - len(expectedSessionsMap), - ) - - // Get info to verify VS and Real active sessions - info, err := ts.Balancer.Info(currentTime) - require.NoError(t, err) - - // Verify module active sessions - assert.Equal( - t, - expectedCount, - info.ActiveSessions, - "module should have %d active sessions", - expectedCount, - ) - - // Verify VS active sessions - require.Equal(t, 1, len(info.Vs), "should have exactly one VS") - assert.Equal( - t, - expectedCount, - info.Vs[0].ActiveSessions, - "VS should have %d active sessions", - expectedCount, - ) - - // Verify Real active sessions - require.Equal(t, 1, len(info.Vs[0].Reals), "should have exactly one Real") - assert.Equal( - t, - expectedCount, - info.Vs[0].Reals[0].ActiveSessions, - "Real should have %d active sessions", - expectedCount, - ) -} - -// TestSessionTableManual tests session table behavior with timeouts and resizing. -// It creates sessions, verifies they stay active after resizing or when refreshed -// within timeout, and tests that new sessions can be created alongside existing ones. -func TestSessionTableManual(t *testing.T) { - vsIP := netip.MustParseAddr("1.1.1.1") - vsPort := uint16(80) - realAddr := netip.MustParseAddr("2.2.2.2") - - sessionTimeout := 60 // in seconds - initialCapacity := 16 - maxLoadFactor := 0.5 - - // Configure balancer with single VS and single real - moduleConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIP.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2.2.2.2").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: uint32(sessionTimeout), - TcpSyn: uint32(sessionTimeout), - TcpFin: uint32(sessionTimeout), - Tcp: uint32(sessionTimeout), - Udp: uint32(sessionTimeout), - Default: uint32(sessionTimeout), - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(initialCapacity); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(maxLoadFactor); return &v }(), - RefreshPeriod: durationpb.New( - 0, - ), // do not update in background - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(128*datasize.MB, 4*datasize.MB), - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 32 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - mock := ts.Mock - - // Set initial time - mock.SetCurrentTime(time.Unix(0, 0)) - - rng := rand.New(rand.NewSource(42)) - - // Helper to generate random client IP in 10.x.x.x range - randomClientIP := func() netip.Addr { - return netip.AddrFrom4([4]byte{ - 10, - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - }) - } - - // Helper to generate random port - randomPort := func() uint16 { - return uint16(1024 + rng.Intn(64511)) // 1024-65535 - } - - // Track session keys (srcIP, srcPort) - sessions := make([]sessionKey, 0, 10) - - // Phase 1: Create 10 random sessions with TCP SYN packets - t.Run("Phase1_Create_10_Sessions", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, 10) - for range 10 { - srcIP := randomClientIP() - srcPort := randomPort() - sessions = append(sessions, sessionKey{ip: srcIP, port: srcPort}) - - packetLayers := utils.MakeTCPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify all packets are in output - assert.Equal( - t, - 10, - len(result.Output), - "all 10 packets should be in output", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Verify each output packet is properly encapsulated - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } - - t.Logf( - "Created 10 sessions successfully, all packets properly encapsulated", - ) - }) - - // Advance time by 30 seconds - t.Run("Advance_Time_30s", func(t *testing.T) { - newTime := mock.AdvanceTime(30 * time.Second) - t.Logf("Advanced time to %v (30s elapsed)", newTime) - }) - - // Phase 2: Send TCP non-SYN packets to the same sessions - t.Run("Phase2_Send_NonSYN_To_Same_Sessions", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, 10) - for _, session := range sessions { - packetLayers := utils.MakeTCPPacket( - session.ip, - session.port, - vsIP, - vsPort, - &layers.TCP{}, // No SYN flag - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify packets are not dropped (sessions still valid) - assert.Equal( - t, - 10, - len(result.Output), - "all packets should be in output", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Verify each output packet is properly encapsulated - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } - - t.Logf( - "Sent non-SYN packets to 10 sessions, all accepted and properly encapsulated", - ) - }) - - // Resize session table and get active sessions - t.Run("Resize_And_Verify_Active_Sessions", func(t *testing.T) { - currentTime := mock.CurrentTime() - - // Sync active sessions and resize table on demand - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Check active sessions using helper function - checkActiveSessions( - t, - ts, - currentTime, - sessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified 10 active sessions for VS and Real with correct identifiers", - ) - }) - - // Advance time by 40 seconds (total 70s from start, 40s from last packet) - t.Run("Advance_Time_40s", func(t *testing.T) { - newTime := mock.AdvanceTime(40 * time.Second) - t.Logf( - "Advanced time to %v (40s elapsed, sessions at 40s age)", - newTime, - ) - }) - - // Phase 3: Send packets to the same sessions again - t.Run("Phase3_Send_Packets_Again", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, 10) - for _, session := range sessions { - packetLayers := utils.MakeTCPPacket( - session.ip, - session.port, - vsIP, - vsPort, - &layers.TCP{}, // No SYN flag - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify packets are not dropped (sessions still valid at 40s age) - assert.Equal( - t, - 10, - len(result.Output), - "all packets should be in output", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Verify each output packet is properly encapsulated - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } - - // Verify sessions are still active - currentTime := mock.CurrentTime() - err = ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - assert.Equal( - t, - 10, - len(sessions), - "should still have 10 active sessions", - ) - - t.Logf( - "Sent packets to 10 sessions again, all accepted, properly encapsulated, and sessions refreshed", - ) - }) - - // Advance time by 30 seconds - t.Run("Advance_Time_30s_Again", func(t *testing.T) { - newTime := mock.AdvanceTime(30 * time.Second) - t.Logf("Advanced time to %v (30s elapsed)", newTime) - }) - - // Phase 4: Create 10 new sessions and send packets to old sessions - t.Run("Phase4_Create_New_And_Refresh_Old_Sessions", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, 20) - - // Create 10 new sessions with TCP SYN packets - newSessions := make([]sessionKey, 0, 10) - for range 10 { - srcIP := randomClientIP() - srcPort := randomPort() - newSessions = append( - newSessions, - sessionKey{ip: srcIP, port: srcPort}, - ) - - packetLayers := utils.MakeTCPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - // Send packets to old sessions to refresh them - for _, session := range sessions { - packetLayers := utils.MakeTCPPacket( - session.ip, - session.port, - vsIP, - vsPort, - &layers.TCP{}, // No SYN flag - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify all packets are in output - assert.Equal( - t, - 20, - len(result.Output), - "all 20 packets should be in output", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Verify each output packet is properly encapsulated - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } - - // Append new sessions to the sessions list for final verification - sessions = append(sessions, newSessions...) - - t.Logf( - "Created 10 new sessions and refreshed 10 old sessions, all packets properly encapsulated", - ) - }) - - // Advance time and verify all sessions are active - t.Run("Verify_All_20_Sessions_Active", func(t *testing.T) { - currentTime := mock.CurrentTime() - - // Sync active sessions - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Check active sessions using helper function - checkActiveSessions( - t, - ts, - currentTime, - sessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified all 20 sessions are active (10 new + 10 old refreshed) with correct client IPs and ports", - ) - }) - - // Keep only new sessions for expiration testing - newSessions := sessions[10:] // Last 10 sessions are the new ones - - // Phase 5: Test session expiration - t.Run("Phase5_Advance_30s_Send_To_New_Sessions_Only", func(t *testing.T) { - // Advance time by 30 seconds (old sessions at 60s age, should expire) - newTime := mock.AdvanceTime(30 * time.Second) - t.Logf( - "Advanced time to %v (30s elapsed, old sessions at 60s age)", - newTime, - ) - - // Send packets only to new sessions - packets := make([]gopacket.Packet, 0, 10) - for _, session := range newSessions { - packetLayers := utils.MakeTCPPacket( - session.ip, - session.port, - vsIP, - vsPort, - &layers.TCP{}, // No SYN flag - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify all packets are in output - assert.Equal( - t, - 10, - len(result.Output), - "all 10 packets should be in output", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Verify each output packet is properly encapsulated - for i, outPacket := range result.Output { - utils.ValidatePacket(t, ts.Balancer.Config(), packets[i], outPacket) - } - - // Check active sessions immediately after sending packets - currentTime := mock.CurrentTime() - checkActiveSessions( - t, - ts, - currentTime, - sessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Sent packets to 10 new sessions only, verified all sessions still active", - ) - }) - - t.Run( - "Phase5_Advance_30s_Check_Only_New_Sessions_Active", - func(t *testing.T) { - // Advance time by 30 seconds (new sessions at 30s age) - newTime := mock.AdvanceTime(30 * time.Second) - t.Logf( - "Advanced time to %v (30s elapsed, new sessions at 30s age)", - newTime, - ) - - currentTime := mock.CurrentTime() - - // Check active sessions WITHOUT resizing - checkActiveSessions( - t, - ts, - currentTime, - newSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf("Verified only 10 new sessions are active at 30s age") - }, - ) - - t.Run( - "Phase5_Resize_And_Verify_New_Sessions_Still_Active", - func(t *testing.T) { - currentTime := mock.CurrentTime() - - // Resize and sync active sessions - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Check active sessions after resize - checkActiveSessions( - t, - ts, - currentTime, - newSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf("Verified 10 new sessions still active after resize") - }, - ) - - t.Run( - "Phase5_Advance_29s_Verify_Sessions_Still_Active", - func(t *testing.T) { - // Advance time by 29 seconds (new sessions at 59s age, still valid) - newTime := mock.AdvanceTime(29 * time.Second) - t.Logf( - "Advanced time to %v (29s elapsed, new sessions at 59s age)", - newTime, - ) - - currentTime := mock.CurrentTime() - - // Check active sessions - checkActiveSessions( - t, - ts, - currentTime, - newSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified 10 new sessions still active at 59s age (< 60s timeout)", - ) - }, - ) - - t.Run("Phase5_Advance_1s_Verify_Sessions_Expired", func(t *testing.T) { - // Advance time by 1 second (new sessions at 60s age, should expire) - newTime := mock.AdvanceTime(1 * time.Second) - t.Logf( - "Advanced time to %v (1s elapsed, new sessions at 60s age - expired)", - newTime, - ) - - currentTime := mock.CurrentTime() - - // Check active sessions - should be 0 - checkActiveSessions( - t, - ts, - currentTime, - []sessionKey{}, - vsIP, - vsPort, - realAddr, - ) - - t.Logf("Verified all sessions expired after 60s timeout") - }) - - // Phase 6: Comprehensive session table test with 256 sessions, load testing, and resizing - var phase6OldSessions []sessionKey - var phase6NewSessions []sessionKey - - t.Run("Phase6_Manual_Resize_Session_Table", func(t *testing.T) { - now := mock.CurrentTime() - newCapacity := uint64(256) - newMaxLoadFactor := float32(0.5) - - config := ts.Balancer.Config() - config.State.SessionTableCapacity = &newCapacity - config.State.SessionTableMaxLoadFactor = &newMaxLoadFactor - - _, err := ts.Balancer.Update(config, now) - require.NoError(t, err, "failed to update config") - - t.Logf("Resized session table to capacity 256") - }) - - t.Run("Phase6_Send_256_Sessions_Check_80_Percent", func(t *testing.T) { - // Generate 256 unique packets and store session keys - packets := make([]gopacket.Packet, 0, 256) - sessionToPacket := make(map[sessionKey]gopacket.Packet) - - for range 256 { - srcIP := randomClientIP() - srcPort := randomPort() - session := sessionKey{ip: srcIP, port: srcPort} - - packetLayers := utils.MakeTCPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - pkt := xpacket.LayersToPacket(t, packetLayers...) - packets = append(packets, pkt) - sessionToPacket[session] = pkt - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Verify at least 80% accepted (205 out of 256) - acceptedCount := len(result.Output) - require.GreaterOrEqual( - t, - acceptedCount, - 205, - "at least 80%% of packets should be accepted", - ) - - t.Logf( - "Sent 256 packets, %d accepted (%.1f%%)", - acceptedCount, - float64(acceptedCount)*100/256, - ) - - // Extract accepted session keys from output packets - acceptedSessions := make([]sessionKey, 0, acceptedCount) - for _, outPacket := range result.Output { - // Get inner packet to extract session key - if outPacket.InnerPacket == nil { - t.Fatal("output packet has no inner packet") - } - - innerIP, ok := netip.AddrFromSlice(outPacket.InnerPacket.SrcIP) - if !ok { - t.Fatalf( - "failed to parse inner packet source IP: %v", - outPacket.InnerPacket.SrcIP, - ) - } - - port := outPacket.SrcPort - session := sessionKey{ip: innerIP, port: port} - - // Find the original packet for validation - originalPacket, ok := sessionToPacket[session] - if !ok { - t.Errorf( - "could not find original packet for session %v:%d", - innerIP, - port, - ) - } - - // Validate the packet - utils.ValidatePacket( - t, - ts.Balancer.Config(), - originalPacket, - outPacket, - ) - - acceptedSessions = append(acceptedSessions, session) - } - - // Store for later phases - phase6OldSessions = acceptedSessions - - t.Logf("Extracted %d accepted session keys", len(acceptedSessions)) - }) - - t.Run("Phase6_Check_Accepted_Sessions_Active", func(t *testing.T) { - currentTime := mock.CurrentTime() - - // Sync active sessions - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Check only accepted sessions are active - checkActiveSessions( - t, - ts, - currentTime, - phase6OldSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified %d accepted sessions are active", - len(phase6OldSessions), - ) - }) - - t.Run("Phase6_Advance_59s_Resize_To_1024", func(t *testing.T) { - // Advance time by 59 seconds (old sessions at 59s age) - newTime := mock.AdvanceTime(59 * time.Second) - t.Logf( - "Advanced time to %v (59s elapsed, old sessions at 59s age)", - newTime, - ) - - // Resize table to 1024 - now := mock.CurrentTime() - newCapacity := uint64(1024) - newMaxLoadFactor := float32(0.5) - - config := ts.Balancer.Config() - config.State.SessionTableCapacity = &newCapacity - config.State.SessionTableMaxLoadFactor = &newMaxLoadFactor - - _, err := ts.Balancer.Update(config, now) - require.NoError(t, err, "failed to resize session table to 1024") - - t.Logf("Resized session table to 1024") - }) - - t.Run("Phase6_Send_256_More_Sessions", func(t *testing.T) { - // Generate 256 new unique packets and store session keys - packets := make([]gopacket.Packet, 0, 256) - sessionToPacket := make(map[sessionKey]gopacket.Packet) - - for range 256 { - srcIP := randomClientIP() - srcPort := randomPort() - session := sessionKey{ip: srcIP, port: srcPort} - - packetLayers := utils.MakeTCPPacket( - srcIP, - srcPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - pkt := xpacket.LayersToPacket(t, packetLayers...) - packets = append(packets, pkt) - sessionToPacket[session] = pkt - } - - result, err := mock.HandlePackets(packets...) - require.NoError(t, err) - - // Extract session keys from output packets - newSessions := make([]sessionKey, 0, 256) - for _, outPacket := range result.Output { - // Get inner packet to extract session key - if outPacket.InnerPacket == nil { - t.Fatal("output packet has no inner packet") - } - - innerIP, ok := netip.AddrFromSlice(outPacket.InnerPacket.SrcIP) - if !ok { - t.Fatalf( - "failed to parse inner packet source IP: %v", - outPacket.InnerPacket.SrcIP, - ) - } - - port := outPacket.SrcPort - session := sessionKey{ip: innerIP, port: port} - - // Find the original packet for validation - originalPacket, ok := sessionToPacket[session] - if !ok { - t.Errorf( - "could not find original packet for session %v:%d", - innerIP, - port, - ) - } - - // Validate the packet - utils.ValidatePacket( - t, - ts.Balancer.Config(), - originalPacket, - outPacket, - ) - - newSessions = append(newSessions, session) - } - - // Store for later phases - phase6NewSessions = newSessions - - t.Logf("Sent new sessions, all accepted") - }) - - t.Run("Phase6_Check_All_Sessions_Active", func(t *testing.T) { - currentTime := mock.CurrentTime() - - // Sync active sessions - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Verify both old and new sessions are active - allSessions := append( - []sessionKey{}, - phase6OldSessions..., - ) - allSessions = append(allSessions, phase6NewSessions...) - - checkActiveSessions( - t, - ts, - currentTime, - allSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified all %d sessions are active (%d old + %d new)", - len(allSessions), - len(phase6OldSessions), - len(phase6NewSessions), - ) - }) - - t.Run("Phase6_Advance_1s_Check_Old_Expired", func(t *testing.T) { - // Advance time by 1 second (old sessions now at 60s age - expired) - newTime := mock.AdvanceTime(1 * time.Second) - t.Logf( - "Advanced time to %v (1s elapsed, old sessions at 60s age - expired)", - newTime, - ) - - currentTime := mock.CurrentTime() - - // Only new sessions should be active - checkActiveSessions( - t, - ts, - currentTime, - phase6NewSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Verified old sessions expired, %d new sessions still active", - len(phase6NewSessions), - ) - }) - - t.Run("Phase6_Resize_To_300_Check_New_Active", func(t *testing.T) { - now := mock.CurrentTime() - - // Resize table back to 300 - newCapacity := uint64(300) - newMaxLoadFactor := float32(0.5) - - config := ts.Balancer.Config() - config.State.SessionTableCapacity = &newCapacity - config.State.SessionTableMaxLoadFactor = &newMaxLoadFactor - - _, err := ts.Balancer.Update(config, now) - require.NoError(t, err, "failed to resize session table to 300") - - require.LessOrEqual( - t, - uint64(300), - *ts.Balancer.Config().State.SessionTableCapacity, - ) - require.GreaterOrEqual( - t, - uint64(512), - *ts.Balancer.Config().State.SessionTableCapacity, - ) - - currentTime := mock.CurrentTime() - - // Verify new sessions still active after resize - checkActiveSessions( - t, - ts, - currentTime, - phase6NewSessions, - vsIP, - vsPort, - realAddr, - ) - - t.Logf( - "Resized table to 300, verified %d new sessions still active and old sessions are expired", - len(phase6NewSessions), - ) - }) -} - -// checkSessionsForVS verifies active sessions for a specific virtual service -func checkSessionsForVS( - t *testing.T, - ts *utils.TestSetup, - currentTime time.Time, - expectedCount int, - vsIP netip.Addr, - vsPort uint16, -) { - t.Helper() - - // Get sessions info - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - - // Count sessions for this VS - vsSessionCount := 0 - for _, session := range sessions { - sessionVsAddr, _ := netip.AddrFromSlice(session.RealId.Vs.Addr.Bytes) - sessionVsPort := uint16(session.RealId.Vs.Port) - if sessionVsAddr == vsIP && sessionVsPort == vsPort { - vsSessionCount++ - } - } - - assert.Equal( - t, - expectedCount, - vsSessionCount, - "VS %v:%d should have %d sessions", - vsIP, - vsPort, - expectedCount, - ) - - // Get info to verify VS active sessions - info, err := ts.Balancer.Info(currentTime) - require.NoError(t, err) - - // Find the VS in info and verify its active sessions - for _, vsInfo := range info.Vs { - vsAddr, _ := netip.AddrFromSlice(vsInfo.Id.Addr.Bytes) - vsInfoPort := uint16(vsInfo.Id.Port) - if vsAddr == vsIP && vsInfoPort == vsPort { - assert.Equal( - t, - uint64(expectedCount), - vsInfo.ActiveSessions, - "VS %v:%d active sessions should match", - vsIP, - vsPort, - ) - // Also verify Real active sessions sum - totalRealSessions := uint64(0) - for _, realInfo := range vsInfo.Reals { - totalRealSessions += realInfo.ActiveSessions - } - assert.Equal( - t, - uint64(expectedCount), - totalRealSessions, - "VS %v:%d total real sessions should match", - vsIP, - vsPort, - ) - return - } - } - - if expectedCount > 0 { - t.Errorf("VS %v:%d not found in info", vsIP, vsPort) - } -} - -// TestSessionTimeouts verifies that different session timeout types work correctly. -// It creates two virtual services (TCP and UDP) with different timeout configurations -// and validates that sessions expire at the correct time based on their type: -// - UDP sessions use UDP timeout (30s) -// - TCP sessions use different timeouts based on packet flags: -// - TCP_SYN timeout (20s) for SYN packets -// - TCP_SYN_ACK timeout (25s) after SYN-ACK packets -// - TCP_FIN timeout (15s) after FIN packets -// - TCP timeout (60s) for established connections -// -// Each test verifies the session persists at timeout-1 and expires at timeout. -func TestSessionTimeouts(t *testing.T) { - tcpVsIP := netip.MustParseAddr("1.1.1.1") - tcpVsPort := uint16(80) - tcpRealAddr := netip.MustParseAddr("10.2.2.2") - - udpVsIP := netip.MustParseAddr("2.2.2.2") - udpVsPort := uint16(5353) - udpRealAddr := netip.MustParseAddr("10.3.3.3") - - // Different timeout values to verify correct timeout is applied - udpTimeout := 30 // seconds - tcpTimeout := 60 // seconds - tcpSynTimeout := 20 // seconds - tcpSynAckTimeout := 25 // seconds - tcpFinTimeout := 15 // seconds - - // Configure balancer with two virtual services (TCP and UDP) - moduleConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // TCP Virtual Service - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: tcpVsIP.AsSlice(), - }, - Port: uint32(tcpVsPort), - Proto: balancerpb.TransportProto_TCP, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: tcpRealAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: tcpRealAddr.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - // UDP Virtual Service - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: udpVsIP.AsSlice(), - }, - Port: uint32(udpVsPort), - Proto: balancerpb.TransportProto_UDP, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: udpRealAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: udpRealAddr.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: uint32(tcpSynAckTimeout), - TcpSyn: uint32(tcpSynTimeout), - TcpFin: uint32(tcpFinTimeout), - Tcp: uint32(tcpTimeout), - Udp: uint32(udpTimeout), - Default: uint32(tcpTimeout), - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(64); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New( - 0, - ), // do not update in background - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(128*datasize.MB, 4*datasize.MB), - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 32 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - mock := ts.Mock - - // Set initial time - mock.SetCurrentTime(time.Unix(0, 0)) - - // Test 1: UDP Session Timeout - t.Run("UDP_Session_Timeout", func(t *testing.T) { - clientIP := netip.MustParseAddr("10.1.1.1") - clientPort := uint16(5000) - - // Send first UDP packet - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - udpVsIP, - udpVsPort, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "first UDP packet should be accepted", - ) - - // Send second UDP packet to ensure session is created - result, err = mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "second UDP packet should be accepted", - ) - - // Sync sessions - currentTime := mock.CurrentTime() - err = ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Verify session exists - checkSessionsForVS(t, ts, currentTime, 1, udpVsIP, udpVsPort) - t.Logf("UDP session created successfully") - - // Advance time by timeout-1 (29 seconds) - mock.AdvanceTime(time.Duration(udpTimeout-1) * time.Second) - currentTime = mock.CurrentTime() - - // Verify session still exists - checkSessionsForVS(t, ts, currentTime, 1, udpVsIP, udpVsPort) - t.Logf("UDP session persists at %d seconds", udpTimeout-1) - - // Advance time by 1 second (total 30 seconds) - mock.AdvanceTime(1 * time.Second) - currentTime = mock.CurrentTime() - - // Verify session is gone - checkSessionsForVS(t, ts, currentTime, 0, udpVsIP, udpVsPort) - t.Logf("UDP session expired at %d seconds", udpTimeout) - }) - - // Reset time for next test - mock.SetCurrentTime(time.Unix(1000, 0)) - - // Test 2: TCP SYN Session Timeout - t.Run("TCP_SYN_Session_Timeout", func(t *testing.T) { - clientIP := netip.MustParseAddr("10.1.2.1") - clientPort := uint16(6000) - - // Send TCP SYN packet - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP SYN packet should be accepted", - ) - - currentTime := mock.CurrentTime() - err = ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Verify session exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session created successfully") - - mock.AdvanceTime(time.Duration(tcpSynTimeout-1) * time.Second) - currentTime = mock.CurrentTime() - - // Verify session still exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session persists at %d seconds", tcpTimeout-1) - - // Advance time by 1 second (total 60 seconds) - mock.AdvanceTime(1 * time.Second) - currentTime = mock.CurrentTime() - - // Verify session is gone - checkSessionsForVS(t, ts, currentTime, 0, tcpVsIP, tcpVsPort) - t.Logf("TCP session expired at %d seconds", tcpTimeout) - }) - - // Reset time for next test - mock.SetCurrentTime(time.Unix(2000, 0)) - - // Test 3: TCP SYN-ACK Timeout - t.Run("TCP_SYN_ACK_Timeout", func(t *testing.T) { - clientIP := netip.MustParseAddr("10.1.3.1") - clientPort := uint16(7000) - - // Send TCP SYN packet - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP SYN packet should be accepted", - ) - - // Send TCP SYN-ACK packet from same client - packetLayers = utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{SYN: true, ACK: true}, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP SYN-ACK packet should be accepted", - ) - - // Sync sessions - currentTime := mock.CurrentTime() - - // Verify session exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP SYN-ACK session created successfully") - - // Advance time by SYN-ACK timeout-1 (24 seconds) - mock.AdvanceTime(time.Duration(tcpSynAckTimeout-1) * time.Second) - currentTime = mock.CurrentTime() - - // Verify session still exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP SYN-ACK session persists at %d seconds", tcpSynAckTimeout-1) - - // Advance time by 1 second (total 25 seconds) - mock.AdvanceTime(1 * time.Second) - currentTime = mock.CurrentTime() - - // Verify session is gone - checkSessionsForVS(t, ts, currentTime, 0, tcpVsIP, tcpVsPort) - t.Logf("TCP SYN-ACK session expired at %d seconds", tcpSynAckTimeout) - }) - - // Reset time for next test - mock.SetCurrentTime(time.Unix(3000, 0)) - - // Test 4: TCP SYN + Basic Packet Timeout - t.Run("TCP_SYN_Then_Basic_Packet_Timeout", func(t *testing.T) { - clientIP := netip.MustParseAddr("10.1.4.1") - clientPort := uint16(8000) - - // Send TCP SYN packet - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP SYN packet should be accepted", - ) - - // Send regular TCP packet (no flags) - this should switch timeout to TCP timeout - packetLayers = utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{}, // No flags - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP basic packet should be accepted", - ) - - // Sync sessions - currentTime := mock.CurrentTime() - - // Verify session exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->basic) created successfully") - - // Advance time by TCP timeout-1 (59 seconds) - // This verifies timeout switched from TCP_SYN (20s) to TCP (60s) - mock.AdvanceTime(time.Duration(tcpTimeout-1) * time.Second) - currentTime = mock.CurrentTime() - - // Verify session still exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->basic) persists at %d seconds", tcpTimeout-1) - - // Advance time by 1 second (total 60 seconds) - mock.AdvanceTime(1 * time.Second) - currentTime = mock.CurrentTime() - - // Verify session is gone - checkSessionsForVS(t, ts, currentTime, 0, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->basic) expired at %d seconds", tcpTimeout) - }) - - // Reset time for next test - mock.SetCurrentTime(time.Unix(4000, 0)) - - // Test 5: TCP SYN + FIN Packet Timeout - t.Run("TCP_SYN_Then_FIN_Packet_Timeout", func(t *testing.T) { - clientIP := netip.MustParseAddr("10.1.5.1") - clientPort := uint16(9000) - - // Send TCP SYN packet - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP SYN packet should be accepted", - ) - - // Send TCP FIN packet - this should switch timeout to TCP_FIN timeout - packetLayers = utils.MakeTCPPacket( - clientIP, - clientPort, - tcpVsIP, - tcpVsPort, - &layers.TCP{FIN: true}, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "TCP FIN packet should be accepted", - ) - - // Sync sessions - currentTime := mock.CurrentTime() - - // Verify session exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->FIN) created successfully") - - // Advance time by FIN timeout-1 (14 seconds) - mock.AdvanceTime(time.Duration(tcpFinTimeout-1) * time.Second) - currentTime = mock.CurrentTime() - - // Verify session still exists - checkSessionsForVS(t, ts, currentTime, 1, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->FIN) persists at %d seconds", tcpFinTimeout-1) - - // Advance time by 1 second (total 15 seconds) - mock.AdvanceTime(1 * time.Second) - currentTime = mock.CurrentTime() - - // Verify session is gone - checkSessionsForVS(t, ts, currentTime, 0, tcpVsIP, tcpVsPort) - t.Logf("TCP session (SYN->FIN) expired at %d seconds", tcpFinTimeout) - }) -} diff --git a/modules/balancer/tests/go/st_mt_test.go b/modules/balancer/tests/go/st_mt_test.go deleted file mode 100644 index 7d5aab0c4..000000000 --- a/modules/balancer/tests/go/st_mt_test.go +++ /dev/null @@ -1,1195 +0,0 @@ -package balancer_test - -import ( - "fmt" - "maps" - "math" - "math/rand" - "net" - "net/netip" - "sync" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "github.com/yanet-platform/yanet2/tests/functional/framework" - "google.golang.org/protobuf/types/known/durationpb" -) - -//////////////////////////////////////////////////////////////////////////////// - -type fullSessionKey struct { - clientIP netip.Addr - clientPort uint16 - vsIP netip.Addr - vsPort uint16 - proto balancerpb.TransportProto -} - -func (session *fullSessionKey) String() string { - return fmt.Sprintf( - "%v:%v->%v:%v/%v", - session.clientIP, session.clientPort, - session.vsIP, session.vsPort, session.proto, - ) -} - -func (session *fullSessionKey) Vs() vsKey { - return vsKey{ip: session.vsIP, port: session.vsPort, proto: session.proto} -} - -func fullSessionKeyFromTunPacket( - packet *framework.PacketInfo, -) (*fullSessionKey, error) { - if !packet.IsTunneled { - return nil, fmt.Errorf("packet not tunneled") - } - if packet.InnerPacket == nil { - return nil, fmt.Errorf("no inner packet") - } - proto, err := func() (balancerpb.TransportProto, error) { - proto, ok := packet.GetTransportProtocol() - if !ok { - return 0, fmt.Errorf("no transport protocol in inner packet") - } - switch proto { - case layers.IPProtocolTCP: - return balancerpb.TransportProto_TCP, nil - case layers.IPProtocolUDP: - return balancerpb.TransportProto_UDP, nil - default: - return 0, fmt.Errorf( - "incorrect inner packet protocol: %v, but protocol should be Tcp or Udp", - proto, - ) - } - }() - if err != nil { - return nil, err - } - srcIP, ok := netip.AddrFromSlice(packet.InnerPacket.SrcIP) - if !ok { - return nil, fmt.Errorf( - "invalid inner packet src IP: %v", - packet.InnerPacket.SrcIP, - ) - } - dstIP, ok := netip.AddrFromSlice(packet.InnerPacket.DstIP) - if !ok { - return nil, fmt.Errorf( - "invalid inner packet dst IP: %v", - packet.InnerPacket.DstIP, - ) - } - key := fullSessionKey{ - clientIP: srcIP, - clientPort: packet.SrcPort, - vsIP: dstIP, - vsPort: packet.DstPort, - proto: proto, - } - return &key, nil -} - -func fullSessionKeyFromInputPacket( - packet *framework.PacketInfo, -) (*fullSessionKey, error) { - proto, err := func() (balancerpb.TransportProto, error) { - proto, ok := packet.GetTransportProtocol() - if !ok { - return 0, fmt.Errorf("no transport protocol in inner packet") - } - switch proto { - case layers.IPProtocolTCP: - return balancerpb.TransportProto_TCP, nil - case layers.IPProtocolUDP: - return balancerpb.TransportProto_UDP, nil - default: - return 0, fmt.Errorf( - "incorrect inner packet protocol: %v, but protocol should be Tcp or Udp", - proto, - ) - } - }() - if err != nil { - return nil, err - } - srcIP, ok := netip.AddrFromSlice(packet.SrcIP) - if !ok { - return nil, fmt.Errorf( - "invalid src IP: %v", - packet.SrcIP, - ) - } - dstIP, ok := netip.AddrFromSlice(packet.DstIP) - if !ok { - return nil, fmt.Errorf( - "invalid dst IP: %v", - packet.DstIP, - ) - } - key := fullSessionKey{ - clientIP: srcIP, - clientPort: packet.SrcPort, - vsIP: dstIP, - vsPort: packet.DstPort, - proto: proto, - } - return &key, nil -} - -// workerState holds per-worker state -type workerState struct { - id int - rng *rand.Rand - sessions []fullSessionKey - sessionReals map[fullSessionKey]netip.Addr // mapping from session to selected real - stats workerStats -} - -func aggregateWorkerStates(states []workerState) workerState { - aggregate := workerState{ - id: -1, - rng: nil, - sessions: []fullSessionKey{}, - sessionReals: map[fullSessionKey]netip.Addr{}, - stats: workerStats{}, - } - - // Aggregate all sessions and session-to-real mappings - for _, state := range states { - aggregate.sessions = append(aggregate.sessions, state.sessions...) - maps.Copy(aggregate.sessionReals, state.sessionReals) - - // Aggregate statistics - aggregate.stats.totalPackets += state.stats.totalPackets - aggregate.stats.outputPackets += state.stats.outputPackets - aggregate.stats.droppedPackets += state.stats.droppedPackets - aggregate.stats.sessions += state.stats.sessions - } - - return aggregate -} - -// workerStats tracks statistics for a worker -type workerStats struct { - totalPackets int - outputPackets int - droppedPackets int - sessions int -} - -// multithreadTestConfig holds test configuration -type multithreadTestConfig struct { - numWorkers int - batchesPerWorker int - packetsPerBatch int - extendSessionTablePeriod time.Duration -} - -// vsSimple holds simplified VS info for packet generation -type vsSimple struct { - ip netip.Addr - port uint16 - proto balancerpb.TransportProto -} - -// vsConfigWithWeights holds VS configuration with real weights -type vsConfigWithWeights struct { - ip netip.Addr - port uint16 - proto balancerpb.TransportProto - scheduler balancerpb.VsScheduler - gre bool - fixMss bool - reals []realConfigWithWeight -} - -// realConfigWithWeight holds real configuration with weight -type realConfigWithWeight struct { - ip netip.Addr - weight uint32 -} - -//////////////////////////////////////////////////////////////////////////////// - -// randomClientIP generates a random client IP based on VS IP version -func randomClientIP(rng *rand.Rand, vsIP netip.Addr) netip.Addr { - if vsIP.Is4() { - return netip.AddrFrom4([4]byte{ - byte(10), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - }) - } - // IPv6 - return netip.MustParseAddr( - fmt.Sprintf("2001:db8::%x:%x", rng.Intn(65536), rng.Intn(65536)), - ) -} - -// randomPort generates a random port -func randomPort(rng *rand.Rand) uint16 { - return uint16(32768 + rng.Intn(64511)) -} - -//////////////////////////////////////////////////////////////////////////////// - -// generateVSConfigs creates 5 virtual services with random real weights -func generateVSConfigs() []vsConfigWithWeights { - rng := rand.New(rand.NewSource(42)) - - configs := []vsConfigWithWeights{ - // VS1: TCP IPv4, RR scheduler, 10 IPv4 reals - { - ip: netip.MustParseAddr("10.1.1.1"), - port: 80, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS2: UDP IPv4, RR scheduler, 10 IPv4 reals - { - ip: netip.MustParseAddr("10.1.2.1"), - port: 5353, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS3: TCP IPv6, RR scheduler, 10 IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::1"), - port: 443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: true, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS4: UDP IPv6, RR scheduler, 10 IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::2"), - port: 8080, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS5: TCP IPv4, RR scheduler, 10 mixed IPv4/IPv6 reals - { - ip: netip.MustParseAddr("10.1.3.1"), - port: 8443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: true, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - } - - // Generate real IPs and random weights - for i := range configs { - for j := range configs[i].reals { - var realIP netip.Addr - switch i { - case 0: - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.1.%d", j+1)) - case 1: - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.2.%d", j+1)) - case 2: - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:2::%x", j+1)) - case 3: - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:3::%x", j+1)) - case 4: - // Mixed IPv4/IPv6 - if j < 5 { - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.4.%d", j+1)) - } else { - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:4::%x", j-4)) - } - } - - configs[i].reals[j] = realConfigWithWeight{ - ip: realIP, - weight: uint32(rng.Intn(10) + 1), // Random weight 1-10 - } - } - } - - return configs -} - -// buildModuleConfig creates balancer module config from VS configs -func buildModuleConfig( - vsConfigs []vsConfigWithWeights, - sessionTimeout int, - capacity uint64, - maxLoadFactor float32, -) *balancerpb.BalancerConfig { - virtualServices := make([]*balancerpb.VirtualService, 0, len(vsConfigs)) - - for _, vsConf := range vsConfigs { - // Build allowed sources based on VS IP version - var allowedSrcs []*balancerpb.AllowedSources - if vsConf.ip.Is4() { - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0").AsSlice(), - }, - }}, - }, - } - } else { - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff::").AsSlice(), - }, - }}, - }, - } - } - - // Build reals - reals := make([]*balancerpb.Real, 0, len(vsConf.reals)) - for _, realConf := range vsConf.reals { - var srcMask []byte - if realConf.ip.Is4() { - srcMask = netip.MustParseAddr("255.255.255.255").AsSlice() - } else { - srcMask = netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff").AsSlice() - } - - reals = append(reals, &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realConf.ip.AsSlice(), - }, - Port: 0, - }, - Weight: realConf.weight, - SrcAddr: &balancerpb.Addr{ - Bytes: realConf.ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: srcMask, - }, - }) - } - - virtualServices = append(virtualServices, &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsConf.ip.AsSlice(), - }, - Port: uint32(vsConf.port), - Proto: vsConf.proto, - }, - AllowedSrcs: allowedSrcs, - Scheduler: vsConf.scheduler, - Flags: &balancerpb.VsFlags{ - Gre: vsConf.gre, - FixMss: vsConf.fixMss, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - }) - } - - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: virtualServices, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: uint32(sessionTimeout), - TcpSyn: uint32(sessionTimeout), - TcpFin: uint32(sessionTimeout), - Tcp: uint32(sessionTimeout), - Udp: uint32(sessionTimeout), - Default: uint32(sessionTimeout), - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: &capacity, - SessionTableMaxLoadFactor: &maxLoadFactor, - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// makeSimplePacketLayers creates packet layers with minimal payload to avoid protocol detection issues -func makeSimplePacketLayers( - srcIP netip.Addr, - srcPort uint16, - dstIP netip.Addr, - dstPort uint16, - isTCP bool, - tcpFlags *layers.TCP, -) []gopacket.SerializableLayer { - // Ensure both addresses are the same IP version - if srcIP.Is4() != dstIP.Is4() { - panic(fmt.Sprintf("IP version mismatch: src=%v dst=%v", srcIP, dstIP)) - } - - src := net.IP(srcIP.AsSlice()) - dst := net.IP(dstIP.AsSlice()) - - var ip gopacket.NetworkLayer - ethernetType := layers.EthernetTypeIPv6 - if srcIP.Is4() { - ethernetType = layers.EthernetTypeIPv4 - if isTCP { - ip = &layers.IPv4{ - Version: 4, - IHL: 5, - TTL: 64, - Protocol: layers.IPProtocolTCP, - SrcIP: src, - DstIP: dst, - } - } else { - ip = &layers.IPv4{ - Version: 4, - IHL: 5, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: src, - DstIP: dst, - } - } - } else { - if isTCP { - ip = &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolTCP, - HopLimit: 64, - SrcIP: src, - DstIP: dst, - } - } else { - ip = &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolUDP, - HopLimit: 64, - SrcIP: src, - DstIP: dst, - } - } - } - - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: ethernetType, - } - - var result []gopacket.SerializableLayer - result = append(result, eth, ip.(gopacket.SerializableLayer)) - - if isTCP { - tcp := tcpFlags - tcp.SrcPort = layers.TCPPort(srcPort) - tcp.DstPort = layers.TCPPort(dstPort) - _ = tcp.SetNetworkLayerForChecksum(ip) - result = append(result, tcp) - } else { - udp := &layers.UDP{ - SrcPort: layers.UDPPort(srcPort), - DstPort: layers.UDPPort(dstPort), - } - _ = udp.SetNetworkLayerForChecksum(ip) - result = append(result, udp) - } - - // Use empty payload to avoid protocol detection issues - result = append(result, gopacket.Payload([]byte{})) - return result -} - -// generateNewSessionPacket creates a packet for a new session -func generateNewSessionPacket( - state *workerState, - virtualServices []vsSimple, -) (gopacket.Packet, *fullSessionKey, error) { - // Randomly select a virtual service - vs := virtualServices[state.rng.Intn(len(virtualServices))] - - // Generate random client IP based on VS IP version - clientIP := randomClientIP(state.rng, vs.ip) - clientPort := randomPort(state.rng) - - session := &fullSessionKey{ - clientIP: clientIP, - clientPort: clientPort, - vsIP: vs.ip, - vsPort: vs.port, - proto: vs.proto, - } - - var packetLayers []gopacket.SerializableLayer - if vs.proto == balancerpb.TransportProto_TCP { - packetLayers = makeSimplePacketLayers( - clientIP, clientPort, vs.ip, vs.port, - true, &layers.TCP{SYN: true}, - ) - } else { - packetLayers = makeSimplePacketLayers( - clientIP, clientPort, vs.ip, vs.port, - false, nil, - ) - } - - packet, err := xpacket.LayersToPacketChecked(packetLayers...) - if err != nil { - return nil, nil, err - } - return packet, session, nil -} - -// generateExistingSessionPacket creates a packet for an existing session -func generateExistingSessionPacket( - state *workerState, -) (gopacket.Packet, *fullSessionKey, error) { - // Randomly select an existing session - idx := state.rng.Intn(len(state.sessions)) - session := &state.sessions[idx] - - // Create packet for session - var packetLayers []gopacket.SerializableLayer - if session.proto == balancerpb.TransportProto_TCP { - // TCP - packetLayers = makeSimplePacketLayers( - session.clientIP, session.clientPort, - session.vsIP, session.vsPort, - true, &layers.TCP{}, - ) - } else { - // UDP - packetLayers = makeSimplePacketLayers( - session.clientIP, session.clientPort, - session.vsIP, session.vsPort, - false, nil, - ) - } - - packet, err := xpacket.LayersToPacketChecked(packetLayers...) - if err != nil { - return nil, nil, err - } - return packet, session, nil -} - -//////////////////////////////////////////////////////////////////////////////// - -// workerRoutine sends packets and validates sessions -func workerRoutine( - workerID int, - config *multithreadTestConfig, - mock *mock.YanetMock, - virtualServices []vsSimple, - wg *sync.WaitGroup, - errors chan error, - resultState *workerState, -) { - defer wg.Done() - - state := &workerState{ - id: workerID, - rng: rand.New(rand.NewSource(int64(workerID + 1000))), - sessions: []fullSessionKey{}, - sessionReals: map[fullSessionKey]netip.Addr{}, - stats: workerStats{}, - } - - for batch := range config.batchesPerWorker { - outputActiveSessions := map[fullSessionKey]bool{} - packets := make([]gopacket.Packet, 0, config.packetsPerBatch) - - sendError := func(format string, a ...any) { - errors <- fmt.Errorf("worker %d: batch %d: %w", workerID, batch, fmt.Errorf(format, a...)) - } - - for range config.packetsPerBatch { - if state.rng.Intn(10) < 5 || len(state.sessions) == 0 { - // new session - packet, _, err := generateNewSessionPacket( - state, - virtualServices, - ) - if err != nil { - sendError("failed to generate new session packet: %w", err) - continue - } - packets = append(packets, packet) - } else { - packet, key, err := generateExistingSessionPacket(state) - if err != nil { - sendError("failed to generate existing session packet: %w", err) - continue - } - packets = append(packets, packet) - outputActiveSessions[*key] = false - } - } - - result, err := mock.HandlePacketsOnWorker(workerID, packets...) - if err != nil { - sendError("failed to handle packets: %w", err) - continue - } - output, drop := result.Output, result.Drop - for _, outPkt := range output { - sessionKey, err := fullSessionKeyFromTunPacket(outPkt) - if err != nil { - sendError("failed to get session key for out packet: %w", err) - continue - } - realIP, ok := netip.AddrFromSlice(outPkt.DstIP) - if !ok { - sendError( - "failed to get real ip for out packet (dstIP=%v)", - outPkt.DstIP, - ) - continue - } - if expectedRealIP, ok := state.sessionReals[*sessionKey]; ok { - if expectedRealIP != realIP { - sendError( - "real ip mismatch for session %v: expected=%v, got=%v", - sessionKey, - expectedRealIP, - realIP, - ) - continue - } - outputActiveSessions[*sessionKey] = true - } else { // created new session - state.sessionReals[*sessionKey] = realIP - state.sessions = append(state.sessions, *sessionKey) - } - } - for _, dropPkt := range drop { - key, err := fullSessionKeyFromInputPacket(dropPkt) - if err != nil { - sendError( - "failed to get session key from dropped packet: %w", - err, - ) - continue - } - if _, ok := outputActiveSessions[*key]; ok { - expectedReal := state.sessionReals[*key] - sendError("dropped active session %v [real %v]", - key, expectedReal, - ) - } - } - for sessionKey, touched := range outputActiveSessions { - if !touched { - sendError( - "active session %v not in output [real %v]", - sessionKey, - state.sessionReals[sessionKey], - ) - } - } - if config.packetsPerBatch != len(drop)+len(output) { - sendError( - "summary packet mismatch: expected=%d, got=%d", - config.packetsPerBatch, - len(drop)+len(output), - ) - } - state.stats.droppedPackets += len(drop) - state.stats.outputPackets += len(output) - state.stats.totalPackets += config.packetsPerBatch - } - - state.stats.sessions = len(state.sessions) - - *resultState = *state -} - -//////////////////////////////////////////////////////////////////////////////// - -// validateWeightDistribution checks packet distribution matches real weights -func validateWeightDistribution( - t *testing.T, - vsConfigs []vsConfigWithWeights, - aggregate *workerState, -) { - realSessionCount := make(map[netip.Addr]int) - for _, session := range aggregate.sessions { - realIP := aggregate.sessionReals[session] - realSessionCount[realIP] += 1 - } - - for _, vsConfig := range vsConfigs { - // Calculate total weight and total packets - totalWeight := uint32(0) - totalSessions := 0 - - for idx := range vsConfig.reals { - real := &vsConfig.reals[idx] - totalSessions += realSessionCount[real.ip] - totalWeight += real.weight - } - - for idx := range vsConfig.reals { - real := &vsConfig.reals[idx] - sessions := realSessionCount[real.ip] - expectedRatio := float64(real.weight) / float64(totalWeight) - actualRatio := float64(sessions) / float64(totalSessions) - - deviation := math.Abs(actualRatio-expectedRatio) / expectedRatio - - assert.Less( - t, - deviation, - 0.3, - "VS %s:%d/%s Real %s: sessions distribution deviates too much from weight: expected=%.2f, actual=%.2f, deviation=%.2f", - vsConfig.ip, - vsConfig.port, - vsConfig.proto, - real.ip, - expectedRatio, - actualRatio, - deviation, - ) - } - } -} - -// validateCounters checks stats counters match info counters -func validateCounters( - t *testing.T, - balancer *balancer.BalancerManager, - mock *mock.YanetMock, - aggregate *workerState, -) { - // Count sessions per real from aggregate - realSessionCount := make(map[netip.Addr]int) - vsSessionCount := make(map[vsKey]int) - for _, session := range aggregate.sessions { - realIP := aggregate.sessionReals[session] - realSessionCount[realIP] += 1 - vsSessionCount[session.Vs()] += 1 - } - - currentTime := mock.CurrentTime() - - // Get state info (from stats) - stateInfo, err := balancer.Info(currentTime) - require.NoError(t, err) - - // Get config stats (from Info) - ref := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - configStats, err := balancer.Stats(ref) - require.NoError(t, err) - - // Validate VS counters match - require.Equal(t, len(stateInfo.Vs), len(configStats.Vs), - "VS count mismatch between state and config") - - summaryOverflowCnt := uint64(0) - for i := range stateInfo.Vs { - vsState := stateInfo.Vs[i] - vsConfig := configStats.Vs[i] - - vsAddr, _ := netip.AddrFromSlice(vsState.Id.Addr.Bytes) - vs := vsKey{ - ip: vsAddr, - port: uint16(vsState.Id.Port), - proto: vsState.Id.Proto, - } - - expectedSessions := vsSessionCount[vs] - - assert.Equal( - t, - uint64(expectedSessions), - vsState.ActiveSessions, - "[VS %s]: active session count mismatch between state and workers", - vs.String(), - ) - - summaryOverflowCnt += vsConfig.Stats.SessionTableOverflow - - } - - // Validate Real counters match - for i := range stateInfo.Vs { - vsState := stateInfo.Vs[i] - vsConfig := configStats.Vs[i] - - require.Equal(t, len(vsState.Reals), len(vsConfig.Reals), - "Real count mismatch between state and config for VS %d", i) - - for j := range vsState.Reals { - realState := vsState.Reals[j] - - realIP, _ := netip.AddrFromSlice(realState.Id.Real.Ip.Bytes) - expectedSessions := realSessionCount[realIP] - - assert.Equal(t, uint64(expectedSessions), realState.ActiveSessions, - "[real %s]: active session count mismatch", realIP) - } - } - - // Validate invariants - assert.Equal( - t, - configStats.Common.IncomingPackets, - uint64(aggregate.stats.totalPackets), - ) - assert.Equal( - t, - configStats.Common.OutgoingPackets, - uint64(aggregate.stats.outputPackets), - ) - assert.Equal( - t, - configStats.L4.IncomingPackets, - uint64(aggregate.stats.totalPackets), - ) - assert.Equal( - t, - configStats.L4.OutgoingPackets, - uint64(aggregate.stats.outputPackets), - ) - assert.Equal( - t, - configStats.L4.SelectRealFailed, - uint64(aggregate.stats.droppedPackets), - ) - assert.Equal(t, summaryOverflowCnt, uint64(aggregate.stats.droppedPackets)) -} - -// validateFinalSessions checks session count -func validateFinalSessions( - t *testing.T, - balancer *balancer.BalancerManager, - currentTime time.Time, - aggregate *workerState, -) { - // Get sessions info from balancer - sessionsInfo, err := balancer.Sessions(currentTime) - require.NoError(t, err, "failed to get sessions info") - - // Log session counts - t.Logf("Aggregate tracked sessions: %d", len(aggregate.sessions)) - t.Logf("Balancer active sessions: %d", len(sessionsInfo)) - - // Verify session count matches - assert.Equal(t, len(aggregate.sessions), len(sessionsInfo), - "session count mismatch between aggregate and balancer") - - // Build a map of expected sessions from aggregate - expectedSessions := make(map[fullSessionKey]netip.Addr) - for _, session := range aggregate.sessions { - expectedSessions[session] = aggregate.sessionReals[session] - } - - // Verify each balancer session exists in our tracked sessions - for i, session := range sessionsInfo { - // Build session key from balancer session info - clientAddr, _ := netip.AddrFromSlice(session.ClientAddr.Bytes) - vsAddr, _ := netip.AddrFromSlice(session.RealId.Vs.Addr.Bytes) - realIP, _ := netip.AddrFromSlice(session.RealId.Real.Ip.Bytes) - - sessionKey := fullSessionKey{ - clientIP: clientAddr, - clientPort: uint16(session.ClientPort), - vsIP: vsAddr, - vsPort: uint16(session.RealId.Vs.Port), - proto: session.RealId.Vs.Proto, - } - - // Check if this session was tracked - expectedReal, found := expectedSessions[sessionKey] - if !found { - t.Errorf( - "Session %d: balancer has session %s that was not tracked by workers", - i, - sessionKey.String(), - ) - continue - } - - // Verify the real server matches - if expectedReal != realIP { - t.Errorf( - "Session %d: real server mismatch for %s: expected=%v, got=%v", - i, - sessionKey.String(), - expectedReal, - realIP, - ) - } - - // Remove from expected map (to detect sessions we tracked but balancer doesn't have) - delete(expectedSessions, sessionKey) - } - - // Check for sessions we tracked but balancer doesn't have - if len(expectedSessions) > 0 { - t.Errorf( - "Workers tracked %d sessions that are not in balancer:", - len(expectedSessions), - ) - for sessionKey, realIP := range expectedSessions { - t.Errorf(" - %s -> real %v", sessionKey.String(), realIP) - } - } - - t.Logf("Session validation completed successfully") -} - -//////////////////////////////////////////////////////////////////////////////// - -// extendSessionTableRoutine periodically calls sync to allow session table resizing -func extendSessionTableRoutine( - mock *mock.YanetMock, - balancer *balancer.BalancerManager, - done chan struct{}, - config *multithreadTestConfig, - errors chan error, -) { - ticker := time.NewTicker(config.extendSessionTablePeriod) - defer ticker.Stop() - - for { - select { - case <-done: - return - case <-ticker.C: - err := balancer.Refresh( - mock.CurrentTime(), - ) - if err != nil { - errors <- err - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// runMultithreadedTest executes the multithreaded test -func runMultithreadedTest(t *testing.T, config *multithreadTestConfig) { - // Generate VS configurations with random real weights - vsConfigs := generateVSConfigs() - - sessionTimeout := 60 - - // Calculate expected sessions to set initial capacity - // Total packets = numWorkers * batchesPerWorker * packetsPerBatch - // New session probability is 50% - totalPackets := config.numWorkers * config.batchesPerWorker * config.packetsPerBatch - expectedSessions := uint64(totalPackets / 2) - initialCapacity := 3 * expectedSessions / 2 - maxLoadFactor := float32(0.5) - - moduleConfig := buildModuleConfig( - vsConfigs, - sessionTimeout, - initialCapacity, - maxLoadFactor, - ) - - // Setup test - mockConfig := utils.SingleWorkerMockConfig(datasize.MB*512, datasize.MB*4) - mockConfig.Workers = uint64(config.numWorkers) - - setup, err := utils.Make(&utils.TestConfig{ - Mock: mockConfig, - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 256 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer setup.Free() - - mock := setup.Mock - balancer := setup.Balancer - - // Enable all reals - utils.EnableAllReals(t, setup) - - // Set initial time - mock.SetCurrentTime(time.Unix(0, 0)) - - // Create simplified VS list for packet generation - vsSimpleList := make([]vsSimple, 0, len(vsConfigs)) - for _, vsConf := range vsConfigs { - vsSimpleList = append(vsSimpleList, vsSimple{ - ip: vsConf.ip, - port: vsConf.port, - proto: vsConf.proto, - }) - } - - // Create channels and wait groups - errors := make(chan error, config.numWorkers+1) - - var wg sync.WaitGroup - - // Launch worker goroutines - wg.Add(config.numWorkers) - wStates := make([]workerState, config.numWorkers) - for i := 0; i < config.numWorkers; i++ { - go workerRoutine( - i, config, mock, - vsSimpleList, &wg, errors, &wStates[i], - ) - } - - done := make(chan struct{}, 1) - - // Start extend session table routine - go func() { - extendSessionTableRoutine(mock, balancer, done, config, errors) - }() - - // Listen for errors - wg.Wait() - - // Stop extend session table routine - done <- struct{}{} - - close(errors) - - // List for errors - for err := range errors { - t.Error(err) - } - - t.Log("all worker routines completed") - - // Perform final validations - - t.Run("Validate_Workers_Stats", func(t *testing.T) { - for worker := range wStates { - stats := wStates[worker].stats - dropRate := float64( - stats.droppedPackets, - ) / float64( - stats.totalPackets, - ) * 100.0 - t.Logf( - "worker %d: sessions=%d, totalPackets=%d, output=%d, dropped=%d, dropRate=%.2f%%", - worker, - stats.sessions, - stats.totalPackets, - stats.outputPackets, - stats.droppedPackets, - dropRate, - ) - assert.Less( - t, - dropRate, - 20.0, - "worker %d: too big drop rate", - worker, - ) - } - }) - - currentTime := mock.CurrentTime() - - // Aggregate worker states - aggregate := aggregateWorkerStates(wStates) - if len(aggregate.sessions) != len(aggregate.sessionReals) { - panic("invariant violation, maybe bad seed") - } - - t.Run("Validate_Counters", func(t *testing.T) { - validateCounters(t, balancer, mock, &aggregate) - }) - - t.Run("Validate_Final_Sessions", func(t *testing.T) { - validateFinalSessions(t, balancer, currentTime, &aggregate) - }) - - t.Run("Validate_Weight_Distribution", func(t *testing.T) { - validateWeightDistribution(t, vsConfigs, &aggregate) - }) - - // Log final statistics - capacity := balancer.Config().State.SessionTableCapacity - t.Logf("Final session table capacity: %d", *capacity) -} - -//////////////////////////////////////////////////////////////////////////////// - -// TestMultithreadedSessionTable tests session table with multiple workers -func TestMultithreadedSessionTable(t *testing.T) { - testCases := []struct { - name string - numWorkers int - }{ - {"SingleWorker", 1}, - {"TwoWorkers", 2}, - {"FourWorkers", 4}, - {"EightWorkers", 8}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - config := &multithreadTestConfig{ - numWorkers: tc.numWorkers, - batchesPerWorker: 100, - packetsPerBatch: 1024 / tc.numWorkers, - extendSessionTablePeriod: 50 * time.Millisecond, - } - - runMultithreadedTest(t, config) - }) - } -} diff --git a/modules/balancer/tests/go/st_mt_wlc_test.go b/modules/balancer/tests/go/st_mt_wlc_test.go deleted file mode 100644 index 06a7ddbba..000000000 --- a/modules/balancer/tests/go/st_mt_wlc_test.go +++ /dev/null @@ -1,461 +0,0 @@ -package balancer_test - -import ( - "fmt" - "math/rand" - "net/netip" - "sync" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" -) - -//////////////////////////////////////////////////////////////////////////////// - -// generateVSConfigs creates 5 virtual services with random real weights -func generateVSConfigsWithWlc() []vsConfigWithWeights { - rng := rand.New(rand.NewSource(42)) - - configs := []vsConfigWithWeights{ - // VS1: TCP IPv4, RR scheduler, 10 IPv4 reals - { - ip: netip.MustParseAddr("10.1.1.1"), - port: 80, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS2: UDP IPv4, RR scheduler, 10 IPv4 reals - { - ip: netip.MustParseAddr("10.1.2.1"), - port: 5353, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS3: TCP IPv6, RR scheduler, 10 IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::1"), - port: 443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: true, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS4: UDP IPv6, RR scheduler, 10 IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::2"), - port: 8080, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - // VS5: TCP IPv4, RR scheduler, 10 mixed IPv4/IPv6 reals - { - ip: netip.MustParseAddr("10.1.3.1"), - port: 8443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: true, - fixMss: false, - reals: make([]realConfigWithWeight, 10), - }, - } - - // Generate real IPs and random weights - for i := range configs { - for j := range configs[i].reals { - var realIP netip.Addr - switch i { - case 0: - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.1.%d", j+1)) - case 1: - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.2.%d", j+1)) - case 2: - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:2::%x", j+1)) - case 3: - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:3::%x", j+1)) - case 4: - // Mixed IPv4/IPv6 - if j < 5 { - realIP = netip.MustParseAddr(fmt.Sprintf("10.2.4.%d", j+1)) - } else { - realIP = netip.MustParseAddr(fmt.Sprintf("2001:db8:4::%x", j-4)) - } - } - - configs[i].reals[j] = realConfigWithWeight{ - ip: realIP, - weight: uint32(rng.Intn(10) + 1), // Random weight 1-10 - } - } - } - - return configs -} - -//////////////////////////////////////////////////////////////////////////////// - -// workerRoutine sends packets and validates sessions -func workerDisableRealsAwareRoutine( - workerID int, - config *multithreadTestConfig, - mock *mock.YanetMock, - virtualServices []vsSimple, - wg *sync.WaitGroup, - errors chan error, - resultState *workerState, -) { - defer wg.Done() - - realMismatches := 0 - - state := &workerState{ - id: workerID, - rng: rand.New(rand.NewSource(int64(workerID + 1000))), - sessions: []fullSessionKey{}, - sessionReals: map[fullSessionKey]netip.Addr{}, - stats: workerStats{}, - } - - for batch := range config.batchesPerWorker { - outputActiveSessions := map[fullSessionKey]bool{} - packets := make([]gopacket.Packet, 0, config.packetsPerBatch) - - sendError := func(format string, a ...any) { - errors <- fmt.Errorf("worker %d: batch %d: %w", workerID, batch, fmt.Errorf(format, a...)) - } - - for range config.packetsPerBatch { - if state.rng.Intn(10) < 5 || len(state.sessions) == 0 { - // new session - packet, _, err := generateNewSessionPacket( - state, - virtualServices, - ) - if err != nil { - sendError("failed to generate new session packet: %w", err) - continue - } - packets = append(packets, packet) - } else { - packet, key, err := generateExistingSessionPacket(state) - if err != nil { - sendError("failed to generate existing session packet: %w", err) - continue - } - packets = append(packets, packet) - outputActiveSessions[*key] = false - } - } - - result, err := mock.HandlePacketsOnWorker(workerID, packets...) - if err != nil { - sendError("failed to handle packets: %w", err) - continue - } - output, drop := result.Output, result.Drop - for _, outPkt := range output { - sessionKey, err := fullSessionKeyFromTunPacket(outPkt) - if err != nil { - sendError("failed to get session key for out packet: %w", err) - continue - } - realIP, ok := netip.AddrFromSlice(outPkt.DstIP) - if !ok { - sendError( - "failed to get real ip for out packet (dstIP=%v)", - outPkt.DstIP, - ) - continue - } - if expectedRealIP, ok := state.sessionReals[*sessionKey]; ok { - if expectedRealIP != realIP { - realMismatches += 1 - state.sessionReals[*sessionKey] = realIP - } - outputActiveSessions[*sessionKey] = true - } else { // created new session - state.sessionReals[*sessionKey] = realIP - } - } - for _, dropPkt := range drop { - key, err := fullSessionKeyFromInputPacket(dropPkt) - if err != nil { - sendError( - "failed to get session key from dropped packet: %w", - err, - ) - continue - } - if _, ok := outputActiveSessions[*key]; ok { - realMismatches += 1 - } - } - for _, touched := range outputActiveSessions { - if !touched { - realMismatches += 1 - } - } - if config.packetsPerBatch != len(drop)+len(output) { - sendError( - "summary packet mismatch: expected=%d, got=%d", - config.packetsPerBatch, - len(drop)+len(output), - ) - } - state.stats.droppedPackets += len(drop) - state.stats.outputPackets += len(output) - state.stats.totalPackets += config.packetsPerBatch - } - - if realMismatches > 10 { - errors <- fmt.Errorf("worker %d: too many real mismatches: %d (max is 10)", workerID, realMismatches) - } - - state.stats.sessions = len(state.sessions) - - *resultState = *state -} - -//////////////////////////////////////////////////////////////////////////////// - -// refreshRoutine periodically calls sync to allow session table resizing -func refreshRoutine( - mock *mock.YanetMock, - balancer *balancer.BalancerManager, - done chan struct{}, - config *multithreadTestConfig, - errors chan error, -) { - ticker := time.NewTicker(config.extendSessionTablePeriod) - defer ticker.Stop() - - iter := 0 - - for { - select { - case <-done: - return - case <-ticker.C: - if iter <= 5 { - // disable one real for every virtual service - config := balancer.Config() - updates := []*balancerpb.RealUpdate{} - enableTrue := true - enableFalse := false - for _, vs := range config.PacketHandler.Vs { - if iter < 5 { - updates = append(updates, &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Real: vs.Reals[iter].Id, - Vs: vs.Id, - }, - Weight: nil, - Enable: &enableFalse, - }) - } - if iter > 0 { - updates = append(updates, &balancerpb.RealUpdate{ - RealId: &balancerpb.RealIdentifier{ - Real: vs.Reals[iter-1].Id, - Vs: vs.Id, - }, - Weight: nil, - Enable: &enableTrue, - }) - } - } - madeUpdates, err := balancer.UpdateReals(updates, false) - if err != nil { - errors <- err - } - if madeUpdates != len(updates) { - errors <- fmt.Errorf("in refresh routine: expected %d updates, got %d", len(updates), madeUpdates) - } - iter += 1 - } - err := balancer.Refresh( - mock.CurrentTime(), - ) - if err != nil { - errors <- err - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// runMultithreadedWlcTestMultithreadedTest executes the multithreaded test -func runMultithreadedWlcTest(t *testing.T, config *multithreadTestConfig) { - // Generate VS configurations with random real weights - vsConfigs := generateVSConfigsWithWlc() - - sessionTimeout := 60 - - // Calculate expected sessions to set initial capacity - // Total packets = numWorkers * batchesPerWorker * packetsPerBatch - // New session probability is 50% - totalPackets := config.numWorkers * config.batchesPerWorker * config.packetsPerBatch - expectedSessions := uint64(totalPackets / 2) - initialCapacity := 3 * expectedSessions / 2 - maxLoadFactor := float32(0.5) - - moduleConfig := buildModuleConfig( - vsConfigs, - sessionTimeout, - initialCapacity, - maxLoadFactor, - ) - - // Setup test - mockConfig := utils.SingleWorkerMockConfig(datasize.MB*512, datasize.MB*4) - mockConfig.Workers = uint64(config.numWorkers) - - setup, err := utils.Make(&utils.TestConfig{ - Mock: mockConfig, - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 256 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer setup.Free() - - mock := setup.Mock - balancer := setup.Balancer - - // Enable all reals - utils.EnableAllReals(t, setup) - - // Set initial time - mock.SetCurrentTime(time.Unix(0, 0)) - - // Create simplified VS list for packet generation - vsSimpleList := make([]vsSimple, 0, len(vsConfigs)) - for _, vsConf := range vsConfigs { - vsSimpleList = append(vsSimpleList, vsSimple{ - ip: vsConf.ip, - port: vsConf.port, - proto: vsConf.proto, - }) - } - - // Create channels and wait groups - errors := make(chan error, config.numWorkers+1) - - var wg sync.WaitGroup - - // Launch worker goroutines - wg.Add(config.numWorkers) - wStates := make([]workerState, config.numWorkers) - for i := 0; i < config.numWorkers; i++ { - go workerDisableRealsAwareRoutine( - i, config, mock, - vsSimpleList, &wg, errors, &wStates[i], - ) - } - - done := make(chan struct{}) - refreshDone := make(chan struct{}) - - // Start extend session table routine - go func() { - defer close(refreshDone) - refreshRoutine(mock, balancer, done, config, errors) - }() - - // Listen for errors - wg.Wait() - - // Stop extend session table routine and wait until it fully exits - close(done) - <-refreshDone - - close(errors) - - // List for errors - for err := range errors { - t.Error(err) - } - - t.Log("all worker routines completed") - - // Perform final validations - - t.Run("Validate_Workers_Stats", func(t *testing.T) { - for worker := range wStates { - stats := wStates[worker].stats - dropRate := float64( - stats.droppedPackets, - ) / float64( - stats.totalPackets, - ) * 100.0 - t.Logf( - "worker %d: sessions=%d, totalPackets=%d, output=%d, dropped=%d, dropRate=%.2f%%", - worker, - stats.sessions, - stats.totalPackets, - stats.outputPackets, - stats.droppedPackets, - dropRate, - ) - assert.Less( - t, - dropRate, - 20.0, - "worker %d: too big drop rate", - worker, - ) - } - }) - - // Log final statistics - capacity := balancer.Config().State.SessionTableCapacity - t.Logf("Final session table capacity: %d", *capacity) -} - -//////////////////////////////////////////////////////////////////////////////// - -// TestMultithreadedWlcSessionTable tests session table with multiple workers -func TestMultithreadedWlcSessionTable(t *testing.T) { - testCases := []struct { - name string - numWorkers int - }{ - {"SingleWorker", 1}, - {"TwoWorkers", 2}, - {"FourWorkers", 4}, - {"EightWorkers", 8}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - config := &multithreadTestConfig{ - numWorkers: tc.numWorkers, - batchesPerWorker: 100, - packetsPerBatch: 1024 / tc.numWorkers, - extendSessionTablePeriod: 150 * time.Millisecond, - } - - runMultithreadedWlcTest(t, config) - }) - } -} diff --git a/modules/balancer/tests/go/st_stress1_test.go b/modules/balancer/tests/go/st_stress1_test.go deleted file mode 100644 index 98660baf9..000000000 --- a/modules/balancer/tests/go/st_stress1_test.go +++ /dev/null @@ -1,517 +0,0 @@ -package balancer_test - -// TestSessionTableStress1 performs comprehensive stress testing of the session table: -// -// # Test Configuration -// - Single virtual service with one real server -// - Session timeout: 64 seconds -// - Initial capacity: 16 (dynamically resized during test) -// - Max load factor: 0.5 -// -// # Test Scenarios -// Multiple test cases with varying parameters: -// - Small batches with long intervals: 10 packets × 500 batches -// - Small batches with short intervals: 10 packets × 4 batches -// - Large batches with long intervals: 500 packets × 10 batches -// - Large batches with short intervals: 500 packets × 10 batches -// -// # Validation Per Batch -// - Sends random TCP SYN packets from unique client IPs -// - Randomly refreshes 1/3 of existing sessions with non-SYN packets -// - Verifies packet acceptance rate (drop rate < 10%) -// - Validates active session counts match expectations -// - Confirms session table resizes dynamically as needed -// - Advances time and removes expired sessions from tracking -// -// # Session Expiration Testing -// - Tracks session creation and last packet times -// - Removes sessions from tracking after timeout period -// - Verifies balancer's active session counts match tracked sessions -// - Validates VS and Real active session counts are consistent -// -// # Performance Metrics -// - Monitors drop rates across all batches -// - Tracks failed active session insertions -// - Ensures acceptable performance under load (< 10% failure rate) - -import ( - "fmt" - "math/rand" - "net/netip" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -//////////////////////////////////////////////////////////////////////////////// - -type testCase struct { - batchSize int - numBatches int - timeoutBatches int -} - -type stressConfig struct { - ts *utils.TestSetup - rng *rand.Rand - vsIP netip.Addr - vsPort uint16 - realAddr netip.Addr - timeouts int -} - -func executeTestCase(t *testing.T, config *stressConfig, test *testCase) { - // Track sessions by their creation time to calculate expected active sessions - type sessionKey struct { - ip netip.Addr - port uint16 - } - activeSessions := make( - map[sessionKey]time.Time, - ) // maps session to its last packet time - - totalDropped := 0 - totalOutput := 0 - totalPackets := 0 - - advanceTimePerBatch := time.Duration( - config.timeouts/test.timeoutBatches+1, - ) * time.Second - - // Get initial capacity - initialConfig := config.ts.Balancer.Config() - require.NotNil(t, initialConfig.State) - require.NotNil(t, initialConfig.State.SessionTableCapacity) - initialCapacity := *initialConfig.State.SessionTableCapacity - - t.Logf( - "Test batchSize=%d, numBatches=%d, timeoutBatches=%d, advanceTimePerBatch=%.1fs, sessionsTimeouts=%.1fs, sessionTableCapacity=%d", - test.batchSize, - test.numBatches, - test.timeoutBatches, - advanceTimePerBatch.Seconds(), - float32(config.timeouts), - initialCapacity, - ) - - // Verify initial state is valid - - t.Run("Verify_Initial_State", func(t *testing.T) { - currentTime := config.ts.Mock.CurrentTime() - - // Verify sessions - sessions, err := config.ts.Balancer.Sessions(currentTime) - require.NoError(t, err, "failed to get sessions") - assert.Equal(t, 0, len(sessions)) - - // Get info - info, err := config.ts.Balancer.Info(currentTime) - require.NoError(t, err) - - // Verify module state - assert.Equal(t, uint64(0), info.ActiveSessions) - - // Verify virtual services active sessions - require.Equal(t, 1, len(info.Vs), "should have exactly one VS") - assert.Equal(t, uint64(0), info.Vs[0].ActiveSessions) - - // Verify real services active sessions - require.Equal( - t, - 1, - len(info.Vs[0].Reals), - "should have exactly one Real", - ) - assert.Equal(t, uint64(0), info.Vs[0].Reals[0].ActiveSessions) - }) - - // Generate and send packets - - t.Run("Send_Packets", func(t *testing.T) { - logPeriod := (test.numBatches + 9) / 10 - - // It is possible only if - // there is no enough time to - // extend table, which is very rare - // in production. - failedToFindActiveSession := 0 - findActiveSessions := 0 - - for batch := range test.numBatches { - // Generate `batchSize` random TCP SYN packets - // (with possible repetitions) - packets := make( - []gopacket.Packet, - 0, - test.batchSize+len(activeSessions)/2, - ) - for range test.batchSize { - // Generate random source IP in 10.x.x.x range - srcIP := netip.AddrFrom4([4]byte{ - 10, - byte(config.rng.Intn(256)), - byte(config.rng.Intn(256)), - byte(config.rng.Intn(256)), - }) - - // Generate random source port - srcPort := uint16(1024 + config.rng.Intn(64511)) // 1024-65535 - - // Create TCP SYN packet - packetLayers := utils.MakeTCPPacket( - srcIP, - srcPort, - config.vsIP, - config.vsPort, - &layers.TCP{SYN: true}, - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - prolongedSessions := make([]sessionKey, 0, len(activeSessions)/2) - for activeSession := range activeSessions { - if config.rng.Intn(3) == 0 { // prolong with 1/3 probability - packetLayers := utils.MakeTCPPacket( - activeSession.ip, - activeSession.port, - config.vsIP, - config.vsPort, - &layers.TCP{}, - ) - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - prolongedSessions = append(prolongedSessions, activeSession) - } - } - - totalPackets += len(packets) - - if batch%logPeriod == 0 || batch+1 == test.numBatches { - t.Logf( - "Batch %d sending packets to sessions: %d (%d new + %d active)", - batch, - len(packets), - test.batchSize, - len(prolongedSessions), - ) - } - - // Send all packets at once - result, err := config.ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - - currentTime := config.ts.Mock.CurrentTime() - - totalDropped += len(result.Drop) - totalOutput += len(result.Output) - - assert.Equal(t, len(packets), len(result.Drop)+len(result.Output)) - - for _, outPkt := range result.Output { - // Extract source IP and port from output packet - ip, ok := netip.AddrFromSlice(outPkt.InnerPacket.SrcIP) - require.True(t, ok, "incorrect src ip") - key := sessionKey{ - ip: ip, - port: outPkt.SrcPort, - } - activeSessions[key] = currentTime - } - - // Trace active sessions counters - findActiveSessions += len(prolongedSessions) - for _, session := range prolongedSessions { - if currentTime != activeSessions[session] { - failedToFindActiveSession += 1 - } - } - - // Sync active sessions and resize table on demand - err = config.ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Check active sessions are correct - activeSessionsCount := uint64(len(activeSessions)) - - if batch%logPeriod == 0 || batch+1 == test.numBatches { - currentConfig := config.ts.Balancer.Config() - currentCapacity := uint64(0) - if currentConfig.State != nil && - currentConfig.State.SessionTableCapacity != nil { - currentCapacity = *currentConfig.State.SessionTableCapacity - } - - t.Logf( - "Batch %d verified: Sent %d packet sessions (%d new + %d active), Output=%d, Drop=%d, ActiveSessions=%d, SessionTableCapacity=%d", - batch, - len(packets), - len(packets)-len(prolongedSessions), - len(prolongedSessions), - len(result.Output), - len(result.Drop), - activeSessionsCount, - currentCapacity, - ) - } - - // Get info - info, err := config.ts.Balancer.Info(currentTime) - require.NoError(t, err) - - // Verify active sessions for VS - require.Equal(t, 1, len(info.Vs), "should have exactly one VS") - vsActiveSessions := info.Vs[0].ActiveSessions - assert.Equal( - t, - activeSessionsCount, - vsActiveSessions, - "VS active sessions should match expected after batch %d", - batch+1, - ) - - // Verify active sessions for Real - require.Equal( - t, - 1, - len(info.Vs[0].Reals), - "should have exactly one Real", - ) - realActiveSessions := info.Vs[0].Reals[0].ActiveSessions - assert.Equal( - t, - activeSessionsCount, - realActiveSessions, - "Real active sessions should match expected after batch %d", - batch+1, - ) - - // Advance time - newTime := config.ts.Mock.AdvanceTime(advanceTimePerBatch) - - // Remove expired sessions from our tracking - for key, lastPacketTime := range activeSessions { - if newTime.Sub( - lastPacketTime, - ) > time.Duration( - config.timeouts, - )*time.Second { - delete(activeSessions, key) - } - } - } - - assert.Equal(t, totalPackets, totalDropped+totalOutput) - - dropRate := float64(totalDropped) / float64(totalPackets) * 100 - t.Logf( - "Drop rate: %.2f%% (dropped=%d, total=%d)", - dropRate, - totalDropped, - totalPackets, - ) - - // Just ensure some packets not dropped - assert.Less(t, dropRate, 10.0) - - failedToFindRate := float64( - failedToFindActiveSession, - ) / float64( - findActiveSessions, - ) * 100 - t.Logf( - "Failed to insert active sessions rate: %.2f%% (failed=%d, total=%d)", - failedToFindRate, - failedToFindActiveSession, - findActiveSessions, - ) - - // Just ensure not all active session packets - // were dropped - assert.Less(t, failedToFindRate, 10.0) - }) -} - -////////////////////////////////////////////////////////////////////////////// - -// TestSessionTableStress1 sends many random TCP SYN packets to a single -// real of a single VS, calling Refresh() after each batch and verifying -// that active sessions count matches expectations considering session -// expiration based on timeout. -func TestSessionTableStress1(t *testing.T) { - vsIP := netip.MustParseAddr("1.1.1.1") - vsPort := uint16(80) - realAddr := netip.MustParseAddr("2.2.2.2") - - sessionsTimeout := 64 // in seconds - defaultCapacity := 16 - maxLoadFactor := 0.5 - - // Configure balancer with single VS and single real - moduleConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIP.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realAddr.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("4.4.4.4").AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: uint32(sessionsTimeout), - TcpSyn: uint32(sessionsTimeout), - TcpFin: uint32(sessionsTimeout), - Tcp: uint32(sessionsTimeout), - Udp: uint32(sessionsTimeout), - Default: uint32(sessionsTimeout), - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(defaultCapacity); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(maxLoadFactor); return &v }(), - RefreshPeriod: durationpb.New( - 0, - ), // do not update in background - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(128*datasize.MB, 4*datasize.MB), - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 32 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Set time to mock - ts.Mock.SetCurrentTime(time.Unix(0, 0)) - - rng := rand.New(rand.NewSource(123)) - - config := stressConfig{ - ts: ts, - rng: rng, - vsIP: vsIP, - vsPort: vsPort, - realAddr: realAddr, - timeouts: sessionsTimeout, - } - - tests := []testCase{ - {10, 5, 2}, // just a small test - {10, 500, 2}, // emit a few packets with big time intervals - {10, 4, 5}, // emit a few packets with small time intervals - {500, 10, 2}, // emit many packets with small time interval - {500, 10, 5}, // emit many packets with big time interval - } - - // Run test cases - for _, test := range tests { - // Initially resize session table to the 90% of first batch size - newCapacity := uint64(9 * test.batchSize / 10) - if newCapacity < 1 { - newCapacity = 1 - } - - currentConfig := ts.Balancer.Config() - currentConfig.State.SessionTableCapacity = &newCapacity - _, err = ts.Balancer.Update(currentConfig, ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to shrink session table capacity") - - // Run current test case - testName := fmt.Sprintf( - "BatchSize_%d_NumBatches_%d_TimeoutBatches_%d", - test.batchSize, - test.numBatches, - test.timeoutBatches, - ) - t.Run(testName, func(t *testing.T) { - executeTestCase(t, &config, &test) - }) - - // Prepare the next test case - - // Make all sessions expire - ts.Mock.AdvanceTime(time.Duration(sessionsTimeout) * time.Second) - - // Sync active sessions - err = ts.Balancer.Refresh(ts.Mock.CurrentTime()) - require.NoError(t, err, "failed to sync active sessions before shrink") - } -} diff --git a/modules/balancer/tests/go/st_stress2_test.go b/modules/balancer/tests/go/st_stress2_test.go deleted file mode 100644 index 0646249f3..000000000 --- a/modules/balancer/tests/go/st_stress2_test.go +++ /dev/null @@ -1,673 +0,0 @@ -package balancer_test - -// TestSessionTableStress2 validates session table behavior with multiple virtual services: -// -// # Virtual Services Configuration -// Six virtual services with diverse configurations: -// - VS1: TCP IPv4, ROUND_ROBIN, no GRE/FixMSS, 2 IPv4 reals (weights 1:1) -// - VS2: UDP IPv4, SOURCE_HASH, no GRE, 2 IPv4 reals (weights 2:3) -// - VS3: TCP IPv6, ROUND_ROBIN, with GRE, 2 IPv6 reals (weights 1:1) -// - VS4: UDP IPv6, SOURCE_HASH, no GRE, 2 IPv6 reals (weights 1:2) -// - VS5: TCP IPv4, SOURCE_HASH, with GRE, mixed IPv4/IPv6 reals (weights 3:2) -// - VS6: TCP IPv6, ROUND_ROBIN, with FixMSS, 2 IPv4 reals (weights 1:1) -// -// # Test Configuration -// - Session timeout: 60 seconds -// - Initial capacity: 16 (dynamically resized) -// - Max load factor: 0.25 -// -// # Phase 1: Initial Session Creation -// - Sends 16 random packets across all virtual services -// - Tracks which sessions were accepted and their selected real servers -// - Verifies at least 75% acceptance rate -// -// # Phase 2: Iterative Stress Testing (10 iterations) -// Each iteration performs: -// - Step 1: Sync active sessions and resize table on demand -// - Step 2: Send packets to all existing sessions -// * Validates all packets are accepted (no drops) -// * Verifies session-to-real consistency (same client → same real) -// - Step 3: Create N/2 new sessions (where N = current active sessions) -// * Validates all new sessions are accepted -// * Tracks new sessions and their selected reals -// -// # Session Consistency Validation -// - Ensures the same client always reaches the same real server -// - Validates this consistency across packet retransmissions -// - Tests with both TCP (SYN and non-SYN) and UDP packets -// -// # Final Verification -// - Confirms tracked sessions match balancer's session count -// - Validates VS and Real active session counts are consistent -// - Verifies session table capacity and load factor - -import ( - "fmt" - "math/rand" - "net/netip" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -// sessionInfo tracks a session and its selected real -type sessionInfo struct { - clientIP netip.Addr - clientPort uint16 - vsIP netip.Addr - vsPort uint16 - proto balancerpb.TransportProto - realIP netip.Addr // The real server selected for this session -} - -// vsConfig holds configuration for a virtual service -type vsConfig struct { - ip netip.Addr - port uint16 - proto balancerpb.TransportProto - scheduler balancerpb.VsScheduler - gre bool - fixMss bool - reals []realConfig -} - -// realConfig holds configuration for a real server -type realConfig struct { - ip netip.Addr - weight uint32 -} - -// TestSessionTableStress2 tests session table with multiple virtual services, -// sequential resizing, and session consistency validation -func TestSessionTableStress2(t *testing.T) { - sessionTimeout := 60 // in seconds - initialCapacity := 16 - maxLoadFactor := 0.25 - - // Define virtual services configuration - // Mix of TCP/UDP, IPv4/IPv6, ROUND_ROBIN/SOURCE_HASH schedulers, with/without GRE and FixMSS - virtualServicesConfig := []vsConfig{ - // VS1: TCP IPv4, ROUND_ROBIN, no GRE, no FixMSS, IPv4 reals - { - ip: netip.MustParseAddr("10.1.1.1"), - port: 80, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: false, - reals: []realConfig{ - {netip.MustParseAddr("10.2.1.1"), 1}, - {netip.MustParseAddr("10.2.1.2"), 1}, - }, - }, - // VS2: UDP IPv4, SOURCE_HASH, no GRE, IPv4 reals - { - ip: netip.MustParseAddr("10.1.2.1"), - port: 5353, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_SOURCE_HASH, - gre: false, - fixMss: false, - reals: []realConfig{ - {netip.MustParseAddr("10.2.2.1"), 2}, - {netip.MustParseAddr("10.2.2.2"), 3}, - }, - }, - // VS3: TCP IPv6, ROUND_ROBIN, with GRE, IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::1"), - port: 443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: true, - fixMss: false, - reals: []realConfig{ - {netip.MustParseAddr("2001:db8:2::1"), 1}, - {netip.MustParseAddr("2001:db8:2::2"), 1}, - }, - }, - // VS4: UDP IPv6, SOURCE_HASH, IPv6 reals - { - ip: netip.MustParseAddr("2001:db8::2"), - port: 8080, - proto: balancerpb.TransportProto_UDP, - scheduler: balancerpb.VsScheduler_SOURCE_HASH, - gre: false, - fixMss: false, - reals: []realConfig{ - {netip.MustParseAddr("2001:db8:3::1"), 1}, - {netip.MustParseAddr("2001:db8:3::2"), 2}, - }, - }, - // VS5: TCP IPv4, SOURCE_HASH, with GRE, mixed IPv4 and IPv6 reals - { - ip: netip.MustParseAddr("10.1.3.1"), - port: 8443, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_SOURCE_HASH, - gre: true, - fixMss: false, - reals: []realConfig{ - {netip.MustParseAddr("10.2.4.1"), 3}, - {netip.MustParseAddr("2001:db8:4::2"), 2}, - }, - }, - // VS6: TCP IPv6, ROUND_ROBIN, with FixMSS, IPv4 reals - { - ip: netip.MustParseAddr("2001:db8::3"), - port: 9000, - proto: balancerpb.TransportProto_TCP, - scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - gre: false, - fixMss: true, - reals: []realConfig{ - {netip.MustParseAddr("10.2.3.1"), 1}, - {netip.MustParseAddr("10.2.3.2"), 1}, - }, - }, - } - - // Build module config from virtual services configuration - virtualServices := make( - []*balancerpb.VirtualService, - 0, - len(virtualServicesConfig), - ) - for _, vsConf := range virtualServicesConfig { - // Build allowed sources based on VS IP version - var allowedSrcs []*balancerpb.AllowedSources - if vsConf.ip.Is4() { - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0").AsSlice(), - }, - }}, - }, - } - } else { - allowedSrcs = []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("2001:db8::").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("ffff:ffff:ffff:ffff::").AsSlice(), - }, - }}, - }, - } - } - - // Build reals - reals := make([]*balancerpb.Real, 0, len(vsConf.reals)) - for _, realConf := range vsConf.reals { - var srcMask []byte - if realConf.ip.Is4() { - srcMask = netip.MustParseAddr("255.255.255.255").AsSlice() - } else { - srcMask = netip.MustParseAddr("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff").AsSlice() - } - - reals = append(reals, &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: realConf.ip.AsSlice(), - }, - Port: 0, - }, - Weight: realConf.weight, - SrcAddr: &balancerpb.Addr{ - Bytes: realConf.ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: srcMask, - }, - }) - } - - virtualServices = append(virtualServices, &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsConf.ip.AsSlice(), - }, - Port: uint32(vsConf.port), - Proto: vsConf.proto, - }, - AllowedSrcs: allowedSrcs, - Scheduler: vsConf.scheduler, - Flags: &balancerpb.VsFlags{ - Gre: vsConf.gre, - FixMss: vsConf.fixMss, - Ops: false, - PureL3: false, - Wlc: false, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - }) - } - - moduleConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: virtualServices, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: uint32(sessionTimeout), - TcpSyn: uint32(sessionTimeout), - TcpFin: uint32(sessionTimeout), - Tcp: uint32(sessionTimeout), - Udp: uint32(sessionTimeout), - Default: uint32(sessionTimeout), - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(initialCapacity); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(maxLoadFactor); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - // Setup test - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: moduleConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - // Set initial time - ts.Mock.SetCurrentTime(time.Unix(0, 0)) - - rng := rand.New(rand.NewSource(111222)) - - // Build simple list for random packet generation - type vsSimple struct { - ip netip.Addr - port uint16 - proto balancerpb.TransportProto - } - - vsSimpleList := make([]vsSimple, 0, len(virtualServicesConfig)) - for _, vsConf := range virtualServicesConfig { - vsSimpleList = append(vsSimpleList, vsSimple{ - ip: vsConf.ip, - port: vsConf.port, - proto: vsConf.proto, - }) - } - - // Helper to generate random client IP based on VS IP version - randomClientIP := func(vsIP netip.Addr) netip.Addr { - if vsIP.Is4() { - return netip.AddrFrom4([4]byte{ - 10, - byte(rng.Intn(256)), - byte(rng.Intn(256)), - byte(rng.Intn(256)), - }) - } - // IPv6 - return netip.MustParseAddr(fmt.Sprintf("2001:db8::%x", rng.Intn(65536))) - } - - // Helper to generate random port - randomPort := func() uint16 { - return uint16(1024 + rng.Intn(64511)) - } - - // Track active sessions with their selected reals - activeSessions := make(map[string]*sessionInfo) - - // Helper to create session key - sessionKey := func(clientIP netip.Addr, clientPort uint16, vsIP netip.Addr, vsPort uint16) string { - return fmt.Sprintf("%s:%d->%s:%d", clientIP, clientPort, vsIP, vsPort) - } - - makeTCPSynPacket := func(clientIP netip.Addr, clientPort uint16, vsIP netip.Addr, vsPort uint16) gopacket.Packet { - packetLayers := utils.MakeTCPPacket( - clientIP, - clientPort, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - return xpacket.LayersToPacket(t, packetLayers...) - } - - // Phase 1: Send 16 random packets to establish initial sessions - t.Run("Phase1_Create_16_Initial_Sessions", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, 16) - packetToSession := make(map[int]*sessionInfo) - - for i := range 16 { - // Randomly select a virtual service - vs := vsSimpleList[rng.Intn(len(vsSimpleList))] - clientIP := randomClientIP(vs.ip) - clientPort := randomPort() - - session := &sessionInfo{ - clientIP: clientIP, - clientPort: clientPort, - vsIP: vs.ip, - vsPort: vs.port, - proto: vs.proto, - } - packetToSession[i] = session - - var packet gopacket.Packet - if vs.proto == balancerpb.TransportProto_TCP { - packet = makeTCPSynPacket(clientIP, clientPort, vs.ip, vs.port) - } else { - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - vs.ip, - vs.port, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - } - packets = append(packets, packet) - } - - result, err := ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - - t.Logf( - "Sent 16 packets: Output=%d, Drop=%d", - len(result.Output), - len(result.Drop), - ) - - // Track which sessions were accepted and their selected reals - for i, outPacket := range result.Output { - session := packetToSession[i] - - // Extract the real IP from the output packet - realIP, ok := netip.AddrFromSlice(outPacket.DstIP) - require.True(t, ok, "failed to parse real IP") - session.realIP = realIP - - key := sessionKey( - session.clientIP, - session.clientPort, - session.vsIP, - session.vsPort, - ) - activeSessions[key] = session - } - - t.Logf("Created %d initial sessions", len(activeSessions)) - assert.GreaterOrEqual( - t, - len(activeSessions), - 12, - "at least 75%% of packets should be accepted", - ) - }) - - // Phase 2: Perform 10 iterations of the stress test cycle - for iteration := range 10 { - t.Run(fmt.Sprintf("Iteration_%d", iteration+1), func(t *testing.T) { - currentTime := ts.Mock.CurrentTime() - - // Step 1: Sync active sessions and resize table on demand - t.Run("Step1_Sync_And_Resize", func(t *testing.T) { - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - config := ts.Balancer.Config() - capacity := uint64(0) - if config.State != nil && - config.State.SessionTableCapacity != nil { - capacity = *config.State.SessionTableCapacity - } - t.Logf( - "Session table capacity: %d, Active sessions: %d", - capacity, - len(activeSessions), - ) - }) - - // Step 2: Send packets to existing sessions (TCP non-SYN or UDP) - t.Run("Step2_Send_To_Existing_Sessions", func(t *testing.T) { - packets := make([]gopacket.Packet, 0, len(activeSessions)) - sessionList := make([]*sessionInfo, 0, len(activeSessions)) - - for _, session := range activeSessions { - sessionList = append(sessionList, session) - - var packetLayers []gopacket.SerializableLayer - if session.proto == balancerpb.TransportProto_TCP { - packetLayers = utils.MakeTCPPacket( - session.clientIP, - session.clientPort, - session.vsIP, - session.vsPort, - &layers.TCP{}, // No SYN flag - ) - } else { - packetLayers = utils.MakeUDPPacket( - session.clientIP, - session.clientPort, - session.vsIP, - session.vsPort, - ) - } - packets = append( - packets, - xpacket.LayersToPacket(t, packetLayers...), - ) - } - - result, err := ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - - // All packets should be accepted (no drops) - assert.Equal( - t, - len(packets), - len(result.Output), - "all packets to existing sessions should be accepted", - ) - assert.Empty(t, result.Drop, "no packets should be dropped") - - // Validate each packet and verify real consistency - for i, outPacket := range result.Output { - session := sessionList[i] - - // Verify the same real is selected - realIP, ok := netip.AddrFromSlice(outPacket.DstIP) - require.True(t, ok, "failed to parse real IP") - assert.Equal( - t, - session.realIP, - realIP, - "real server should remain consistent for session %s:%d->%s:%d", - session.clientIP, - session.clientPort, - session.vsIP, - session.vsPort, - ) - } - - t.Logf( - "Sent %d packets to existing sessions, all accepted with consistent reals", - len(packets), - ) - }) - - // Step 3: Create N/2 new sessions - t.Run("Step3_Create_New_Sessions", func(t *testing.T) { - N := len(activeSessions) - newSessionCount := N / 2 - if newSessionCount == 0 { - newSessionCount = 1 - } - - packets := make([]gopacket.Packet, 0, newSessionCount) - packetToSession := make(map[int]*sessionInfo) - - for i := 0; i < newSessionCount; i++ { - // Randomly select a virtual service - vs := vsSimpleList[rng.Intn(len(vsSimpleList))] - clientIP := randomClientIP(vs.ip) - clientPort := randomPort() - - session := &sessionInfo{ - clientIP: clientIP, - clientPort: clientPort, - vsIP: vs.ip, - vsPort: vs.port, - proto: vs.proto, - } - packetToSession[i] = session - - var packet gopacket.Packet - if vs.proto == balancerpb.TransportProto_TCP { - packet = makeTCPSynPacket( - clientIP, - clientPort, - vs.ip, - vs.port, - ) - } else { - packetLayers := utils.MakeUDPPacket( - clientIP, - clientPort, - vs.ip, - vs.port, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - } - packets = append(packets, packet) - } - - result, err := ts.Mock.HandlePackets(packets...) - require.NoError(t, err) - - // All new sessions should be accepted - require.Equal( - t, - len(packets), - len(result.Output), - "all new session packets should be accepted", - ) - require.Empty(t, result.Drop, "no packets should be dropped") - - // Track new sessions and their selected reals - for i, outPacket := range result.Output { - session := packetToSession[i] - - // Extract the real IP from the output packet - realIP, ok := netip.AddrFromSlice(outPacket.DstIP) - require.True(t, ok, "failed to parse real IP") - session.realIP = realIP - - key := sessionKey( - session.clientIP, - session.clientPort, - session.vsIP, - session.vsPort, - ) - activeSessions[key] = session - } - - t.Logf( - "Created %d new sessions (N=%d, N/2=%d), total active: %d", - len(result.Output), - N, - newSessionCount, - len(activeSessions), - ) - }) - - // Note: Do NOT advance time as per requirements - }) - } - - // Final verification - t.Run("Final_Verification", func(t *testing.T) { - currentTime := ts.Mock.CurrentTime() - - // Sync one more time - err := ts.Balancer.Refresh(currentTime) - require.NoError(t, err) - - // Get sessions info - sessions, err := ts.Balancer.Sessions(currentTime) - require.NoError(t, err) - - t.Logf( - "Final state: %d active sessions tracked, %d sessions in balancer", - len(activeSessions), - len(sessions), - ) - - // Verify session count matches - assert.Equal(t, len(activeSessions), len(sessions), - "tracked sessions should match balancer sessions") - - // Get info - info, err := ts.Balancer.Info(currentTime) - require.NoError(t, err) - t.Logf("Module active sessions: %d", info.ActiveSessions) - - config := ts.Balancer.Config() - capacity := uint64(0) - if config.State != nil && config.State.SessionTableCapacity != nil { - capacity = *config.State.SessionTableCapacity - } - t.Logf("Session table capacity: %d", capacity) - - // Verify VS and Real active sessions sum up correctly - totalVsSessions := uint64(0) - for _, vsInfo := range info.Vs { - totalVsSessions += vsInfo.ActiveSessions - } - t.Logf("Total VS active sessions: %d", totalVsSessions) - - totalRealSessions := uint64(0) - for _, vsInfo := range info.Vs { - for _, realInfo := range vsInfo.Reals { - totalRealSessions += realInfo.ActiveSessions - } - } - t.Logf("Total Real active sessions: %d", totalRealSessions) - - // The total should match (each session belongs to one VS and one Real) - assert.Equal(t, totalVsSessions, totalRealSessions, - "VS and Real session counts should match") - }) -} diff --git a/modules/balancer/tests/go/st_update_reals_test.go b/modules/balancer/tests/go/st_update_reals_test.go deleted file mode 100644 index a83d69d8a..000000000 --- a/modules/balancer/tests/go/st_update_reals_test.go +++ /dev/null @@ -1 +0,0 @@ -package balancer_test diff --git a/modules/balancer/tests/go/update_test.go b/modules/balancer/tests/go/update_test.go new file mode 100644 index 000000000..5f7491c8f --- /dev/null +++ b/modules/balancer/tests/go/update_test.go @@ -0,0 +1,251 @@ +package balancer + +import ( + "fmt" + "math/rand/v2" + "testing" + + "github.com/c2h5oh/datasize" + "github.com/stretchr/testify/assert" + mock "github.com/yanet-platform/yanet2/mock/go" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" + "google.golang.org/protobuf/types/known/durationpb" +) + +type vsBuildParams struct { + Flags *balancerpb.VsFlags + minRealsCnt int + maxRealsCnt int + minAllowedSrcCnt int + maxAllowedSrcCnt int +} + +func buildVS( + cntNew int, + cntReuse int, + params vsBuildParams, + prevVS []*balancerpb.VirtualService, + rng *rand.Rand, +) ([]*balancerpb.VirtualService, []*balancerpb.VirtualService) { + result := make([]*balancerpb.VirtualService, 0, cntNew+cntReuse) + var reused []*balancerpb.VirtualService + if prevVS != nil && cntReuse > 0 { + reused = utils.SelectVS(cntReuse, prevVS, rng) + for _, vs := range reused { + result = append(result, utils.VSUpdateSomeReals(vs, rng)) + } + } + for range cntNew { + realsCnt := rng.IntN(params.maxRealsCnt-params.minRealsCnt+1) + params.minRealsCnt + allowedSrcCnt := rng.IntN( + params.maxAllowedSrcCnt-params.minAllowedSrcCnt+1, + ) + params.minAllowedSrcCnt + result = append(result, utils.GenerateVS(rng, realsCnt, allowedSrcCnt, params.Flags)) + } + if len(result) != cntNew+cntReuse { + panic( + fmt.Sprintf( + "buildVS: result count mismatch: %d != %d + %d", + len(result), + cntNew, + cntReuse, + ), + ) + } + return result, reused +} + +func buildInitialConfig( + rng *rand.Rand, + cntVS int, + vsParams vsBuildParams, +) *balancerpb.BalancerConfig { + stCapacity := uint64(50000) + stMaxLoadFactor := float32(1.0) + wlcPower := uint64(10) + maxWeight := uint32(100) + wlc := &balancerpb.WlcConfig{ + Power: &wlcPower, + MaxWeight: &maxWeight, + } + vs, _ := buildVS(cntVS, 0, vsParams, nil, rng) + return &balancerpb.BalancerConfig{ + PacketHandler: &balancerpb.PacketHandlerConfig{ + SourceAddressV4: utils.GenerateIPv4Address(rng).AsSlice(), + SourceAddressV6: utils.GenerateIPv6Address(rng).AsSlice(), + DecapAddresses: [][]byte{}, + SessionsTimeouts: &balancerpb.SessionsTimeouts{ + TcpSynAck: 60, + TcpSyn: 60, + TcpFin: 60, + Tcp: 60, + Udp: 60, + }, + Vs: vs, + }, + State: &balancerpb.StateConfig{ + SessionTableCapacity: &stCapacity, + SessionTableMaxLoadFactor: &stMaxLoadFactor, + Wlc: wlc, + RefreshPeriod: &durationpb.Duration{ + Nanos: 25 * 1000 * 1000, + }, + }, + } +} + +func buildTestSetup(balancerConfig *balancerpb.BalancerConfig) (*utils.TestSetup, error) { + testConfig := &utils.TestConfig{ + Mock: &mock.YanetMockConfig{ + AgentsMemory: 256 * datasize.MB, + DpMemory: 128 * datasize.MB, + Workers: 1, + Devices: []mock.YanetMockDeviceConfig{ + { + ID: 0, + Name: "device0", + }, + }, + }, + Balancer: balancerConfig, + AgentMemory: 128 * datasize.MB, + } + return utils.Make(testConfig) +} + +func updateReals(t *testing.T, ts *utils.TestSetup, rng *rand.Rand) { + config := ts.Balancer.Config() + vs := config.PacketHandler.Vs + selected := utils.SelectVS(len(vs)/2, vs, rng) + updates := utils.GenerateRealUpdates(selected, rng) + updated, err := ts.Balancer.UpdateReals(updates, false) + assert.NoError(t, err) + assert.Equal(t, len(updates), updated) +} + +// runUpdateRealsRound sends one packet per VS then repeats real-update + send iters times. +// After each send it verifies that per-VS incoming and outgoing packet counters increased. +func runUpdateRealsRound(t *testing.T, ts *utils.TestSetup, rng *rand.Rand, iters int) { + utils.SendAndValidateMany(t, ts, rng) + for range iters { + updateReals(t, ts, rng) + utils.SendAndValidateMany(t, ts, rng) + } +} + +func stepUpdateVS( + t *testing.T, + ts *utils.TestSetup, + vsParams vsBuildParams, + rng *rand.Rand, + iter int, + realsIters int, +) { + b := ts.Balancer + newCnt := rng.IntN(5) + prevCount := utils.VsCount(b) + reuseCnt := min(rng.IntN(5), prevCount) + t.Logf("iter=%d, UpdateVS: newCnt=%d, reuseCnt=%d", iter, newCnt, reuseCnt) + + newVSList, reusedVS := buildVS(newCnt, reuseCnt, vsParams, b.Config().PacketHandler.Vs, rng) + snapshots := utils.CaptureVsSnapshots(t, ts, reusedVS) + _, err := b.UpdateVS(newVSList) + assert.NoError(t, err, "iter=%d: failed to UpdateVS", iter) + assert.Equal( + t, + newCnt+prevCount, + utils.VsCount(b), + "iter=%d: vs count mismatch after UpdateVS", + iter, + ) + utils.EnableAllReals(t, ts) + utils.VerifyInheritedStats(t, ts, snapshots) + runUpdateRealsRound(t, ts, rng, realsIters) +} + +func stepDeleteVS( + t *testing.T, + ts *utils.TestSetup, + rng *rand.Rand, + iter int, + realsIters int, +) { + b := ts.Balancer + prevCount := utils.VsCount(b) + delCnt := max(rng.IntN(prevCount/4), 1) + t.Logf("iter=%d, DeleteVS: delCnt=%d, prevCount=%d", iter, delCnt, prevCount) + + allVS := b.Config().PacketHandler.Vs + snapshots := utils.CaptureVsSnapshots(t, ts, allVS) + _, err := b.DeleteVS(utils.SelectVS(delCnt, allVS, rng)) + assert.NoError(t, err, "iter=%d: failed to DeleteVS", iter) + assert.Equal( + t, + prevCount-delCnt, + utils.VsCount(b), + "iter=%d: vs count mismatch after DeleteVS", + iter, + ) + utils.VerifyInheritedStats(t, ts, snapshots) + runUpdateRealsRound(t, ts, rng, realsIters) +} + +func stepUpdateConfig( + t *testing.T, + ts *utils.TestSetup, + vsParams vsBuildParams, + rng *rand.Rand, + iter int, + realsIters int, +) { + b := ts.Balancer + newCnt := max(utils.VsCount(b)/2+2-rng.IntN(5), 10) + reuseCnt := max(utils.VsCount(b)/2+2-rng.IntN(5), 0) + t.Logf("iter=%d, Update: newCnt=%d, reuseCnt=%d", iter, newCnt, reuseCnt) + + config := b.Config() + newVSList, reusedVS := buildVS(newCnt, reuseCnt, vsParams, config.PacketHandler.Vs, rng) + config.PacketHandler.Vs = newVSList + snapshots := utils.CaptureVsSnapshots(t, ts, reusedVS) + _, err := b.Update(config, nil) + assert.NoError(t, err, "iter=%d: failed to Update config", iter) + assert.Equal( + t, + newCnt+reuseCnt, + utils.VsCount(b), + "iter=%d: vs count mismatch after Update", + iter, + ) + utils.EnableAllReals(t, ts) + utils.VerifyInheritedStats(t, ts, snapshots) + runUpdateRealsRound(t, ts, rng, realsIters) +} + +func TestUpdateStress(t *testing.T) { + rng := rand.New(rand.NewPCG(uint64(100), uint64(123))) + vsParams := vsBuildParams{ + Flags: &balancerpb.VsFlags{FixMss: true}, + minRealsCnt: 5, + maxRealsCnt: 15, + minAllowedSrcCnt: 1, + maxAllowedSrcCnt: 3, + } + + ts, err := buildTestSetup(buildInitialConfig(rng, 20, vsParams)) + if err != nil { + t.Fatalf("failed to make test setup: %v", err) + } + defer ts.Free() + + utils.EnableAllReals(t, ts) + runUpdateRealsRound(t, ts, rng, 20) + t.Log("initial vs count:", utils.VsCount(ts.Balancer)) + + for iter := range 15 { + stepUpdateVS(t, ts, vsParams, rng, iter, 20) + stepDeleteVS(t, ts, rng, iter, 20) + stepUpdateConfig(t, ts, vsParams, rng, iter, 20) + t.Logf("iter=%d done: vs_count=%d", iter, utils.VsCount(ts.Balancer)) + } +} diff --git a/modules/balancer/tests/go/update_vs_test.go b/modules/balancer/tests/go/update_vs_test.go deleted file mode 100644 index eb4763e31..000000000 --- a/modules/balancer/tests/go/update_vs_test.go +++ /dev/null @@ -1,960 +0,0 @@ -package balancer_test - -// TestUpdateVSAndDeleteVS provides comprehensive testing for UpdateVS and DeleteVS methods, -// verifying configuration updates, WLC index management, ACL reuse filtering, and idempotent operations. - -import ( - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "google.golang.org/protobuf/types/known/durationpb" -) - -// Test addresses -var ( - // Virtual Services - testVs1IP = netip.MustParseAddr("10.1.1.1") - testVs1Port = uint16(80) - testVs2IP = netip.MustParseAddr("10.1.2.1") - testVs2Port = uint16(80) - testVs3IP = netip.MustParseAddr("10.1.3.1") - testVs3Port = uint16(80) - testVs4IP = netip.MustParseAddr("10.1.4.1") - testVs4Port = uint16(80) - - // Real servers - testReal1IP = netip.MustParseAddr("192.168.1.1") - testReal2IP = netip.MustParseAddr("192.168.1.2") - testReal3IP = netip.MustParseAddr("192.168.1.3") - testReal4IP = netip.MustParseAddr("192.168.2.1") - testReal5IP = netip.MustParseAddr("192.168.2.2") - testReal6IP = netip.MustParseAddr("192.168.3.1") - testReal7IP = netip.MustParseAddr("192.168.3.2") - testReal8IP = netip.MustParseAddr("192.168.4.1") - testReal9IP = netip.MustParseAddr("192.168.4.2") -) - -// createTestReal creates a Real configuration -func createTestReal(ip netip.Addr, weight uint32) *balancerpb.Real { - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: 0, - }, - Weight: weight, - SrcAddr: &balancerpb.Addr{ - Bytes: ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255").AsSlice(), - }, - } -} - -// createTestVS creates a VirtualService configuration -func createTestVS( - ip netip.Addr, - port uint16, - wlc bool, - reals []*balancerpb.Real, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: wlc, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - } -} - -// createTestVSWithACL creates a VirtualService with specific ACL configuration -func createTestVSWithACL( - ip netip.Addr, - port uint16, - wlc bool, - reals []*balancerpb.Real, - allowedNets []*balancerpb.Net, -) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: allowedNets, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: wlc, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - } -} - -// createInitialTestConfig creates the initial balancer configuration -func createInitialTestConfig() *balancerpb.BalancerConfig { - return &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - createTestVS(testVs1IP, testVs1Port, true, []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - createTestReal(testReal2IP, 2), - }), - createTestVS(testVs2IP, testVs2Port, false, []*balancerpb.Real{ - createTestReal(testReal3IP, 1), - }), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } -} - -// findVSInConfig finds a VS in config by IP address -func findVSInConfig( - config *balancerpb.BalancerConfig, - vsIP netip.Addr, -) *balancerpb.VirtualService { - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - if addr == vsIP { - return vs - } - } - return nil -} - -// verifyWLCConfig verifies that WLC configuration is correctly set for specified VSs -func verifyWLCConfig( - t *testing.T, - config *balancerpb.BalancerConfig, - expectedWLCVSs []netip.Addr, -) { - t.Helper() - - require.NotNil(t, config.State, "State config should not be nil") - require.NotNil(t, config.State.Wlc, "WLC config should not be nil") - - // Build map of expected WLC VSs - expectedWLC := make(map[netip.Addr]bool) - for _, vsIP := range expectedWLCVSs { - expectedWLC[vsIP] = true - } - - // Verify each VS has correct WLC flag - for _, vs := range config.PacketHandler.Vs { - addr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - require.NotNil(t, vs.Flags, "VS flags should not be nil for %s", addr) - - if expectedWLC[addr] { - assert.True(t, vs.Flags.Wlc, "VS %s should have WLC enabled", addr) - } else { - assert.False(t, vs.Flags.Wlc, "VS %s should have WLC disabled", addr) - } - } -} - -// TestUpdateVSBasicOperations tests basic UpdateVS operations -func TestUpdateVSBasicOperations(t *testing.T) { - config := createInitialTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("AddNewVS", func(t *testing.T) { - // Add VS3 with WLC enabled - newVS := createTestVS(testVs3IP, testVs3Port, true, []*balancerpb.Real{ - createTestReal(testReal4IP, 1), - createTestReal(testReal5IP, 1), - }) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{newVS}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Verify config - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 3, - len(updatedConfig.PacketHandler.Vs), - "should have 3 VS", - ) - - // Verify VS3 exists - vs3 := findVSInConfig(updatedConfig, testVs3IP) - require.NotNil(t, vs3, "VS3 should exist") - assert.True(t, vs3.Flags.Wlc, "VS3 should have WLC enabled") - assert.Equal(t, 2, len(vs3.Reals), "VS3 should have 2 reals") - - // Verify WLC: VS1 and VS3 should have WLC enabled - verifyWLCConfig(t, updatedConfig, []netip.Addr{testVs1IP, testVs3IP}) - }) - - t.Run("UpdateExistingVS", func(t *testing.T) { - // Update VS1: change from WLC=true to WLC=false - updatedVS1 := createTestVS( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal6IP, 1), - }, - ) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS1}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Verify config - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 3, - len(updatedConfig.PacketHandler.Vs), - "should still have 3 VS", - ) - - // Verify VS1 updated - vs1 := findVSInConfig(updatedConfig, testVs1IP) - require.NotNil(t, vs1, "VS1 should exist") - assert.False(t, vs1.Flags.Wlc, "VS1 should have WLC disabled") - assert.Equal(t, 1, len(vs1.Reals), "VS1 should have 1 real") - - // Verify WLC: only VS3 should have WLC enabled now - verifyWLCConfig(t, updatedConfig, []netip.Addr{testVs3IP}) - }) - - t.Run("UpdateMultipleVS", func(t *testing.T) { - // Update VS2 to enable WLC and add VS4 with WLC - updatedVS2 := createTestVS( - testVs2IP, - testVs2Port, - true, - []*balancerpb.Real{ - createTestReal(testReal7IP, 2), - createTestReal(testReal8IP, 1), - }, - ) - newVS4 := createTestVS(testVs4IP, testVs4Port, true, []*balancerpb.Real{ - createTestReal(testReal9IP, 1), - }) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS2, newVS4}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Verify config - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 4, - len(updatedConfig.PacketHandler.Vs), - "should have 4 VS", - ) - - // Verify VS2 updated - vs2 := findVSInConfig(updatedConfig, testVs2IP) - require.NotNil(t, vs2, "VS2 should exist") - assert.True(t, vs2.Flags.Wlc, "VS2 should have WLC enabled") - assert.Equal(t, 2, len(vs2.Reals), "VS2 should have 2 reals") - - // Verify VS4 added - vs4 := findVSInConfig(updatedConfig, testVs4IP) - require.NotNil(t, vs4, "VS4 should exist") - assert.True(t, vs4.Flags.Wlc, "VS4 should have WLC enabled") - - // Verify WLC: VS2, VS3, VS4 should have WLC enabled - verifyWLCConfig( - t, - updatedConfig, - []netip.Addr{testVs2IP, testVs3IP, testVs4IP}, - ) - }) -} - -// TestDeleteVSBasicOperations tests basic DeleteVS operations -func TestDeleteVSBasicOperations(t *testing.T) { - config := createInitialTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("DeleteSingleVS", func(t *testing.T) { - // Delete VS1 - vsToDelete := createTestVS(testVs1IP, testVs1Port, false, nil) - - updateInfo, err := ts.Balancer.DeleteVS( - []*balancerpb.VirtualService{vsToDelete}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Verify ACL reuse list is empty for delete - assert.Empty( - t, - updateInfo.ACLReusedVs, - "ACL reuse list should be empty for delete", - ) - - // Verify config - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 1, - len(updatedConfig.PacketHandler.Vs), - "should have 1 VS", - ) - - // Verify VS1 deleted - vs1 := findVSInConfig(updatedConfig, testVs1IP) - assert.Nil(t, vs1, "VS1 should be deleted") - - // Verify VS2 still exists - vs2 := findVSInConfig(updatedConfig, testVs2IP) - require.NotNil(t, vs2, "VS2 should still exist") - - // Verify WLC: no WLC-enabled VS remaining - verifyWLCConfig(t, updatedConfig, []netip.Addr{}) - }) - - t.Run("DeleteMultipleVS", func(t *testing.T) { - // Re-add VS1 and VS3 - vs1 := createTestVS(testVs1IP, testVs1Port, true, []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }) - vs3 := createTestVS(testVs3IP, testVs3Port, false, []*balancerpb.Real{ - createTestReal(testReal4IP, 1), - }) - - _, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{vs1, vs3}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Delete VS2 and VS3 - vsToDelete := []*balancerpb.VirtualService{ - createTestVS(testVs2IP, testVs2Port, false, nil), - createTestVS(testVs3IP, testVs3Port, false, nil), - } - - updateInfo, err := ts.Balancer.DeleteVS( - vsToDelete, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Verify config - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 1, - len(updatedConfig.PacketHandler.Vs), - "should have 1 VS", - ) - - // Verify only VS1 remains - vs1Found := findVSInConfig(updatedConfig, testVs1IP) - require.NotNil(t, vs1Found, "VS1 should exist") - - // Verify WLC: VS1 should have WLC enabled - verifyWLCConfig(t, updatedConfig, []netip.Addr{testVs1IP}) - }) - - t.Run("IdempotentDelete", func(t *testing.T) { - // Try to delete non-existent VS - vsToDelete := createTestVS(testVs4IP, testVs4Port, false, nil) - - updateInfo, err := ts.Balancer.DeleteVS( - []*balancerpb.VirtualService{vsToDelete}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err, "deleting non-existent VS should not error") - require.NotNil(t, updateInfo) - - // Verify config unchanged - updatedConfig := ts.Balancer.Config() - require.NotNil(t, updatedConfig) - assert.Equal( - t, - 1, - len(updatedConfig.PacketHandler.Vs), - "should still have 1 VS", - ) - }) -} - -// TestUpdateVSAndDeleteVSWorkflow tests complex workflows -func TestUpdateVSAndDeleteVSWorkflow(t *testing.T) { - config := createInitialTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - t.Run("ComplexWorkflow", func(t *testing.T) { - // Step 1: Add VS3 and VS4 - vs3 := createTestVS(testVs3IP, testVs3Port, true, []*balancerpb.Real{ - createTestReal(testReal4IP, 1), - }) - vs4 := createTestVS(testVs4IP, testVs4Port, false, []*balancerpb.Real{ - createTestReal(testReal5IP, 1), - }) - - _, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{vs3, vs4}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - config := ts.Balancer.Config() - assert.Equal(t, 4, len(config.PacketHandler.Vs), "should have 4 VS") - verifyWLCConfig(t, config, []netip.Addr{testVs1IP, testVs3IP}) - - // Step 2: Send packets to VS1 - clientIP := netip.MustParseAddr("3.3.3.1") - packetLayers := utils.MakeTCPPacket( - clientIP, - 1000, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output)) - - // Step 3: Delete VS2 - vsToDelete := createTestVS(testVs2IP, testVs2Port, false, nil) - _, err = ts.Balancer.DeleteVS( - []*balancerpb.VirtualService{vsToDelete}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - config = ts.Balancer.Config() - assert.Equal(t, 3, len(config.PacketHandler.Vs), "should have 3 VS") - verifyWLCConfig(t, config, []netip.Addr{testVs1IP, testVs3IP}) - - // Step 4: Update VS1 to disable WLC - updatedVS1 := createTestVS( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - createTestReal(testReal2IP, 1), - }, - ) - _, err = ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS1}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - config = ts.Balancer.Config() - verifyWLCConfig(t, config, []netip.Addr{testVs3IP}) - - // Step 5: Delete all remaining VS - allVS := []*balancerpb.VirtualService{ - createTestVS(testVs1IP, testVs1Port, false, nil), - createTestVS(testVs3IP, testVs3Port, false, nil), - createTestVS(testVs4IP, testVs4Port, false, nil), - } - _, err = ts.Balancer.DeleteVS(allVS, ts.Mock.CurrentTime()) - require.NoError(t, err) - - config = ts.Balancer.Config() - assert.Equal(t, 0, len(config.PacketHandler.Vs), "should have 0 VS") - verifyWLCConfig(t, config, []netip.Addr{}) - }) -} - -// TestWLCIndexRecalculation tests WLC index management during updates/deletes -func TestWLCIndexRecalculation(t *testing.T) { - config := createInitialTestConfig() - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - t.Run("WLCIndexRecalculationOnUpdate", func(t *testing.T) { - // Initial: VS1 (WLC=true, index 0), VS2 (WLC=false, index 1) - initialConfig := ts.Balancer.Config() - verifyWLCConfig(t, initialConfig, []netip.Addr{testVs1IP}) - - // Add VS3 with WLC=true - vs3 := createTestVS(testVs3IP, testVs3Port, true, []*balancerpb.Real{ - createTestReal(testReal4IP, 1), - }) - _, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{vs3}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Now: VS1 (index 0, WLC), VS2 (index 1, no WLC), VS3 (index 2, WLC) - config := ts.Balancer.Config() - verifyWLCConfig(t, config, []netip.Addr{testVs1IP, testVs3IP}) - - // Update VS2 to enable WLC - updatedVS2 := createTestVS( - testVs2IP, - testVs2Port, - true, - []*balancerpb.Real{ - createTestReal(testReal3IP, 1), - }, - ) - _, err = ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS2}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Now all 3 VS have WLC - config = ts.Balancer.Config() - verifyWLCConfig( - t, - config, - []netip.Addr{testVs1IP, testVs2IP, testVs3IP}, - ) - }) - - t.Run("WLCIndexRecalculationOnDelete", func(t *testing.T) { - // Current: VS1 (index 0, WLC), VS2 (index 1, WLC), VS3 (index 2, WLC) - - // Delete VS2 (middle VS with WLC) - vsToDelete := createTestVS(testVs2IP, testVs2Port, false, nil) - _, err := ts.Balancer.DeleteVS( - []*balancerpb.VirtualService{vsToDelete}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - - // Now: VS1 (index 0, WLC), VS3 (index 1, WLC) - indices shifted - config := ts.Balancer.Config() - verifyWLCConfig(t, config, []netip.Addr{testVs1IP, testVs3IP}) - - // Verify VS3 is now at index 1 - assert.Equal(t, 2, len(config.PacketHandler.Vs)) - vs3Addr, _ := netip.AddrFromSlice( - config.PacketHandler.Vs[1].Id.Addr.Bytes, - ) - assert.Equal(t, testVs3IP, vs3Addr, "VS3 should be at index 1") - }) -} - -// TestACLRebuildVerification tests that ACL filters are properly rebuilt on VS update -func TestACLRebuildVerification(t *testing.T) { - // Create initial config with VS1 allowing only 10.0.0.0/8 - initialConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - createTestVSWithACL( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }, - []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0").AsSlice(), - }, - }}, - ), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(1000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.8); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: initialConfig, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err) - defer ts.Free() - - // Enable all reals - utils.EnableAllReals(t, ts) - - t.Run("InitialACL_AllowsOnly10Network", func(t *testing.T) { - // Packet from 10.0.0.1 should be allowed - allowedClient := netip.MustParseAddr("10.0.0.1") - packetLayers := utils.MakeTCPPacket( - allowedClient, - 1000, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "packet from 10.0.0.1 should be allowed", - ) - assert.Empty(t, result.Drop, "packet should not be dropped") - - // Packet from 192.168.1.1 should be denied - deniedClient := netip.MustParseAddr("192.168.1.100") - packetLayers = utils.MakeTCPPacket( - deniedClient, - 1001, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Empty( - t, - result.Output, - "packet from 192.168.1.100 should be denied", - ) - assert.Equal(t, 1, len(result.Drop), "packet should be dropped") - }) - - t.Run("UpdateACL_AllowsOnly192Network", func(t *testing.T) { - // Update VS1 to allow only 192.168.0.0/16 - updatedVS1 := createTestVSWithACL( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }, - []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("192.168.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.0.0").AsSlice(), - }, - }}, - ) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS1}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Now packet from 10.0.0.1 should be denied - deniedClient := netip.MustParseAddr("10.0.0.1") - packetLayers := utils.MakeTCPPacket( - deniedClient, - 1002, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Empty( - t, - result.Output, - "packet from 10.0.0.1 should now be denied", - ) - assert.Equal(t, 1, len(result.Drop), "packet should be dropped") - - // Packet from 192.168.1.1 should now be allowed - allowedClient := netip.MustParseAddr("192.168.1.100") - packetLayers = utils.MakeTCPPacket( - allowedClient, - 1003, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "packet from 192.168.1.100 should now be allowed", - ) - assert.Empty(t, result.Drop, "packet should not be dropped") - }) - - t.Run("UpdateACL_AllowsAllNetworks", func(t *testing.T) { - // Update VS1 to allow all networks (0.0.0.0/0) - updatedVS1 := createTestVSWithACL( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }, - []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - ) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{updatedVS1}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Both packets should now be allowed - client1 := netip.MustParseAddr("10.0.0.1") - packetLayers := utils.MakeTCPPacket( - client1, - 1004, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "packet from 10.0.0.1 should be allowed", - ) - assert.Empty(t, result.Drop) - - client2 := netip.MustParseAddr("192.168.1.100") - packetLayers = utils.MakeTCPPacket( - client2, - 1005, - testVs1IP, - testVs1Port, - &layers.TCP{SYN: true}, - ) - packet = xpacket.LayersToPacket(t, packetLayers...) - result, err = ts.Mock.HandlePackets(packet) - require.NoError(t, err) - assert.Equal( - t, - 1, - len(result.Output), - "packet from 192.168.1.100 should be allowed", - ) - assert.Empty(t, result.Drop) - }) - - t.Run("VerifyACLReuseReporting", func(t *testing.T) { - // Update VS1 with same ACL - should report ACL reuse - sameACLVS := createTestVSWithACL( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }, - []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - ) - - updateInfo, err := ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{sameACLVS}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Should report ACL reuse for VS1 - assert.NotEmpty( - t, - updateInfo.ACLReusedVs, - "ACL should be reused when unchanged", - ) - - // Update VS1 with different ACL - should NOT report ACL reuse - differentACLVS := createTestVSWithACL( - testVs1IP, - testVs1Port, - false, - []*balancerpb.Real{ - createTestReal(testReal1IP, 1), - }, - []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("172.16.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.240.0.0").AsSlice(), - }, - }}, - ) - - updateInfo, err = ts.Balancer.UpdateVS( - []*balancerpb.VirtualService{differentACLVS}, - ts.Mock.CurrentTime(), - ) - require.NoError(t, err) - require.NotNil(t, updateInfo) - - // Should NOT report ACL reuse for VS1 - assert.Empty( - t, - updateInfo.ACLReusedVs, - "ACL should not be reused when changed", - ) - }) -} diff --git a/modules/balancer/tests/go/utils/config.go b/modules/balancer/tests/go/utils/config.go new file mode 100644 index 000000000..767a3f3cb --- /dev/null +++ b/modules/balancer/tests/go/utils/config.go @@ -0,0 +1,492 @@ +// Package utils provides utils for balancer tests. +package utils + +import ( + "math/rand" + "net/netip" + + "github.com/c2h5oh/datasize" + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "google.golang.org/protobuf/types/known/durationpb" +) + +// --------------------------------------------------------------------------- +// Default values +// --------------------------------------------------------------------------- + +var ( + DefaultSourceV4 = netip.MustParseAddr("100.0.0.1") + DefaultSourceV6 = netip.MustParseAddr("2001:db8::1") + + DefaultTimeouts = &balancerpb.SessionsTimeouts{ + TcpSynAck: 25, + TcpSyn: 20, + TcpFin: 15, + Tcp: 60, + Udp: 30, + } + + DefaultSessionCapacity uint64 = 20_000 + DefaultMaxLoadFactor float32 = 0.5 + DefaultRefreshPeriodSec int + DefaultWlcPower uint64 + DefaultWlcMaxWeight uint32 +) + +// --------------------------------------------------------------------------- +// ConfigBuilder +// --------------------------------------------------------------------------- + +type ConfigBuilder struct { + sourceV4 netip.Addr + sourceV6 netip.Addr + vs []*balancerpb.VirtualService + timeouts *balancerpb.SessionsTimeouts + sessionCapacity uint64 + maxLoadFactor float32 + refreshPeriodSec int + wlcPower uint64 + wlcMaxWeight uint32 +} + +func NewConfigBuilder() *ConfigBuilder { + return &ConfigBuilder{ + sourceV4: DefaultSourceV4, + sourceV6: DefaultSourceV6, + timeouts: DefaultTimeouts, + sessionCapacity: DefaultSessionCapacity, + maxLoadFactor: DefaultMaxLoadFactor, + refreshPeriodSec: DefaultRefreshPeriodSec, + wlcPower: DefaultWlcPower, + wlcMaxWeight: DefaultWlcMaxWeight, + } +} + +func (b *ConfigBuilder) WithSourceV4(ip string) *ConfigBuilder { + b.sourceV4 = netip.MustParseAddr(ip) + return b +} + +func (b *ConfigBuilder) WithSourceV6(ip string) *ConfigBuilder { + b.sourceV6 = netip.MustParseAddr(ip) + return b +} + +func (b *ConfigBuilder) AddVS(vs ...*balancerpb.VirtualService) *ConfigBuilder { + b.vs = append(b.vs, vs...) + return b +} + +func (b *ConfigBuilder) WithTimeouts(t *balancerpb.SessionsTimeouts) *ConfigBuilder { + b.timeouts = t + return b +} + +func (b *ConfigBuilder) WithSessionCapacity(capacity uint64) *ConfigBuilder { + b.sessionCapacity = capacity + return b +} + +func (b *ConfigBuilder) WithMaxLoadFactor(f float32) *ConfigBuilder { + b.maxLoadFactor = f + return b +} + +func (b *ConfigBuilder) WithRefreshPeriod(seconds int) *ConfigBuilder { + b.refreshPeriodSec = seconds + return b +} + +func (b *ConfigBuilder) WithWLC(power uint64, maxWeight uint32) *ConfigBuilder { + b.wlcPower = power + b.wlcMaxWeight = maxWeight + return b +} + +func (b *ConfigBuilder) Build() *balancerpb.BalancerConfig { + capacity := b.sessionCapacity + mlf := b.maxLoadFactor + wlcPower := b.wlcPower + wlcMaxWeight := b.wlcMaxWeight + + return &balancerpb.BalancerConfig{ + PacketHandler: &balancerpb.PacketHandlerConfig{ + Vs: b.vs, + SourceAddressV4: b.sourceV4.AsSlice(), + SourceAddressV6: b.sourceV6.AsSlice(), + DecapAddresses: [][]byte{}, + SessionsTimeouts: b.timeouts, + }, + State: &balancerpb.StateConfig{ + SessionTableCapacity: &capacity, + SessionTableMaxLoadFactor: &mlf, + Wlc: &balancerpb.WlcConfig{ + Power: &wlcPower, + MaxWeight: &wlcMaxWeight, + }, + RefreshPeriod: durationpb.New(0), + }, + } +} + +// --------------------------------------------------------------------------- +// VSBuilder +// --------------------------------------------------------------------------- + +type VSBuilder struct { + addr netip.Addr + port uint16 + proto balancerpb.TransportProto + scheduler balancerpb.VsScheduler + flags balancerpb.VsFlags + reals []*balancerpb.Real + allowed []*balancerpb.AllowedSources + peers [][]byte +} + +func NewTCPVS(addr string, port uint16) *VSBuilder { + return &VSBuilder{ + addr: netip.MustParseAddr(addr), + port: port, + proto: balancerpb.TransportProto_TCP, + } +} + +func NewUDPVS(addr string, port uint16) *VSBuilder { + return &VSBuilder{ + addr: netip.MustParseAddr(addr), + port: port, + proto: balancerpb.TransportProto_UDP, + } +} + +func (b *VSBuilder) WithScheduler(s balancerpb.VsScheduler) *VSBuilder { + b.scheduler = s + return b +} + +func (b *VSBuilder) OPS() *VSBuilder { + b.flags.Ops = true + return b +} + +func (b *VSBuilder) GRE() *VSBuilder { + b.flags.Gre = true + return b +} + +func (b *VSBuilder) FixMSS() *VSBuilder { + b.flags.FixMss = true + return b +} + +func (b *VSBuilder) PureL3() *VSBuilder { + b.flags.PureL3 = true + b.port = 0 + return b +} + +func (b *VSBuilder) WLC() *VSBuilder { + b.scheduler = balancerpb.VsScheduler_WLC + return b +} + +func (b *VSBuilder) AddReal(r ...*balancerpb.Real) *VSBuilder { + b.reals = append(b.reals, r...) + return b +} + +// AllowAll adds a permissive allowed source (0.0.0.0/0 or ::/0 based on VS IP version). +func (b *VSBuilder) AllowAll() *VSBuilder { + if b.addr.Is4() { + b.allowed = append(b.allowed, &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: make([]byte, 4), Mask: make([]byte, 4)}, + }, + }) + } else { + b.allowed = append(b.allowed, &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{ + {Addr: make([]byte, 16), Mask: make([]byte, 16)}, + }, + }) + } + return b +} + +func (b *VSBuilder) AddAllowedSrc(src *balancerpb.AllowedSources) *VSBuilder { + b.allowed = append(b.allowed, src) + return b +} + +func (b *VSBuilder) AddPeers(peers ...string) *VSBuilder { + for _, p := range peers { + addr := netip.MustParseAddr(p) + b.peers = append(b.peers, addr.AsSlice()) + } + return b +} + +func (b *VSBuilder) Build() *balancerpb.VirtualService { + return &balancerpb.VirtualService{ + Id: &balancerpb.VsIdentifier{ + Addr: b.addr.AsSlice(), + Port: uint32(b.port), + Proto: b.proto, + }, + Scheduler: b.scheduler, + Flags: &b.flags, + Reals: b.reals, + AllowedSrcs: b.allowed, + Peers: b.peers, + } +} + +// --------------------------------------------------------------------------- +// Real constructors +// --------------------------------------------------------------------------- + +// R creates a real with weight 1 and a full mask (preserves entire src address). +func R(addr string) *balancerpb.Real { + return RW(addr, 1) +} + +// RW creates a real with the given weight and a full mask. +func RW(addr string, weight uint32) *balancerpb.Real { + ip := netip.MustParseAddr(addr) + var fullMask []byte + if ip.Is4() { + fullMask = []byte{255, 255, 255, 255} + } else { + fullMask = []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + } + } + return &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{Ip: ip.AsSlice()}, + Weight: weight, + Src: &filterpb.IPNet{ + Addr: ip.AsSlice(), + Mask: fullMask, + }, + } +} + +// RealWithSrc creates a real with custom source address and mask. +func RealWithSrc(addr string, weight uint32, srcAddr, srcMask string) *balancerpb.Real { + ip := netip.MustParseAddr(addr) + src := netip.MustParseAddr(srcAddr) + mask := netip.MustParseAddr(srcMask) + return &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{Ip: ip.AsSlice()}, + Weight: weight, + Src: &filterpb.IPNet{ + Addr: src.AsSlice(), + Mask: mask.AsSlice(), + }, + } +} + +// --------------------------------------------------------------------------- +// Generators +// --------------------------------------------------------------------------- + +// GenerateReals generates count random real servers with random weights. +func GenerateReals(count int, rng *rand.Rand) []*balancerpb.Real { + reals := make([]*balancerpb.Real, count) + for i := range count { + ip := generateRealIP(i, rng.Intn(2) == 0) + weight := uint32(rng.Intn(100)) + 1 + reals[i] = RW(ip.String(), weight) + } + return reals +} + +// GenerateVSList generates count virtual services with realsPerVS reals each. +func GenerateVSList(count, realsPerVS int, rng *rand.Rand) []*balancerpb.VirtualService { + vsList := make([]*balancerpb.VirtualService, count) + for i := range count { + isV6 := rng.Intn(2) == 0 + isTCP := rng.Intn(2) == 0 + port := uint16(1000 + i) + + var addr netip.Addr + if isV6 { + addr = generateV6Addr(rng) + } else { + addr = generateV4Addr(10, rng) + } + + var b *VSBuilder + if isTCP { + b = NewTCPVS(addr.String(), port) + } else { + b = NewUDPVS(addr.String(), port) + } + + // Random scheduler + if rng.Intn(2) == 0 { + b.WithScheduler(balancerpb.VsScheduler_WRR) + } + + // Random flags + if rng.Intn(4) == 0 { + b.GRE() + } + if rng.Intn(4) == 0 { + b.OPS() + } + + b.AllowAll() + + reals := make([]*balancerpb.Real, realsPerVS) + for j := range realsPerVS { + realIsV6 := rng.Intn(2) == 0 + realIP := generateRealIP(i*realsPerVS+j, realIsV6) + weight := uint32(rng.Intn(100)) + 1 + reals[j] = RW(realIP.String(), weight) + } + b.AddReal(reals...) + + vsList[i] = b.Build() + } + return vsList +} + +func generateV4Addr(prefix byte, rng *rand.Rand) netip.Addr { + return netip.AddrFrom4([4]byte{ + prefix, + byte(rng.Intn(256)), + byte(rng.Intn(256)), + byte(rng.Intn(254)) + 1, + }) +} + +func generateV6Addr(rng *rand.Rand) netip.Addr { + var b [16]byte + b[0] = 0x20 + b[1] = 0x01 + b[2] = 0x0d + b[3] = 0xb8 + for i := 4; i < 16; i++ { + b[i] = byte(rng.Intn(256)) + } + if b[15] == 0 { + b[15] = 1 + } + return netip.AddrFrom16(b) +} + +func generateRealIP(index int, isV6 bool) netip.Addr { + if isV6 { + var b [16]byte + b[0] = 0xfd + b[1] = 0x00 + b[14] = byte(index >> 8) + b[15] = byte(index&0xff) + 1 + return netip.AddrFrom16(b) + } + return netip.AddrFrom4([4]byte{ + 192, + 168, + byte(index >> 8), + byte(index&0xff) + 1, + }) +} + +// --------------------------------------------------------------------------- +// Helpers for building RealUpdate +// --------------------------------------------------------------------------- + +// EnableReal creates a RealUpdate that enables a real. +func EnableReal( + vsID *balancerpb.VsIdentifier, + realID *balancerpb.RelativeRealIdentifier, +) *balancerpb.RealUpdate { + enable := true + return &balancerpb.RealUpdate{ + RealId: &balancerpb.RealIdentifier{ + Vs: vsID, + Real: realID, + }, + Enable: &enable, + } +} + +// DisableReal creates a RealUpdate that disables a real. +func DisableReal( + vsID *balancerpb.VsIdentifier, + realID *balancerpb.RelativeRealIdentifier, +) *balancerpb.RealUpdate { + enable := false + return &balancerpb.RealUpdate{ + RealId: &balancerpb.RealIdentifier{ + Vs: vsID, + Real: realID, + }, + Enable: &enable, + } +} + +// SetWeight creates a RealUpdate that changes a real's weight. +func SetWeight( + vsID *balancerpb.VsIdentifier, + realID *balancerpb.RelativeRealIdentifier, + weight uint32, +) *balancerpb.RealUpdate { + return &balancerpb.RealUpdate{ + RealId: &balancerpb.RealIdentifier{ + Vs: vsID, + Real: realID, + }, + Weight: &weight, + } +} + +// --------------------------------------------------------------------------- +// Convenience: QuickConfig +// --------------------------------------------------------------------------- + +// QuickConfig creates a BalancerConfig from virtual services with all defaults. +func QuickConfig(vs ...*balancerpb.VirtualService) *balancerpb.BalancerConfig { + return NewConfigBuilder().AddVS(vs...).Build() +} + +// QuickTestSetup is a shorthand for creating a single-worker test setup with sensible memory defaults. +func QuickTestSetup(config *balancerpb.BalancerConfig) *TestConfig { + return &TestConfig{ + Mock: SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), + Balancer: config, + } +} + +// --------------------------------------------------------------------------- +// Address helpers +// --------------------------------------------------------------------------- + +// Addr parses an IP address string, panicking on failure. +func Addr(s string) netip.Addr { + return netip.MustParseAddr(s) +} + +// AddrBytes returns the raw bytes of a parsed IP address. +func AddrBytes(s string) []byte { + return netip.MustParseAddr(s).AsSlice() +} + +// IPNet creates a filterpb.IPNet from address and mask strings. +func IPNet(addr, mask string) *filterpb.IPNet { + return &filterpb.IPNet{ + Addr: AddrBytes(addr), + Mask: AddrBytes(mask), + } +} + +// AllowedNet creates an AllowedSources entry from a network prefix. +func AllowedNet(addr, mask string) *balancerpb.AllowedSources { + return &balancerpb.AllowedSources{ + Nets: []*filterpb.IPNet{IPNet(addr, mask)}, + } +} diff --git a/modules/balancer/tests/go/utils/gen.go b/modules/balancer/tests/go/utils/gen.go new file mode 100644 index 000000000..9b871cd42 --- /dev/null +++ b/modules/balancer/tests/go/utils/gen.go @@ -0,0 +1,271 @@ +package utils + +import ( + "math/rand/v2" + "net/netip" + + "github.com/yanet-platform/yanet2/common/filterpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +func generateContinuousNet(rng *rand.Rand, len int) ([]byte, []byte) { + addr := make([]byte, len) + mask := make([]byte, len) + for i := range len { + addr[i] = byte(rng.IntN(256)) + } + prefixLen := rng.IntN(len*8) + 1 + for i := range prefixLen { + mask[i/8] |= byte(1 << (7 - i%8)) + } + return addr, mask +} + +// GenerateIPv4Net generates a random contiguous IPv4 network. +func GenerateIPv4Net(rng *rand.Rand) *filterpb.IPNet { + addr, mask := generateContinuousNet(rng, 4) + for i := range 4 { + addr[i] = 0 + mask[i] = 0 + } + return &filterpb.IPNet{ + Addr: addr, + Mask: mask, + } +} + +// GenerateIPv6Net generates a random continuously IPv6 network with hole in the middle. +func GenerateIPv6Net(rng *rand.Rand) *filterpb.IPNet { + addr1, mask1 := generateContinuousNet(rng, 8) + addr2, mask2 := generateContinuousNet(rng, 8) + for i := range 8 { + addr1[i] = 0 + addr2[i] = 0 + mask1[i] = 0 + mask2[i] = 0 + } + return &filterpb.IPNet{ + Addr: append(addr1, addr2...), + Mask: append(mask1, mask2...), + } +} + +func GenerateNet(rng *rand.Rand) *filterpb.IPNet { + if rng.IntN(2) == 0 { + return GenerateIPv4Net(rng) + } + return GenerateIPv6Net(rng) +} + +// GenerateAddressFromNet generates a random address from a network. +func GenerateAddressFromNet(rng *rand.Rand, net *filterpb.IPNet) netip.Addr { + addr := net.Addr + mask := net.Mask + res := make([]byte, len(addr)) + for i := range len(addr) { + res[i] = (addr[i] & mask[i]) | (byte(rng.IntN(256)) & ^mask[i]) + } + resAddr, ok := netip.AddrFromSlice(res) + if !ok { + panic("failed to convert slice to address") + } + return resAddr +} + +func GeneratePortRange(rng *rand.Rand) *filterpb.PortRange { + base := 10050 + from := rng.IntN(65536-base) + base + to := rng.IntN(65535-base) + base + if from > to { + from, to = to, from + } + return &filterpb.PortRange{ + From: uint32(from), + To: uint32(to), + } +} + +func GenerateAllowedSources( + rng *rand.Rand, + netCount int, + portCount int, + ipv6 bool, +) *balancerpb.AllowedSources { + nets := make([]*filterpb.IPNet, netCount) + for i := range nets { + if ipv6 { + nets[i] = GenerateIPv6Net(rng) + } else { + nets[i] = GenerateIPv4Net(rng) + } + } + ports := make([]*filterpb.PortRange, portCount) + for i := range ports { + ports[i] = GeneratePortRange(rng) + } + return &balancerpb.AllowedSources{ + Nets: nets, + Ports: ports, + } +} + +func GenerateIPv4Address(rng *rand.Rand) netip.Addr { + addr := [4]byte{} + for i := range 4 { + addr[i] = byte(rng.IntN(256)) + } + return netip.AddrFrom4(addr) +} + +func GenerateIPv6Address(rng *rand.Rand) netip.Addr { + addr := [16]byte{} + for i := range 16 { + addr[i] = byte(rng.IntN(256)) + } + return netip.AddrFrom16(addr) +} + +func GenerateAddress(rng *rand.Rand) netip.Addr { + if rng.IntN(2) == 0 { + return GenerateIPv4Address(rng) + } + return GenerateIPv6Address(rng) +} + +func GenerateReal(rng *rand.Rand) *balancerpb.Real { + ip := GenerateAddress(rng) + var src *filterpb.IPNet + if ip.Is4() { + src = &filterpb.IPNet{ + Addr: GenerateIPv4Address(rng).AsSlice(), + Mask: GenerateIPv4Address(rng).AsSlice(), + } + } else { + src = &filterpb.IPNet{ + Addr: GenerateIPv6Address(rng).AsSlice(), + Mask: GenerateIPv6Address(rng).AsSlice(), + } + } + return &balancerpb.Real{ + Id: &balancerpb.RelativeRealIdentifier{ + Ip: ip.AsSlice(), + Port: 0, + }, + Weight: uint32(rng.IntN(10) + 1), + Src: src, + } +} + +func GenerateSrcFromAllowed( + rng *rand.Rand, + allowedSource *balancerpb.AllowedSources, +) (netip.Addr, uint16) { + netID := rng.IntN(len(allowedSource.Nets)) + net := allowedSource.Nets[netID] + addr := GenerateAddressFromNet(rng, net) + portID := rng.IntN(len(allowedSource.Ports)) + from := allowedSource.Ports[portID].From + to := allowedSource.Ports[portID].To + port := uint16(rng.IntN(int(to-from+1)) + int(from)) + return addr, port +} + +func GenerateAllowedSrcForVS(rng *rand.Rand, vs *balancerpb.VirtualService) (netip.Addr, uint16) { + allowedSrcs := vs.AllowedSrcs + allowedSrcID := rng.IntN(len(allowedSrcs)) + allowedSrc := allowedSrcs[allowedSrcID] + return GenerateSrcFromAllowed(rng, allowedSrc) +} + +func GenerateVS( + rng *rand.Rand, + realsCount int, + allowedSrcCount int, + flags *balancerpb.VsFlags, +) *balancerpb.VirtualService { + reals := make([]*balancerpb.Real, realsCount) + for i := range realsCount { + reals[i] = GenerateReal(rng) + } + vsAddr := GenerateAddress(rng) + allowedSrcs := make([]*balancerpb.AllowedSources, allowedSrcCount) + for i := range allowedSrcCount { + allowedSrcs[i] = GenerateAllowedSources( + rng, + rand.IntN(10)+1, + rand.IntN(10)+1, + !vsAddr.Is4(), + ) + } + proto := balancerpb.TransportProto_TCP + if rng.IntN(2) == 0 { + proto = balancerpb.TransportProto_UDP + } + minPort := 10005 + return &balancerpb.VirtualService{ + Id: &balancerpb.VsIdentifier{ + Addr: vsAddr.AsSlice(), + Port: uint32(rng.IntN(65535-minPort) + minPort), + Proto: proto, + }, + AllowedSrcs: allowedSrcs, + Reals: reals, + Flags: flags, + Scheduler: balancerpb.VsScheduler_WRR, + } +} + +func VSGenerateRealUpdates(vs *balancerpb.VirtualService, rng *rand.Rand) []*balancerpb.RealUpdate { + reals := vs.Reals + updates := make([]*balancerpb.RealUpdate, 0, 2*len(reals)/3) + for _, r := range reals { + var weight *uint32 + switch rng.IntN(3) { + case 0: + val := uint32(rng.IntN(10) + 1) + weight = &val + case 1: + continue + } + if rng.IntN(2) == 0 { + val := uint32(rng.IntN(10) + 1) + weight = &val + } + var enable *bool + switch rng.IntN(3) { + case 0: + val := false + enable = &val + case 1: + val := true + enable = &val + } + updates = append(updates, &balancerpb.RealUpdate{ + RealId: &balancerpb.RealIdentifier{ + Vs: vs.Id, + Real: r.Id, + }, + Enable: enable, + Weight: weight, + }) + } + + // Enable some real + + if len(updates) > 0 { + idx := rng.IntN(len(updates)) + enable := true + updates[idx].Enable = &enable + } else { + enable := true + updates = append(updates, &balancerpb.RealUpdate{ + RealId: &balancerpb.RealIdentifier{ + Vs: vs.Id, + Real: vs.Reals[0].Id, + }, + Enable: &enable, + }) + } + + return updates +} diff --git a/modules/balancer/tests/go/utils/icmp.go b/modules/balancer/tests/go/utils/icmp.go deleted file mode 100644 index 20284ce15..000000000 --- a/modules/balancer/tests/go/utils/icmp.go +++ /dev/null @@ -1,401 +0,0 @@ -// Package utils provides testing utilities for balancer tests. -package utils - -// This file implements ICMP packet creation and validation functions for testing -// ICMP echo requests, destination unreachable messages, and broadcast functionality. - -import ( - "encoding/binary" - "net" - "net/netip" - "testing" - - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/tests/functional/framework" -) - -// ICMPBroadcastIdent is the magic value used to mark broadcasted packets -// This must match the value in modules/balancer/dataplane/icmp/error/broadcast.h -const ICMPBroadcastIdent uint16 = 0x0BDC - -// MakeICMPv4EchoRequest creates an ICMPv4 Echo Request packet -func MakeICMPv4EchoRequest( - srcIP netip.Addr, - dstIP netip.Addr, - id uint16, - seq uint16, -) []gopacket.SerializableLayer { - src := net.IP(srcIP.AsSlice()) - dst := net.IP(dstIP.AsSlice()) - - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := &layers.IPv4{ - Version: 4, - IHL: 5, - TTL: 64, - Protocol: layers.IPProtocolICMPv4, - SrcIP: src, - DstIP: dst, - } - - icmp := &layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), - Id: id, - Seq: seq, - } - - payload := []byte("ICMP Echo Request Payload") - - return []gopacket.SerializableLayer{ - eth, - ip, - icmp, - gopacket.Payload(payload), - } -} - -// MakeICMPv6EchoRequest creates an ICMPv6 Echo Request packet -func MakeICMPv6EchoRequest( - srcIP netip.Addr, - dstIP netip.Addr, - id uint16, - seq uint16, -) []gopacket.SerializableLayer { - src := net.IP(srcIP.AsSlice()) - dst := net.IP(dstIP.AsSlice()) - - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv6, - } - - ip := &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolICMPv6, - HopLimit: 64, - SrcIP: src, - DstIP: dst, - } - - icmp := &layers.ICMPv6{ - TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), - } - _ = icmp.SetNetworkLayerForChecksum(ip) - - // ICMPv6 Echo uses the same ID/Seq format as ICMPv4 - payload := make([]byte, 4+len("ICMP Echo Request Payload")) - payload[0] = byte(id >> 8) - payload[1] = byte(id) - payload[2] = byte(seq >> 8) - payload[3] = byte(seq) - copy(payload[4:], []byte("ICMP Echo Request Payload")) - - return []gopacket.SerializableLayer{ - eth, - ip, - icmp, - gopacket.Payload(payload), - } -} - -// MakeICMPv4DestUnreachable creates an ICMPv4 Destination Unreachable error packet -// containing the original packet that triggered the error -func MakeICMPv4DestUnreachable( - srcIP netip.Addr, - dstIP netip.Addr, - originalPacket gopacket.Packet, -) []gopacket.SerializableLayer { - return MakeICMPv4DestUnreachableWithIdent(srcIP, dstIP, originalPacket, 0) -} - -// MakeICMPv6DestUnreachable creates an ICMPv6 Destination Unreachable error packet -// containing the original packet that triggered the error -func MakeICMPv6DestUnreachable( - srcIP netip.Addr, - dstIP netip.Addr, - originalPacket gopacket.Packet, -) []gopacket.SerializableLayer { - return MakeICMPv6DestUnreachableWithIdent(srcIP, dstIP, originalPacket, 0) -} - -// MakeICMPv4DestUnreachableWithIdent creates an ICMPv4 Destination Unreachable -// error packet with a custom icmp_ident value -func MakeICMPv4DestUnreachableWithIdent( - srcIP netip.Addr, - dstIP netip.Addr, - originalPacket gopacket.Packet, - icmpIdent uint16, -) []gopacket.SerializableLayer { - src := net.IP(srcIP.AsSlice()) - dst := net.IP(dstIP.AsSlice()) - - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv4, - } - - ip := &layers.IPv4{ - Version: 4, - IHL: 5, - TTL: 64, - Protocol: layers.IPProtocolICMPv4, - SrcIP: src, - DstIP: dst, - } - - // For ICMP error messages, we need to manually construct the header - // because gopacket's ICMPv4 layer doesn't properly handle the unused field - // ICMP Error format: [type:1][code:1][checksum:2][unused:4][original packet...] - // We use the first 2 bytes of unused for our broadcast marker - - icmpType := uint8(layers.ICMPv4TypeDestinationUnreachable) - icmpCode := uint8(3) // Port unreachable - - // Extract the original IP packet - originalData := originalPacket.Data() - ipStart := 14 // Ethernet header size - var originalIPPacket []byte - if ipStart < len(originalData) { - originalIPPacket = originalData[ipStart:] - } - - // Build the complete ICMP packet manually - // [type:1][code:1][checksum:2][unused_marker:2][unused_rest:2][original packet...] - icmpPacket := make([]byte, 8+len(originalIPPacket)) - icmpPacket[0] = icmpType - icmpPacket[1] = icmpCode - // checksum at [2:4] will be calculated later - binary.BigEndian.PutUint16(icmpPacket[4:6], icmpIdent) // Our marker - // bytes [6:8] remain zero (rest of unused field) - copy(icmpPacket[8:], originalIPPacket) - - // Calculate checksum - checksum := uint32(0) - for i := 0; i < len(icmpPacket); i += 2 { - if i+1 < len(icmpPacket) { - checksum += uint32(icmpPacket[i])<<8 | uint32(icmpPacket[i+1]) - } else { - checksum += uint32(icmpPacket[i]) << 8 - } - } - for checksum > 0xffff { - checksum = (checksum & 0xffff) + (checksum >> 16) - } - binary.BigEndian.PutUint16(icmpPacket[2:4], ^uint16(checksum)) - - return []gopacket.SerializableLayer{ - eth, - ip, - gopacket.Payload(icmpPacket), - } -} - -// MakeICMPv6DestUnreachableWithIdent creates an ICMPv6 Destination Unreachable -// error packet with a custom icmp_ident value -func MakeICMPv6DestUnreachableWithIdent( - srcIP netip.Addr, - dstIP netip.Addr, - originalPacket gopacket.Packet, - icmpIdent uint16, -) []gopacket.SerializableLayer { - src := net.IP(srcIP.AsSlice()) - dst := net.IP(dstIP.AsSlice()) - - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv6, - } - - ip := &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolICMPv6, - HopLimit: 64, - SrcIP: src, - DstIP: dst, - } - - // For ICMPv6 error messages, manually construct the header - // ICMPv6 Error format: [type:1][code:1][checksum:2][unused:4][original packet...] - // We use the first 2 bytes of unused for our broadcast marker - - icmpType := uint8(layers.ICMPv6TypeDestinationUnreachable) - icmpCode := uint8(4) // Port unreachable - - // Extract the original IP packet - originalData := originalPacket.Data() - ipStart := 14 // Ethernet header size - var originalIPPacket []byte - if ipStart < len(originalData) { - originalIPPacket = originalData[ipStart:] - } - - // Build the complete ICMPv6 packet manually - // [type:1][code:1][checksum:2][unused_marker:2][unused_rest:2][original packet...] - icmpPacket := make([]byte, 8+len(originalIPPacket)) - icmpPacket[0] = icmpType - icmpPacket[1] = icmpCode - // checksum at [2:4] will be calculated later - binary.BigEndian.PutUint16(icmpPacket[4:6], icmpIdent) // Our marker - // bytes [6:8] remain zero (rest of unused field) - copy(icmpPacket[8:], originalIPPacket) - - // Calculate ICMPv6 checksum (includes pseudo-header) - // Pseudo-header: [src:16][dst:16][length:4][zeros:3][next_header:1] - pseudoHeader := make([]byte, 40) - copy(pseudoHeader[0:16], src) - copy(pseudoHeader[16:32], dst) - binary.BigEndian.PutUint32(pseudoHeader[32:36], uint32(len(icmpPacket))) - pseudoHeader[39] = uint8(layers.IPProtocolICMPv6) - - checksumData := append(pseudoHeader, icmpPacket...) - checksum := uint32(0) - for i := 0; i < len(checksumData); i += 2 { - if i+1 < len(checksumData) { - checksum += uint32(checksumData[i])<<8 | uint32(checksumData[i+1]) - } else { - checksum += uint32(checksumData[i]) << 8 - } - } - for checksum > 0xffff { - checksum = (checksum & 0xffff) + (checksum >> 16) - } - binary.BigEndian.PutUint16(icmpPacket[2:4], ^uint16(checksum)) - - return []gopacket.SerializableLayer{ - eth, - ip, - gopacket.Payload(icmpPacket), - } -} - -// MakeTunneledICMPv4DestUnreachable creates an IP-in-IP tunneled ICMPv4 -// Destination Unreachable packet with a custom icmp_ident -func MakeTunneledICMPv4DestUnreachable( - tunnelSrcIP netip.Addr, - tunnelDstIP netip.Addr, - icmpSrcIP netip.Addr, - icmpDstIP netip.Addr, - originalPacket gopacket.Packet, - icmpIdent uint16, -) []gopacket.SerializableLayer { - // Create the inner ICMP packet - innerLayers := MakeICMPv4DestUnreachableWithIdent( - icmpSrcIP, - icmpDstIP, - originalPacket, - icmpIdent, - ) - - // Serialize the inner packet (skip Ethernet header) - innerPacketLayers := innerLayers[1:] // Skip Ethernet - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - err := gopacket.SerializeLayers(buf, opts, innerPacketLayers...) - if err != nil { - panic(err) - } - innerPacketData := buf.Bytes() - - // Create outer tunnel headers - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv4, - } - - outerIP := &layers.IPv4{ - Version: 4, - IHL: 5, - TTL: 64, - Protocol: 4, // IPIP protocol - SrcIP: net.IP(tunnelSrcIP.AsSlice()), - DstIP: net.IP(tunnelDstIP.AsSlice()), - } - - return []gopacket.SerializableLayer{ - eth, - outerIP, - gopacket.Payload(innerPacketData), - } -} - -// MakeTunneledICMPv6DestUnreachable creates an IPv6-in-IPv6 tunneled ICMPv6 -// Destination Unreachable packet with a custom icmp_ident -func MakeTunneledICMPv6DestUnreachable( - tunnelSrcIP netip.Addr, - tunnelDstIP netip.Addr, - icmpSrcIP netip.Addr, - icmpDstIP netip.Addr, - originalPacket gopacket.Packet, - icmpIdent uint16, -) []gopacket.SerializableLayer { - // Create the inner ICMP packet - innerLayers := MakeICMPv6DestUnreachableWithIdent( - icmpSrcIP, - icmpDstIP, - originalPacket, - icmpIdent, - ) - - // Serialize the inner packet (skip Ethernet header) - innerPacketLayers := innerLayers[1:] // Skip Ethernet - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - err := gopacket.SerializeLayers(buf, opts, innerPacketLayers...) - if err != nil { - panic(err) - } - innerPacketData := buf.Bytes() - - // Create outer tunnel headers - eth := &layers.Ethernet{ - SrcMAC: net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - DstMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - EthernetType: layers.EthernetTypeIPv6, - } - - outerIP := &layers.IPv6{ - Version: 6, - NextHeader: 41, // IPv6 protocol - HopLimit: 64, - SrcIP: net.IP(tunnelSrcIP.AsSlice()), - DstIP: net.IP(tunnelDstIP.AsSlice()), - } - - return []gopacket.SerializableLayer{ - eth, - outerIP, - gopacket.Payload(innerPacketData), - } -} - -// VerifyBroadcastedICMPPacket checks that a broadcasted packet is properly -// tunneled and has the ICMP_BROADCAST_IDENT marker set -func VerifyBroadcastedICMPPacket( - t *testing.T, - packet *framework.PacketInfo, - expectedDstIP net.IP, -) { - t.Helper() - - // Verify packet is tunneled - require.True(t, packet.IsTunneled, "broadcasted packet should be tunneled") - - // Verify destination is a peer - require.Equal( - t, - expectedDstIP, - packet.DstIP, - "packet should be sent to peer", - ) -} diff --git a/modules/balancer/tests/go/utils/packet.go b/modules/balancer/tests/go/utils/packet.go index a54406a48..426a89952 100644 --- a/modules/balancer/tests/go/utils/packet.go +++ b/modules/balancer/tests/go/utils/packet.go @@ -1,28 +1,29 @@ package utils -// TCP and UDP packet creation utilities for balancer testing, supporting both IPv4 and IPv6 -// with MSS option manipulation for testing packet modification and encapsulation scenarios - import ( - "errors" + "cmp" "fmt" + "math/rand/v2" "net" "net/netip" + "testing" "github.com/gopacket/gopacket" "github.com/gopacket/gopacket/layers" + "github.com/yanet-platform/yanet2/common/go/xpacket" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "github.com/yanet-platform/yanet2/tests/functional/framework" ) -// MakeTCPPacket creates a TCP packet with the specified parameters. +// MakeTCPPacketLayers creates a TCP packet with the specified parameters. // Supports both IPv4 and IPv6. -func MakeTCPPacket( +func MakeTCPPacketLayers( srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16, tcp *layers.TCP, ) []gopacket.SerializableLayer { - // Ensure both addresses are the same IP version if srcIP.Is4() != dstIP.Is4() { panic(fmt.Sprintf("IP version mismatch: src=%v dst=%v", srcIP, dstIP)) } @@ -65,25 +66,22 @@ func MakeTCPPacket( _ = tcp.SetNetworkLayerForChecksum(ip) payload := []byte("BALANCER TEST PAYLOAD 12345678910") - packetLayers := []gopacket.SerializableLayer{ + return []gopacket.SerializableLayer{ eth, ip.(gopacket.SerializableLayer), tcp, gopacket.Payload(payload), } - - return packetLayers } -// MakeUDPPacket creates a UDP packet with the specified parameters. +// MakeUDPPacketLayers creates a UDP packet with the specified parameters. // Supports both IPv4 and IPv6. -func MakeUDPPacket( +func MakeUDPPacketLayers( srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16, ) []gopacket.SerializableLayer { - // Ensure both addresses are the same IP version if srcIP.Is4() != dstIP.Is4() { panic(fmt.Sprintf("IP version mismatch: src=%v dst=%v", srcIP, dstIP)) } @@ -127,18 +125,16 @@ func MakeUDPPacket( } _ = udp.SetNetworkLayerForChecksum(ip) - payload := []byte("PING TEST PAYLOAD 1234567890") - packetLayers := []gopacket.SerializableLayer{ + payload := []byte("BALANCER TEST PAYLOAD 12345678910") + return []gopacket.SerializableLayer{ eth, ip.(gopacket.SerializableLayer), udp, gopacket.Payload(payload), } - - return packetLayers } -// MakePacketLayers creates packet layers based on whether TCP or UDP is specified. +// MakePacketLayers creates TCP or UDP packet layers. // If tcp is nil, creates a UDP packet; otherwise creates a TCP packet. func MakePacketLayers( srcIP netip.Addr, @@ -148,138 +144,232 @@ func MakePacketLayers( tcp *layers.TCP, ) []gopacket.SerializableLayer { if tcp == nil { - return MakeUDPPacket(srcIP, srcPort, dstIP, dstPort) + return MakeUDPPacketLayers(srcIP, srcPort, dstIP, dstPort) } - return MakeTCPPacket(srcIP, srcPort, dstIP, dstPort, tcp) + return MakeTCPPacketLayers(srcIP, srcPort, dstIP, dstPort, tcp) } -// padTCPOptions pads TCP options to 4-byte boundary with NOPs -func padTCPOptions(opts []layers.TCPOption) ([]layers.TCPOption, error) { - // Compute current options length (bytes) - length := 0 - for _, o := range opts { - switch o.OptionType { - case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: - length += 1 - default: - if o.OptionLength == 0 { - return nil, errors.New("TCP option with zero length") - } - length += int(o.OptionLength) - } +type PacketInfo struct { + VsID VsID + ClientAddr netip.Addr + ClientPort uint16 + RealID RealID + + // Might be nil. + Packet *framework.PacketInfo +} + +func (pkt *PacketInfo) Compare(other *PacketInfo) int { + if cmp := pkt.VsID.Compare(&other.VsID); cmp != 0 { + return cmp } - if length > 40 { - return nil, fmt.Errorf("TCP options exceed 40 bytes (%d)", length) + if cmp := pkt.ClientAddr.Compare(other.ClientAddr); cmp != 0 { + return cmp } - // Pad with NOPs to 4-byte boundary - for (length % 4) != 0 { - opts = append( - opts, - layers.TCPOption{OptionType: layers.TCPOptionKindNop}, - ) - length++ + if cmp := cmp.Compare(pkt.ClientPort, other.ClientPort); cmp != 0 { + return cmp } - return opts, nil + return pkt.RealID.Compare(&other.RealID) } -// InsertOrUpdateMSS inserts or updates the MSS option in a TCP packet -func InsertOrUpdateMSS( - p gopacket.Packet, - newMSS uint16, -) (*gopacket.Packet, error) { - tcpL := p.Layer(layers.LayerTypeTCP) - if tcpL == nil { - return nil, errors.New("no TCP layer") +func (pkt *PacketInfo) String() string { + addrStr := pkt.ClientAddr.String() + if pkt.ClientAddr.Is6() { + addrStr = fmt.Sprintf("[%s]", addrStr) + } + return fmt.Sprintf("[%s -> %s -> %s]", addrStr, pkt.VsID.String(), pkt.RealID.String()) +} + +func PacketInfoFromSessionPb(s *balancerpb.Session) (PacketInfo, error) { + addr, ok := netip.AddrFromSlice(s.ClientAddr) + if !ok { + return PacketInfo{}, fmt.Errorf("invalid client address: %v", s.ClientAddr) + } + if s.ClientPort > 65535 { + return PacketInfo{}, fmt.Errorf("invalid client port: %d", s.ClientPort) + } + port := uint16(s.ClientPort) + return PacketInfo{ + VsID: VsIDFromPb(s.VsId), + ClientAddr: addr, + ClientPort: port, + RealID: RealIDFromPb(s.RealId), + }, nil +} + +func SendAndValidateTCP( + ts *TestSetup, + srcIP netip.Addr, + srcPort uint16, + dstIP netip.Addr, + dstPort uint16, + tcp *layers.TCP, +) (PacketInfo, error) { + pktLayers := MakeTCPPacketLayers(srcIP, srcPort, dstIP, dstPort, tcp) + pkt, err := xpacket.LayersToPacketChecked(pktLayers...) + if err != nil { + return PacketInfo{}, fmt.Errorf("failed to convert layers to packet: %w", err) + } + + result, err := ts.Mock.HandlePackets(pkt) + if err != nil { + return PacketInfo{}, fmt.Errorf("failed to handle packet: %w", err) } - ip4L := p.Layer(layers.LayerTypeIPv4) - ip6L := p.Layer(layers.LayerTypeIPv6) - if ip4L == nil && ip6L == nil { - return nil, errors.New("no IPv4/IPv6 layer") + + if len(result.Output) != 1 { + return PacketInfo{}, fmt.Errorf("expected 1 output packet, got %d", len(result.Output)) } - tcp := *tcpL.(*layers.TCP) - if !tcp.SYN { - return nil, errors.New("MSS option is only valid on SYN/SYN-ACK") + if len(result.Drop) != 0 { + return PacketInfo{}, fmt.Errorf("expected no drops, got %d", len(result.Drop)) } - // Update existing MSS or insert a new one - found := false - for i, o := range tcp.Options { - if o.OptionType == layers.TCPOptionKindMSS && o.OptionLength >= 4 { - tcp.Options[i].OptionData = []byte{byte(newMSS >> 8), byte(newMSS)} - found = true - break - } + return ValidatePacket(ts.Balancer.Config(), pkt, result.Output[0]) +} + +func SendAndValidateUDP( + ts *TestSetup, + srcIP netip.Addr, + srcPort uint16, + dstIP netip.Addr, + dstPort uint16, +) (PacketInfo, error) { + pktLayers := MakeUDPPacketLayers(srcIP, srcPort, dstIP, dstPort) + pkt, err := xpacket.LayersToPacketChecked(pktLayers...) + if err != nil { + return PacketInfo{}, fmt.Errorf("failed to convert layers to packet: %w", err) + } + + result, err := ts.Mock.HandlePackets(pkt) + if err != nil { + return PacketInfo{}, fmt.Errorf("failed to handle packet: %w", err) + } + + if len(result.Output) != 1 { + return PacketInfo{}, fmt.Errorf("expected 1 output packet, got %d", len(result.Output)) } - if !found { - mssOpt := layers.TCPOption{ - OptionType: layers.TCPOptionKindMSS, - OptionLength: 4, - OptionData: []byte{byte(newMSS >> 8), byte(newMSS)}, + + if len(result.Drop) != 0 { + return PacketInfo{}, fmt.Errorf("expected no drops, got %d", len(result.Drop)) + } + + return ValidatePacket(ts.Balancer.Config(), pkt, result.Output[0]) +} + +func SendAndValidate( + ts *TestSetup, + srcIP netip.Addr, + srcPort uint16, + dstIP netip.Addr, + dstPort uint16, + tcp *layers.TCP, +) (PacketInfo, error) { + if tcp == nil { + return SendAndValidateUDP(ts, srcIP, srcPort, dstIP, dstPort) + } + return SendAndValidateTCP(ts, srcIP, srcPort, dstIP, dstPort, tcp) +} + +func SendAndValidateRandomSrcPorts( + ts *TestSetup, + srcIP netip.Addr, + dstIP netip.Addr, + dstPort uint16, + tcp *layers.TCP, + count int, +) ([]PacketInfo, error) { + inputPackets := make([]gopacket.Packet, 0, count) + + for idx := range count { + var pktLayers []gopacket.SerializableLayer + if tcp == nil { + pktLayers = MakeUDPPacketLayers(srcIP, uint16(rand.Uint32()%60000+1024), dstIP, dstPort) + } else { + pktLayers = MakeTCPPacketLayers(srcIP, uint16(rand.Uint32()%60000+1024), dstIP, dstPort, tcp) + } + pkt, err := xpacket.LayersToPacketChecked(pktLayers...) + if err != nil { + return nil, fmt.Errorf("failed to convert layers to packet at index %d: %w", idx, err) } - // Conventionally MSS is first - tcp.Options = append([]layers.TCPOption{mssOpt}, tcp.Options...) + inputPackets = append(inputPackets, pkt) } - // Pad options and check size - var err error - tcp.Options, err = padTCPOptions(tcp.Options) + result, err := ts.Mock.HandlePackets(inputPackets...) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to handle packets: %w", err) } - var serLayers []gopacket.SerializableLayer + if len(result.Output) != count { + return nil, fmt.Errorf("expected %d output packets, got %d", count, len(result.Output)) + } - var netBeforeTCP gopacket.NetworkLayer + if len(result.Drop) != 0 { + return nil, fmt.Errorf("expected no drops, got %d", len(result.Drop)) + } - for _, l := range p.Layers() { - if l.LayerType() == layers.LayerTypeTCP { - break + outputPackets := make([]PacketInfo, 0, count) + for i, output := range result.Output { + pktInfo, err := ValidatePacket(ts.Balancer.Config(), inputPackets[i], output) + if err != nil { + return nil, fmt.Errorf("validation failed for packet %d: %w", i, err) } - if nl, ok := l.(gopacket.NetworkLayer); ok { - netBeforeTCP = nl + outputPackets = append(outputPackets, pktInfo) + } + + return outputPackets, nil +} + +func SendAndValidateMany(t *testing.T, ts *TestSetup, rng *rand.Rand) error { + config := ts.Balancer.Config() + for _, vs := range config.PacketHandler.Vs { + vsID := VsIDFromPb(vs.Id) + srcIP, srcPort := GenerateAllowedSrcForVS(rng, vs) + + dstIP, ok := netip.AddrFromSlice(vs.Id.Addr) + if !ok { + return fmt.Errorf("failed to parse destination IP for vs %s", vs.Id) } - if sl, ok := l.(gopacket.SerializableLayer); ok { - // Make a value-copy for common layers to avoid mutating the original packet - switch v := l.(type) { - case *layers.Ethernet: - c := *v - serLayers = append(serLayers, &c) - case *layers.Dot1Q: - c := *v - serLayers = append(serLayers, &c) - case *layers.IPv4: - c := *v - serLayers = append(serLayers, &c) - case *layers.IPv6: - c := *v - serLayers = append(serLayers, &c) - case *layers.IPv6HopByHop: - c := *v - serLayers = append(serLayers, &c) - case *layers.IPv6Fragment: - c := *v - serLayers = append(serLayers, &c) - case *layers.UDP: - c := *v - serLayers = append(serLayers, &c) - default: - // Fallback: use as-is (most gopacket layers are already SerializableLayer) - serLayers = append(serLayers, sl) + + dstPort := uint16(vs.Id.Port) + + var tcp *layers.TCP + if vs.Id.Proto == balancerpb.TransportProto_TCP { + tcp = &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: false, } } - } - _ = tcp.SetNetworkLayerForChecksum(netBeforeTCP) - serLayers = append(serLayers, &tcp) - serLayers = append(serLayers, gopacket.Payload(tcp.Payload)) + _, err := SendAndValidate(ts, srcIP, srcPort, dstIP, dstPort, tcp) + if err == nil { + continue + } + + resErr := fmt.Errorf("failed to send and validate packets for vs %s: %w", &vsID, err) + + state, err := ts.Balancer.GetState( + PacketHandlerRef(), + nil, + true, + ts.Mock.CurrentTime(), + ) + if err != nil { + return resErr + } - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - if err := gopacket.SerializeLayers(buf, opts, serLayers...); err != nil { - return nil, err + t.Logf("balancer common stats: %v", state[0].CommonStats) + t.Logf("balancer l4 stats: %v", state[0].L4Stats) + for _, vsState := range state[0].VirtualServices { + vsStateID := VsIDFromPb(vsState.Id) + if vsID.Compare(&vsStateID) == 0 { + t.Logf("vs stats: %v", vsState.Stats) + } + } + + return resErr } - out := buf.Bytes() - p2 := gopacket.NewPacket(out, layers.LayerTypeEthernet, gopacket.Default) - return &p2, nil + + return nil } diff --git a/modules/balancer/tests/go/utils/real.go b/modules/balancer/tests/go/utils/real.go new file mode 100644 index 000000000..46eca5844 --- /dev/null +++ b/modules/balancer/tests/go/utils/real.go @@ -0,0 +1,50 @@ +package utils + +import ( + "bytes" + "cmp" + "net/netip" + + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +type RealID struct { + addrLen int + addr [16]byte +} + +func (r *RealID) Compare(other *RealID) int { + if r.addrLen != other.addrLen { + return cmp.Compare(r.addrLen, other.addrLen) + } + return bytes.Compare(r.addr[:r.addrLen], other.addr[:other.addrLen]) +} + +func (r *RealID) String() string { + ip, _ := netip.AddrFromSlice(r.addr[:r.addrLen]) + return ip.String() +} + +func RealIDFromPb(r *balancerpb.RelativeRealIdentifier) RealID { + addr := [16]byte{} + copy(addr[:], r.Ip) + return RealID{ + addrLen: len(r.Ip), + addr: addr, + } +} + +func RealStatsEqual(a, b *balancerpb.RealStats) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.PacketsRealDisabled == b.PacketsRealDisabled && + a.OpsPackets == b.OpsPackets && + a.ErrorIcmpPackets == b.ErrorIcmpPackets && + a.CreatedSessions == b.CreatedSessions && + a.Packets == b.Packets && + a.Bytes == b.Bytes +} diff --git a/modules/balancer/tests/go/utils/setup.go b/modules/balancer/tests/go/utils/setup.go index 408a57bc0..103d77702 100644 --- a/modules/balancer/tests/go/utils/setup.go +++ b/modules/balancer/tests/go/utils/setup.go @@ -1,36 +1,35 @@ package utils -// Test setup utilities for creating balancer test environments with mock dataplane, -// configuring YANET infrastructure (devices, pipelines, functions), and managing -// test lifecycle including balancer agent and manager initialization. - import ( "fmt" "testing" "github.com/c2h5oh/datasize" - "github.com/yanet-platform/yanet2/common/go/logging" "github.com/yanet-platform/yanet2/controlplane/ffi" mock "github.com/yanet-platform/yanet2/mock/go" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "go.uber.org/zap/zapcore" + balancer "github.com/yanet-platform/yanet2/modules/balancer/controlplane" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" ) var ( - DeviceName string = "01:00.0" - PipelineName string = "pipeline0" - FunctionName string = "function0" - ChainName string = "chain0" - BalancerName string = "balancer0" + DeviceName = "01:00.0" + PipelineName = "pipeline0" + FunctionName = "function0" + ChainName = "chain0" + BalancerName = "balancer0" ) -//////////////////////////////////////////////////////////////////////////////// - type TestConfig struct { Mock *mock.YanetMockConfig Balancer *balancerpb.BalancerConfig - AgentMemory *datasize.ByteSize + AgentMemory datasize.ByteSize // 0 means default (4 MB) +} + +type TestSetup struct { + Mock *mock.YanetMock + Agent *balancer.Agent + Balancer *balancer.Balancer } func SingleWorkerMockConfig( @@ -50,145 +49,128 @@ func SingleWorkerMockConfig( } } -type TestSetup struct { - Mock *mock.YanetMock - Agent *balancer.BalancerAgent - Balancer *balancer.BalancerManager -} - func Make(config *TestConfig) (*TestSetup, error) { if config.Mock.AgentsMemory < 8*datasize.MB { return nil, fmt.Errorf("CP memory must be at least 8MB") } - mock, err := mock.NewYanetMock(config.Mock) + + m, err := mock.NewYanetMock(config.Mock) if err != nil { - return nil, fmt.Errorf("failed to create new mock: %v", err) + return nil, fmt.Errorf("create mock: %w", err) } - logLevel := zapcore.InfoLevel - sugaredLogger, _, _ := logging.Init(&logging.Config{ - Level: logLevel, - }) + agentMemory := 4 * datasize.MB - if config.AgentMemory != nil { - agentMemory = *config.AgentMemory + if config.AgentMemory != 0 { + agentMemory = config.AgentMemory } - agent, err := balancer.NewBalancerAgent( - mock.SharedMemory(), + + log := zap.NewNop().Sugar() + + agent, err := balancer.ReattachAgent( + m.SharedMemory(), + 0, agentMemory, - sugaredLogger, + log, ) if err != nil { - return nil, fmt.Errorf("failed to create new balancer agent: %v", err) - } - if err := agent.NewBalancerManager(BalancerName, config.Balancer); err != nil { - return nil, fmt.Errorf("failed to create new balancer manager: %v", err) + m.Free() + return nil, fmt.Errorf("attach balancer agent: %w", err) } - balancer, err := agent.BalancerManager(BalancerName) + + b, err := balancer.NewBalancer(agent, BalancerName, config.Balancer, log) if err != nil { - panic("failed to get balancer after successful creation") + m.Free() + return nil, fmt.Errorf("create balancer: %w", err) } - bootstrap, err := mock.SharedMemory().AgentReattach("bootstrap", 0, 1<<20) + bootstrap, err := m.SharedMemory().AgentReattach("bootstrap", 0, 1<<20) if err != nil { - return nil, fmt.Errorf("failed to attach to bootstrap agent: %v", err) + b.Destroy() + m.Free() + return nil, fmt.Errorf("attach bootstrap agent: %w", err) } if err := setupCp(bootstrap); err != nil { - return nil, fmt.Errorf("failed to setup controlplane: %v", err) + b.Destroy() + m.Free() + return nil, fmt.Errorf("setup controlplane: %w", err) } return &TestSetup{ - Mock: mock, + Mock: m, Agent: agent, - Balancer: balancer, + Balancer: b, }, nil } func setupCp(agent *ffi.Agent) error { - { - functionConfig := ffi.FunctionConfig{ - Name: FunctionName, - Chains: []ffi.FunctionChainConfig{ - { - Weight: 1, - Chain: ffi.ChainConfig{ - Name: ChainName, - Modules: []ffi.ChainModuleConfig{ - { - Type: "balancer", - Name: BalancerName, - }, + functionConfig := ffi.FunctionConfig{ + Name: FunctionName, + Chains: []ffi.FunctionChainConfig{ + { + Weight: 1, + Chain: ffi.ChainConfig{ + Name: ChainName, + Modules: []ffi.ChainModuleConfig{ + { + Type: "balancer", + Name: BalancerName, }, }, }, }, - } - - if err := agent.UpdateFunction(functionConfig); err != nil { - return fmt.Errorf("failed to update function: %w", err) - } + }, } - - // update pipelines - { - inputPipelineConfig := ffi.PipelineConfig{ - Name: PipelineName, - Functions: []string{FunctionName}, - } - - dummyPipelineConfig := ffi.PipelineConfig{ - Name: "dummy", - Functions: []string{}, - } - - if err := agent.UpdatePipeline(inputPipelineConfig); err != nil { - return fmt.Errorf("failed to update pipeline: %w", err) - } - - if err := agent.UpdatePipeline(dummyPipelineConfig); err != nil { - return fmt.Errorf("failed to update pipeline: %w", err) - } + if err := agent.UpdateFunction(functionConfig); err != nil { + return fmt.Errorf("update function: %w", err) } - // update devices - { - deviceConfig := ffi.DeviceConfig{ - Name: DeviceName, - Input: []ffi.DevicePipelineConfig{ - { - Name: PipelineName, - Weight: 1, - }, - }, - Output: []ffi.DevicePipelineConfig{ - { - Name: "dummy", - Weight: 1, - }, - }, - } + inputPipeline := ffi.PipelineConfig{ + Name: PipelineName, + Functions: []string{FunctionName}, + } + dummyPipeline := ffi.PipelineConfig{ + Name: "dummy", + Functions: []string{}, + } + if err := agent.UpdatePipeline(inputPipeline); err != nil { + return fmt.Errorf("update input pipeline: %w", err) + } + if err := agent.UpdatePipeline(dummyPipeline); err != nil { + return fmt.Errorf("update dummy pipeline: %w", err) + } - if err := agent.UpdatePlainDevices([]ffi.DeviceConfig{deviceConfig}); err != nil { - return fmt.Errorf("failed to update pipelines: %w", err) - } + deviceConfig := ffi.DeviceConfig{ + Name: DeviceName, + Input: []ffi.DevicePipelineConfig{ + {Name: PipelineName, Weight: 1}, + }, + Output: []ffi.DevicePipelineConfig{ + {Name: "dummy", Weight: 1}, + }, + } + if err := agent.UpdatePlainDevices([]ffi.DeviceConfig{deviceConfig}); err != nil { + return fmt.Errorf("update devices: %w", err) } return nil } func (ts *TestSetup) Free() { - ts.Balancer.Free() + ts.Balancer.Destroy() ts.Mock.Free() } -// EnableAllReals enables all real servers in the balancer configuration func EnableAllReals(t *testing.T, ts *TestSetup) { t.Helper() config := ts.Balancer.Config() - var updates []*balancerpb.RealUpdate - enableTrue := true + if config.PacketHandler == nil { + return + } + enableTrue := true + var updates []*balancerpb.RealUpdate for _, vs := range config.PacketHandler.Vs { for _, real := range vs.Reals { updates = append(updates, &balancerpb.RealUpdate{ @@ -206,3 +188,12 @@ func EnableAllReals(t *testing.T, ts *TestSetup) { t.Fatalf("failed to enable reals: %v", err) } } + +func PacketHandlerRef() *balancerpb.PacketHandlerRef { + return &balancerpb.PacketHandlerRef{ + Device: &DeviceName, + Pipeline: &PipelineName, + Function: &FunctionName, + Chain: &ChainName, + } +} diff --git a/modules/balancer/tests/go/utils/snapshots.go b/modules/balancer/tests/go/utils/snapshots.go new file mode 100644 index 000000000..a58acc74b --- /dev/null +++ b/modules/balancer/tests/go/utils/snapshots.go @@ -0,0 +1,111 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +type RealSnapshot struct { + ActiveSessions uint64 + Stats *balancerpb.RealStats +} + +type VsSnapshot struct { + Stats *balancerpb.VsStats + Reals map[RealID]RealSnapshot +} + +func CaptureVsSnapshots( + t *testing.T, + ts *TestSetup, + vsToCapture []*balancerpb.VirtualService, +) map[VsID]VsSnapshot { + t.Helper() + state, err := ts.Balancer.GetState(nil, nil, true, ts.Mock.CurrentTime()) + assert.NoError(t, err) + assert.Len(t, state, 1) + + stateByID := make(map[VsID]*balancerpb.VsState, len(state[0].VirtualServices)) + for _, vsState := range state[0].VirtualServices { + stateByID[VsIDFromPb(vsState.Id)] = vsState + } + + snapshots := make(map[VsID]VsSnapshot, len(vsToCapture)) + for _, vs := range vsToCapture { + id := VsIDFromPb(vs.Id) + vsState, ok := stateByID[id] + if !ok { + continue + } + snap := VsSnapshot{ + Reals: make(map[RealID]RealSnapshot, len(vsState.Reals)), + } + snap.Stats = vsState.Stats + for _, r := range vsState.Reals { + rid := RealIDFromPb(r.Id) + snap.Reals[rid] = RealSnapshot{ + ActiveSessions: r.ActiveSessions, + Stats: r.RealStats, + } + } + snapshots[id] = snap + } + return snapshots +} + +// VerifyInheritedStats asserts that every VS present in both snapshots and current +// state has the same number packet counters. VS missing from +// current state are silently skipped. Reals no longer +// present are also skipped. +func VerifyInheritedStats( + t *testing.T, + ts *TestSetup, + snapshots map[VsID]VsSnapshot, +) { + t.Helper() + state, err := ts.Balancer.GetState(nil, nil, true, ts.Mock.CurrentTime()) + assert.NoError(t, err) + assert.Len(t, state, 1) + + currentByID := make(map[VsID]*balancerpb.VsState, len(state[0].VirtualServices)) + for _, vsState := range state[0].VirtualServices { + currentByID[VsIDFromPb(vsState.Id)] = vsState + } + + for vsID, snap := range snapshots { + vsState, ok := currentByID[vsID] + if !ok { + continue + } + assert.True(t, VsStatsEquals(vsState.Stats, snap.Stats), + "VS %s: stats not inherited after update", vsID.String()) + + realByID := make(map[RealID]*balancerpb.RealState, len(vsState.Reals)) + for _, r := range vsState.Reals { + realByID[RealIDFromPb(r.Id)] = r + } + for realID, realSnap := range snap.Reals { + r, ok := realByID[realID] + if !ok { + continue + } + assert.Equal( + t, + r.ActiveSessions, + realSnap.ActiveSessions, + "VS %s, real %s: active sessions not inherited after update", + vsID.String(), + realID.String(), + ) + assert.True( + t, + RealStatsEqual(r.RealStats, realSnap.Stats), + "VS %s, real %s: stats not inherited after update", + vsID.String(), + realID.String(), + ) + } + } +} diff --git a/modules/balancer/tests/go/utils/validation.go b/modules/balancer/tests/go/utils/validation.go index 3687f8271..de5e40a13 100644 --- a/modules/balancer/tests/go/utils/validation.go +++ b/modules/balancer/tests/go/utils/validation.go @@ -1,207 +1,200 @@ package utils -// Packet validation utilities for verifying balancer behavior including tunnel structure, -// ToS/TrafficClass preservation, protocol consistency, service/real matching, weight distribution, -// and tunnel source address calculation according to balancer specifications. - import ( + "bytes" "fmt" "math" "net" "net/netip" - "testing" "github.com/gopacket/gopacket" "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" "github.com/yanet-platform/yanet2/tests/functional/framework" ) // ValidatePacket validates that a packet has been properly processed by the balancer. -// It checks that the packet is tunneled and that the inner packet matches the original. func ValidatePacket( - t *testing.T, config *balancerpb.BalancerConfig, originalGoPacket gopacket.Packet, resultPacket *framework.PacketInfo, -) { - t.Helper() - - // Parse the original packet +) (PacketInfo, error) { parser := framework.NewPacketParser() originalPacket, err := parser.ParsePacket(originalGoPacket.Data()) - require.NoError(t, err, "failed to parse original packet") + if err != nil { + return PacketInfo{}, fmt.Errorf("failed to parse packet: %w", err) + } + + if err := validateTunnelStructure(originalPacket, resultPacket, originalGoPacket); err != nil { + return PacketInfo{}, fmt.Errorf("tunnel structure validation failed: %w", err) + } + + if err := validateTosPreservation(originalPacket, originalGoPacket, resultPacket); err != nil { + return PacketInfo{}, fmt.Errorf("ToS preservation validation failed: %w", err) + } - // Validate basic tunnel structure - validateTunnelStructure(t, originalPacket, resultPacket, originalGoPacket) + packetProto, err := validateProtocol(originalPacket, resultPacket) + if err != nil { + return PacketInfo{}, fmt.Errorf("protocol validation failed: %w", err) + } - // Validate ToS/TrafficClass preservation - validateTosPreservation(t, originalPacket, originalGoPacket, resultPacket) + vs, rl, err := validateServiceAndReal(config, originalPacket, resultPacket, packetProto) + if err != nil { + return PacketInfo{}, fmt.Errorf("service/real validation failed: %w", err) + } - // Validate protocol consistency - packetProto := validateProtocol(t, originalPacket, resultPacket) + if err := validateTunnelSourceAddress(config, originalPacket, resultPacket); err != nil { + return PacketInfo{}, fmt.Errorf("tunnel source address validation failed: %w", err) + } - // Find and validate matching service and real - validateServiceAndReal(t, config, originalPacket, resultPacket, packetProto) + clientIP, _ := netip.AddrFromSlice(originalPacket.SrcIP) + clientPort := originalPacket.SrcPort - // Validate tunnel source address - validateTunnelSourceAddress(t, config, originalPacket, resultPacket) + return PacketInfo{ + VsID: VsIDFromPb(vs.Id), + ClientAddr: clientIP, + ClientPort: clientPort, + RealID: RealIDFromPb(rl.Id), + Packet: resultPacket, + }, nil } -// validateTunnelStructure checks that the packet is properly tunneled with correct inner packet. func validateTunnelStructure( - t *testing.T, originalPacket *framework.PacketInfo, resultPacket *framework.PacketInfo, originalGoPacket gopacket.Packet, -) { - t.Helper() - - // Check that result packet is tunneled - require.True(t, resultPacket.IsTunneled, "result packet is not tunneled") +) error { + if !resultPacket.IsTunneled { + return fmt.Errorf("result packet is not tunneled") + } - // Check that inner packet exists resultInner := resultPacket.InnerPacket - require.NotNil(t, resultInner, "no inner packet in result") + if resultInner == nil { + return fmt.Errorf("no inner packet in result") + } - // Validate that inner packet matches original - assert.Equal(t, - originalPacket.DstIP.String(), - resultInner.DstIP.String(), - "encapsulated packet dst ip mismatch", - ) - assert.Equal(t, - originalPacket.SrcIP.String(), - resultInner.SrcIP.String(), - "encapsulated packet src ip mismatch", - ) - assert.Equal(t, - originalGoPacket.ApplicationLayer().Payload(), - resultPacket.Payload, - "payload mismatch", - ) + if !originalPacket.DstIP.Equal(resultInner.DstIP) { + return fmt.Errorf("encapsulated packet dst ip mismatch") + } + + if !originalPacket.SrcIP.Equal(resultInner.SrcIP) { + return fmt.Errorf("encapsulated packet src ip mismatch") + } + + if !bytes.Equal(originalGoPacket.ApplicationLayer().Payload(), resultPacket.Payload) { + return fmt.Errorf("payload mismatch") + } + + return nil } -// validateTosPreservation checks that ToS/TrafficClass is preserved through encapsulation. func validateTosPreservation( - t *testing.T, originalPacket *framework.PacketInfo, originalGoPacket gopacket.Packet, resultPacket *framework.PacketInfo, -) { - t.Helper() - - // Get original ToS/TrafficClass - originalToS := getOriginalTos(t, originalPacket, originalGoPacket) +) error { + originalToS, err := getOriginalTos(originalPacket, originalGoPacket) + if err != nil { + return fmt.Errorf("failed to get original ToS: %w", err) + } if originalToS == nil { - return // Error already reported + return nil } - // Parse the full tunneled packet tunneled := gopacket.NewPacket( resultPacket.RawData, layers.LayerTypeEthernet, gopacket.Default, ) if tunneled.ErrorLayer() != nil { - t.Errorf( - "failed to parse tunneled packet for ToS/TrafficClass check: %v", - tunneled.ErrorLayer().Error(), - ) - return + return fmt.Errorf("failed to parse tunneled packet: %v", tunneled.ErrorLayer().Error()) } - // Get outer ToS/TrafficClass - outerToS := getOuterTos(t, resultPacket, tunneled) + outerToS, err := getOuterTos(resultPacket, tunneled) + if err != nil { + return fmt.Errorf("failed to get outer ToS: %w", err) + } if outerToS == nil { - return // Error already reported + return nil } - // Get inner ToS/TrafficClass - innerToS := getInnerTos(t, tunneled) + innerToS, err := getInnerTos(tunneled) + if err != nil { + return fmt.Errorf("failed to get inner ToS: %w", err) + } if innerToS == nil { - return // Error already reported + return nil } - // Verify ToS/TrafficClass preservation - assert.Equal(t, - *originalToS, - *outerToS, - "outer packet ToS/TrafficClass mismatch with original", - ) - assert.Equal(t, - *originalToS, - *innerToS, - "inner packet ToS/TrafficClass mismatch with original", - ) + if *originalToS != *outerToS { + return fmt.Errorf( + "outer packet ToS/TrafficClass mismatch with original: expected %d, got %d", + *originalToS, *outerToS, + ) + } + if *originalToS != *innerToS { + return fmt.Errorf( + "inner packet ToS/TrafficClass mismatch with original: expected %d, got %d", + *originalToS, *innerToS, + ) + } + + return nil } -// getOriginalTos extracts ToS/TrafficClass from the original packet. func getOriginalTos( - t *testing.T, originalPacket *framework.PacketInfo, originalGoPacket gopacket.Packet, -) *uint8 { - t.Helper() - +) (*uint8, error) { var tos uint8 if originalPacket.IsIPv4 { if ipv4 := originalGoPacket.Layer(layers.LayerTypeIPv4); ipv4 != nil { tos = ipv4.(*layers.IPv4).TOS } else { - t.Error("no IPv4 layer in original packet to read TOS") - return nil + return nil, fmt.Errorf("no IPv4 layer in original packet") } } else if originalPacket.IsIPv6 { if ipv6 := originalGoPacket.Layer(layers.LayerTypeIPv6); ipv6 != nil { tos = ipv6.(*layers.IPv6).TrafficClass } else { - t.Error("no IPv6 layer in original packet to read TrafficClass") - return nil + return nil, fmt.Errorf("no IPv6 layer in original packet") } } - return &tos + return &tos, nil } -// getOuterTos extracts ToS/TrafficClass from the outer packet header. func getOuterTos( - t *testing.T, resultPacket *framework.PacketInfo, tunneled gopacket.Packet, -) *uint8 { - t.Helper() - +) (*uint8, error) { var tos uint8 - if resultPacket.IsIPv4 { - if ipv4 := tunneled.Layer(layers.LayerTypeIPv4); ipv4 != nil { - tos = ipv4.(*layers.IPv4).TOS - } else { - t.Error("no outer IPv4 layer to read TOS") - return nil + + switch { + case resultPacket.IsIPv4: + ipv4 := tunneled.Layer(layers.LayerTypeIPv4) + if ipv4 == nil { + return nil, fmt.Errorf("no outer IPv4 layer") } - } else if resultPacket.IsIPv6 { - if ipv6 := tunneled.Layer(layers.LayerTypeIPv6); ipv6 != nil { - tos = ipv6.(*layers.IPv6).TrafficClass - } else { - t.Error("no outer IPv6 layer to read TrafficClass") - return nil + tos = ipv4.(*layers.IPv4).TOS + + case resultPacket.IsIPv6: + ipv6 := tunneled.Layer(layers.LayerTypeIPv6) + if ipv6 == nil { + return nil, fmt.Errorf("no outer IPv6 layer") } - } else { - t.Error("unknown outer IP version for tunneled packet") - return nil + tos = ipv6.(*layers.IPv6).TrafficClass + + default: + return nil, fmt.Errorf("unknown outer IP version") } - return &tos -} -// getInnerTos extracts ToS/TrafficClass from the inner packet header. -func getInnerTos(t *testing.T, tunneled gopacket.Packet) *uint8 { - t.Helper() + return &tos, nil +} +func getInnerTos(tunneled gopacket.Packet) (*uint8, error) { var innerToS uint8 ipCount := 0 - foundInner := false + found := false for _, l := range tunneled.Layers() { switch l.LayerType() { @@ -209,200 +202,167 @@ func getInnerTos(t *testing.T, tunneled gopacket.Packet) *uint8 { ipCount++ if ipCount == 2 { innerToS = l.(*layers.IPv4).TOS - foundInner = true + found = true } case layers.LayerTypeIPv6: ipCount++ if ipCount == 2 { innerToS = l.(*layers.IPv6).TrafficClass - foundInner = true + found = true } } - if foundInner { + if found { break } } - if !foundInner { - t.Error("failed to locate inner IP header to read ToS/TrafficClass") - return nil + if !found { + return nil, fmt.Errorf("failed to locate inner IP header") } - return &innerToS + return &innerToS, nil } -// validateProtocol checks protocol consistency between original and encapsulated packet. func validateProtocol( - t *testing.T, originalPacket *framework.PacketInfo, resultPacket *framework.PacketInfo, -) balancerpb.TransportProto { - t.Helper() - +) (balancerpb.TransportProto, error) { resultInner := resultPacket.InnerPacket var originPacketProto layers.IPProtocol if originalPacket.IsIPv4 { - assert.Equal(t, - originalPacket.Protocol, - resultInner.Protocol, - "encapsulated packet protocol mismatch", - ) + if originalPacket.Protocol != resultInner.Protocol { + return 0, fmt.Errorf( + "encapsulated packet protocol mismatch: original %v, result %v", + originalPacket.Protocol, resultInner.Protocol, + ) + } originPacketProto = originalPacket.Protocol } else { - assert.Equal(t, - originalPacket.NextHeader, - resultInner.NextHeader, - "encapsulated packet protocol mismatch", - ) + if originalPacket.NextHeader != resultInner.NextHeader { + return 0, fmt.Errorf( + "encapsulated packet protocol mismatch: original %v, result %v", + originalPacket.NextHeader, resultInner.NextHeader, + ) + } originPacketProto = originalPacket.NextHeader } - // Determine packet proto - var packetProto balancerpb.TransportProto if originPacketProto.LayerType() == layers.LayerTypeTCP { - packetProto = balancerpb.TransportProto_TCP - } else if originPacketProto.LayerType() == layers.LayerTypeUDP { - packetProto = balancerpb.TransportProto_UDP - } else { - t.Errorf("invalid packet protocol: %s", originPacketProto.String()) + return balancerpb.TransportProto_TCP, nil } - - return packetProto + if originPacketProto.LayerType() == layers.LayerTypeUDP { + return balancerpb.TransportProto_UDP, nil + } + return 0, fmt.Errorf("invalid packet protocol: %s", originPacketProto.String()) } -// validateServiceAndReal finds the matching virtual service and real server. func validateServiceAndReal( - t *testing.T, config *balancerpb.BalancerConfig, originalPacket *framework.PacketInfo, resultPacket *framework.PacketInfo, packetProto balancerpb.TransportProto, -) { - t.Helper() - +) (*balancerpb.VirtualService, *balancerpb.Real, error) { if config.PacketHandler == nil { - t.Error("packet handler config is nil") - return + return nil, nil, fmt.Errorf("packet handler config is nil") } originalDstIP := netip.MustParseAddr(originalPacket.DstIP.String()) - for idx := range config.PacketHandler.Vs { - service := config.PacketHandler.Vs[idx] - vsAddr, _ := netip.AddrFromSlice(service.Id.Addr.Bytes) + for _, service := range config.PacketHandler.Vs { + vsAddr, _ := netip.AddrFromSlice(service.Id.Addr) if vsAddr.Compare(originalDstIP) == 0 && (service.Id.Port == uint32(originalPacket.DstPort) || service.Flags.PureL3) && service.Id.Proto == packetProto { - // Found matching service - validateTunnelType(t, service, vsAddr, resultPacket) - if findMatchingReal(t, service, resultPacket) { - return // Success + if err := validateTunnelType(service, vsAddr, resultPacket); err != nil { + return nil, nil, err } - t.Error("not found real which can accept packet sent by balancer") - t.Logf("user packet: %v", originalPacket) - t.Logf("balancer packet: %v", resultPacket) - return + if rl := findMatchingReal(service, resultPacket); rl != nil { + return service, rl, nil + } + + return nil, nil, fmt.Errorf( + "no real found that matches packet destination (original: %v, result: %v)", + originalPacket, resultPacket, + ) } } - t.Error("not found service which could serve packet") - t.Logf("user packet: %v", originalPacket) - t.Logf("balancer packet: %v", resultPacket) + return nil, nil, fmt.Errorf( + "no service found that matches packet (original: %v, result: %v)", + originalPacket, resultPacket, + ) } -// validateTunnelType checks that the tunnel type matches the service configuration. func validateTunnelType( - t *testing.T, service *balancerpb.VirtualService, vsAddr netip.Addr, resultPacket *framework.PacketInfo, -) { - t.Helper() - +) error { if service.Flags.Gre { expectedTunnelType := "gre-ip4" if vsAddr.Is6() { expectedTunnelType = "gre-ip6" } - assert.Equal(t, - expectedTunnelType, - resultPacket.TunnelType, - "packet tunnel type must be gre", - ) + if resultPacket.TunnelType != expectedTunnelType { + return fmt.Errorf( + "packet tunnel type must be %s, got %s", + expectedTunnelType, resultPacket.TunnelType, + ) + } } + return nil } -// findMatchingReal searches for a real server that matches the result packet destination. func findMatchingReal( - t *testing.T, service *balancerpb.VirtualService, resultPacket *framework.PacketInfo, -) bool { - t.Helper() - +) *balancerpb.Real { resultDstIP := netip.MustParseAddr(resultPacket.DstIP.String()) - for realIdx := range service.Reals { - real := service.Reals[realIdx] - realAddr, _ := netip.AddrFromSlice(real.Id.Ip.Bytes) - + for _, real := range service.Reals { + realAddr, _ := netip.AddrFromSlice(real.Id.Ip) if realAddr.Compare(resultDstIP) == 0 { - return true // Found matching real + return real } } - - return false + return nil } // ExtractDestinationReal extracts the destination IP (real server) from a tunneled packet. -// Returns the real server IP that the packet was forwarded to. func ExtractDestinationReal(packet *framework.PacketInfo) (netip.Addr, error) { if !packet.IsTunneled { return netip.Addr{}, fmt.Errorf("packet is not tunneled") } - // The destination IP of the outer packet is the real server dstIP, ok := netip.AddrFromSlice(packet.DstIP) if !ok { - return netip.Addr{}, fmt.Errorf( - "failed to parse destination IP: %v", - packet.DstIP, - ) + return netip.Addr{}, fmt.Errorf("failed to parse destination IP: %v", packet.DstIP) } - return dstIP, nil } // CountPacketsPerReal counts how many packets went to each real server. -// Returns a map from real server IP to packet count. -func CountPacketsPerReal(packets []*framework.PacketInfo) map[netip.Addr]int { +func CountPacketsPerReal(packets []PacketInfo) (map[netip.Addr]int, error) { counts := make(map[netip.Addr]int) - for _, packet := range packets { - realIP, err := ExtractDestinationReal(packet) + realIP, err := ExtractDestinationReal(packet.Packet) if err != nil { - continue // Skip non-tunneled packets + return nil, err } counts[realIP]++ } - - return counts + return counts, nil } // ValidateWeightDistribution checks if packet distribution matches expected weights. -// Uses tolerance-based validation (e.g., 0.15 for 15% tolerance). func ValidateWeightDistribution( - t *testing.T, counts map[netip.Addr]int, expectedWeights map[netip.Addr]uint32, tolerance float64, -) { - t.Helper() - - // Calculate total packets and total weight +) error { totalPackets := 0 for _, count := range counts { totalPackets += count @@ -414,16 +374,12 @@ func ValidateWeightDistribution( } if totalPackets == 0 { - t.Error("no packets to validate") - return + return fmt.Errorf("no packets to validate") } - if totalWeight == 0 { - t.Error("total weight is zero") - return + return fmt.Errorf("total weight is zero") } - // Check each real's distribution for realIP, expectedWeight := range expectedWeights { actualCount := counts[realIP] expectedRatio := float64(expectedWeight) / float64(totalWeight) @@ -431,7 +387,7 @@ func ValidateWeightDistribution( diff := math.Abs(actualRatio - expectedRatio) if diff > tolerance { - t.Errorf( + return fmt.Errorf( "weight distribution mismatch for real %s: expected ratio %.3f (weight %d/%d), got %.3f (%d/%d packets), diff %.3f > tolerance %.3f", realIP, expectedRatio, @@ -445,24 +401,24 @@ func ValidateWeightDistribution( ) } } + + return nil } -// AllPacketsToSameReal checks if all packets went to the same real server. -// Returns the real server IP and true if all packets went to the same real, or empty addr and false otherwise. -func AllPacketsToSameReal(packets []*framework.PacketInfo) (netip.Addr, bool) { - if len(packets) == 0 { - return netip.Addr{}, false +// AllSessionsToSameReal checks if all packets went to the same real server. +func AllSessionsToSameReal(sessions []PacketInfo) (netip.Addr, bool) { + if len(sessions) == 0 { + return netip.Addr{}, true } var firstReal netip.Addr firstSet := false - for _, packet := range packets { - realIP, err := ExtractDestinationReal(packet) - if err != nil { + for _, packet := range sessions { + realIP, ok := netip.AddrFromSlice(packet.RealID.addr[:]) + if !ok { return netip.Addr{}, false } - if !firstSet { firstReal = realIP firstSet = true @@ -474,43 +430,25 @@ func AllPacketsToSameReal(packets []*framework.PacketInfo) (netip.Addr, bool) { return firstReal, true } -// PacketsDistributedAcrossReals checks if packets are distributed across multiple reals. -// Returns true if packets went to more than one real server. -func PacketsDistributedAcrossReals(packets []*framework.PacketInfo) bool { - counts := CountPacketsPerReal(packets) - return len(counts) > 1 -} - -// validateTunnelSourceAddress validates that the tunnel source address is correctly calculated -// according to the formula: tunnel_src = client_ip & !real_mask | real_src & real_mask -// This matches the implementation in modules/balancer/dataplane/tunnel.h func validateTunnelSourceAddress( - t *testing.T, config *balancerpb.BalancerConfig, originalPacket *framework.PacketInfo, resultPacket *framework.PacketInfo, -) { - t.Helper() - +) error { if !resultPacket.IsTunneled { - return // Not a tunneled packet, nothing to validate + return nil } - // Get the client IP (source of original packet) clientIP := originalPacket.SrcIP if clientIP == nil { - t.Error("original packet has no source IP") - return + return fmt.Errorf("original packet has no source IP") } - // Get the tunnel source IP (source of outer packet) tunnelSrcIP := resultPacket.SrcIP if tunnelSrcIP == nil { - t.Error("result packet has no source IP") - return + return fmt.Errorf("result packet has no source IP") } - // Find the matching virtual service and real originalDstIP := netip.MustParseAddr(originalPacket.DstIP.String()) resultDstIP := netip.MustParseAddr(resultPacket.DstIP.String()) @@ -530,166 +468,111 @@ func validateTunnelSourceAddress( } if config.PacketHandler == nil { - t.Error("packet handler config is nil") - return + return fmt.Errorf("packet handler config is nil") } - // Find the matching virtual service for _, service := range config.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(service.Id.Addr.Bytes) + vsAddr, _ := netip.AddrFromSlice(service.Id.Addr) if vsAddr.Compare(originalDstIP) == 0 && (service.Id.Port == uint32(originalPacket.DstPort) || service.Flags.PureL3) && service.Id.Proto == packetProto { - // Find the matching real server for _, real := range service.Reals { - realAddr, _ := netip.AddrFromSlice(real.Id.Ip.Bytes) - + realAddr, _ := netip.AddrFromSlice(real.Id.Ip) if realAddr.Compare(resultDstIP) == 0 { - // Found the matching real, now validate source address - validateSourceAddressCalculation( - t, - clientIP, - tunnelSrcIP, - real, - ) - return + return validateSourceAddressCalculation(clientIP, tunnelSrcIP, real) } } } } + + return nil } -// validateSourceAddressCalculation validates the tunnel source address calculation -// Formula: tunnel_src = client_ip & !real_mask | real_src & real_mask -// The tunnel source IP protocol is determined by the real server's IP protocol, not the client's. func validateSourceAddressCalculation( - t *testing.T, clientIP net.IP, tunnelSrcIP net.IP, - real *balancerpb.Real, -) { - t.Helper() - - if real.SrcAddr == nil || real.SrcMask == nil { - t.Error("real server has no SrcAddr or SrcMask configured") - return + r *balancerpb.Real, +) error { + if r.Src == nil { + return fmt.Errorf("real server has no Src configured") } - if real.Id == nil || real.Id.Ip == nil { - t.Error("real server has no Id or Ip configured") - return - } - - realSrc := real.SrcAddr.Bytes - realMask := real.SrcMask.Bytes - realIP := real.Id.Ip.Bytes + realSrc := r.Src.Addr + realMask := r.Src.Mask + realIP := r.Id.Ip - // Determine real server's IP protocol from its address length realIsIPv6 := len(realIP) == 16 realIsIPv4 := len(realIP) == 4 if !realIsIPv4 && !realIsIPv6 { - t.Errorf("unexpected real IP address length: %d", len(realIP)) - return + return fmt.Errorf("unexpected real IP address length: %d", len(realIP)) } - // Normalize client IP var clientIPBytes []byte - if len(clientIP) == 4 || (len(clientIP) == 16 && clientIP.To4() != nil) { - // Client is IPv4 + switch { + case len(clientIP) == 4: clientIPv4 := clientIP.To4() if clientIPv4 == nil { - t.Error("failed to convert client IP to IPv4") - return + return fmt.Errorf("failed to convert client IP to IPv4") } clientIPBytes = []byte(clientIPv4) - } else if len(clientIP) == 16 { - // Client is IPv6 + + case len(clientIP) == 16: clientIPBytes = []byte(clientIP) - } else { - t.Errorf("unexpected client IP address length: %d", len(clientIP)) - return + + default: + return fmt.Errorf("unexpected client IP address length: %d", len(clientIP)) } - // Validate based on real server's IP protocol if realIsIPv6 { - // Tunnel to IPv6 real: tunnel source MUST be IPv6 - if len(tunnelSrcIP) != 16 || tunnelSrcIP.To4() != nil { - t.Errorf( - "tunnel source IP should be IPv6 when tunneling to IPv6 real, got %s", - tunnelSrcIP, - ) - return + if len(tunnelSrcIP) != 16 { + return fmt.Errorf("tunnel source IP should be IPv6 for IPv6 real, got %s", tunnelSrcIP) } - // Calculate expected source: client_ip & !real_mask | real_src & real_mask expectedSrc := make([]byte, 16) - - // Determine how many bytes to use from client IP - clientLen := len(clientIPBytes) - if clientLen > 16 { - clientLen = 16 - } - - for i := 0; i < 16; i++ { + clientLen := min(len(clientIPBytes), 16) + for i := range 16 { var clientByte byte if i < clientLen { clientByte = clientIPBytes[i] - } else { - clientByte = 0 } expectedSrc[i] = (clientByte & ^realMask[i]) | (realSrc[i] & realMask[i]) } - expectedSrcIP := net.IP(expectedSrc) - if !tunnelSrcIP.Equal(expectedSrcIP) { - t.Errorf( - "tunnel source address mismatch: expected %s, got %s (client=%s, real_src=%s, real_mask=%s, real_ip=%s)", - expectedSrcIP, + if !tunnelSrcIP.Equal(net.IP(expectedSrc)) { + return fmt.Errorf( + "tunnel source address mismatch: expected %s, got %s (client=%s, src=%s, mask=%s)", + net.IP(expectedSrc), tunnelSrcIP, clientIP, net.IP(realSrc), net.IP(realMask), - net.IP(realIP), ) } } else { - // Tunnel to IPv4 real: tunnel source MUST be IPv4 tunnelSrcIPv4 := tunnelSrcIP.To4() if tunnelSrcIPv4 == nil { - t.Errorf( - "tunnel source IP should be IPv4 when tunneling to IPv4 real, got %s", - tunnelSrcIP, - ) - return + return fmt.Errorf("tunnel source IP should be IPv4 for IPv4 real, got %s", tunnelSrcIP) } - // Calculate expected source: client_ip & !real_mask | real_src & real_mask - // Use only first 4 bytes of client IP (whether IPv4 or IPv6) expectedSrc := make([]byte, 4) - for i := 0; i < 4; i++ { + for i := range 4 { var clientByte byte if i < len(clientIPBytes) { clientByte = clientIPBytes[i] - } else { - clientByte = 0 } expectedSrc[i] = (clientByte & ^realMask[i]) | (realSrc[i] & realMask[i]) } - expectedSrcIP := net.IP(expectedSrc) - if !tunnelSrcIPv4.Equal(expectedSrcIP) { - t.Errorf( - "tunnel source address mismatch: expected %s, got %s (client=%s, real_src=%s, real_mask=%s, real_ip=%s)", - expectedSrcIP, - tunnelSrcIPv4, - clientIP, - net.IP(realSrc), - net.IP(realMask), - net.IP(realIP), + if !tunnelSrcIPv4.Equal(net.IP(expectedSrc)) { + return fmt.Errorf( + "tunnel source address mismatch: expected %s, got %s (client=%s, src=%s, mask=%s)", + net.IP(expectedSrc), tunnelSrcIPv4, clientIP, net.IP(realSrc), net.IP(realMask), ) } } + + return nil } diff --git a/modules/balancer/tests/go/utils/vs.go b/modules/balancer/tests/go/utils/vs.go new file mode 100644 index 000000000..8b87e0052 --- /dev/null +++ b/modules/balancer/tests/go/utils/vs.go @@ -0,0 +1,155 @@ +package utils + +import ( + "bytes" + "cmp" + "fmt" + "math/rand/v2" + "net/netip" + + balancer "github.com/yanet-platform/yanet2/modules/balancer/controlplane" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" +) + +type VsID struct { + addrLen int + addr [16]byte + port uint16 + proto balancerpb.TransportProto +} + +func (vs *VsID) Compare(other *VsID) int { + if vs.addrLen != other.addrLen { + return cmp.Compare(vs.addrLen, other.addrLen) + } + if cmp := bytes.Compare(vs.addr[:vs.addrLen], other.addr[:other.addrLen]); cmp != 0 { + return cmp + } + if cmp := cmp.Compare(vs.port, other.port); cmp != 0 { + return cmp + } + return cmp.Compare(vs.proto, other.proto) +} + +func (vs *VsID) String() string { + ip, _ := netip.AddrFromSlice(vs.addr[:vs.addrLen]) + proto := "TCP" + if vs.proto == balancerpb.TransportProto_UDP { + proto = "UDP" + } + ips := ip.String() + if ip.Is6() { + ips = fmt.Sprintf("[%s]", ips) + } + return fmt.Sprintf("%s:%d/%s", ips, vs.port, proto) +} + +func VsIDFromPb(vs *balancerpb.VsIdentifier) VsID { + if len(vs.Addr) != 4 && len(vs.Addr) != 16 { + panic(fmt.Sprintf("invalid address length: %d", len(vs.Addr))) + } + if vs.Port > 65535 { + panic(fmt.Sprintf("invalid port: %d", vs.Port)) + } + if vs.Proto != balancerpb.TransportProto_TCP && vs.Proto != balancerpb.TransportProto_UDP { + panic(fmt.Sprintf("invalid protocol: %s", vs.Proto)) + } + addr := [16]byte{} + copy(addr[:], vs.Addr) + return VsID{ + addrLen: len(vs.Addr), + addr: addr, + port: uint16(vs.Port), + proto: vs.Proto, + } +} + +func VsStatsEquals(a *balancerpb.VsStats, b *balancerpb.VsStats) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.IncomingPackets == b.IncomingPackets && + a.IncomingBytes == b.IncomingBytes && + a.PacketSrcNotAllowed == b.PacketSrcNotAllowed && + a.NoReals == b.NoReals && + a.SessionTableOverflow == b.SessionTableOverflow && + a.EchoIcmpPackets == b.EchoIcmpPackets && + a.ErrorIcmpPackets == b.ErrorIcmpPackets && + a.RealIsDisabled == b.RealIsDisabled && + a.RealIsRemoved == b.RealIsRemoved && + a.NotRescheduledPackets == b.NotRescheduledPackets && + a.BroadcastedIcmpPackets == b.BroadcastedIcmpPackets && + a.CreatedSessions == b.CreatedSessions && + a.OutgoingPackets == b.OutgoingPackets && + a.OutgoingBytes == b.OutgoingBytes +} + +func VsCount(b *balancer.Balancer) int { + return len(b.Config().PacketHandler.Vs) +} + +func SelectVS( + cnt int, + vs []*balancerpb.VirtualService, + rng *rand.Rand, +) []*balancerpb.VirtualService { + if cnt >= len(vs) { + panic(fmt.Sprintf("selectVS: cnt >= len(vs): %d >= %d", cnt, len(vs))) + } + indices := make([]int, len(vs)) + for i := range indices { + indices[i] = i + } + for i := len(indices) - 1; i > 0; i-- { + j := rng.IntN(i + 1) + indices[i], indices[j] = indices[j], indices[i] + } + result := make([]*balancerpb.VirtualService, cnt) + for i := range cnt { + result[i] = vs[indices[i]] + } + return result +} + +func VSUpdateSomeReals( + vs *balancerpb.VirtualService, + rng *rand.Rand, +) *balancerpb.VirtualService { + reals := make([]*balancerpb.Real, len(vs.Reals)) + copy(reals, vs.Reals) + for i := len(reals) - 1; i > 0; i-- { + j := rng.IntN(i + 1) + reals[i], reals[j] = reals[j], reals[i] + } + delta := 2 - rng.IntN(5) // -2 to 2 + replaceCnt := max(0, min(len(reals), len(reals)/2+delta)) + for i := range replaceCnt { + reals[i] = GenerateReal(rng) + } + newCnt := rng.IntN(3) + for range newCnt { + reals = append(reals, GenerateReal(rng)) + } + return &balancerpb.VirtualService{ + Id: vs.Id, + AllowedSrcs: vs.AllowedSrcs, + Flags: vs.Flags, + Scheduler: vs.Scheduler, + Reals: reals, + } +} + +func GenerateRealUpdates( + services []*balancerpb.VirtualService, + rng *rand.Rand, +) []*balancerpb.RealUpdate { + updates := make([]*balancerpb.RealUpdate, 0) + for _, vs := range services { + upd := VSGenerateRealUpdates(vs, rng) + updates = append(updates, upd...) + } + return updates +} diff --git a/modules/balancer/tests/go/wlc_test.go b/modules/balancer/tests/go/wlc_test.go deleted file mode 100644 index 795fed6d0..000000000 --- a/modules/balancer/tests/go/wlc_test.go +++ /dev/null @@ -1,1030 +0,0 @@ -package balancer_test - -// TestWlc validates the Weighted Least Connection (WLC) scheduling algorithm: -// -// # Initial Configuration -// - Virtual service: 1.1.1.1:80 (TCP) with WLC enabled -// - Three real servers with weights: Real1=1, Real2=1, Real3=2 -// - Initially only Real1 and Real2 are enabled -// -// # Stage 1: Two Reals with Equal Weights -// - Sends 500 random TCP SYN packets -// - Validates uniform distribution (250 packets each to Real1 and Real2) -// - Verifies Stats, Info, and Sessions APIs show correct counts -// - Enables Real3 (weight=2) -// -// # Stage 1 Continued: Three Reals with Weights 1:1:2 -// - Sends 2500 more random TCP SYN packets -// - Validates distribution proportional to weights -// - Verifies Real3 receives approximately 2× traffic of Real1/Real2 -// -// # Stage 2: State Persistence with New Agent -// - Creates new BalancerAgent attached to same shared memory -// - Verifies all reals are enabled via Graph() -// - Confirms original weights (1, 1, 2) via Config() -// - Disables Real1 and sends 100 packets -// * Validates packets only go to Real2 and Real3 -// * Expected distribution: Real2 ~33%, Real3 ~67% (weights 1:2) -// - Re-enables Real1 and sends 300 more packets -// - Validates session distribution proportional to weights (1:1:2) -// * Expected ratios: Real1=25%, Real2=25%, Real3=50% -// * Tolerance: ±15% -// -// # Stage 3: Multi-VS Configuration Update -// - Updates config to 4 virtual services: -// * VS1: Original VS with WLC=true, weights 1:1:2 -// * VS2: New VS with WLC=true, weights 1:2:1 -// * VS3: New VS with WLC=false (ROUND_ROBIN), weights 1:1 -// * VS4: New VS with WLC=true, weights 2:2:1 -// - Verifies Config() matches updated configuration -// - Creates third BalancerAgent and verifies config persistence -// - Confirms all 4 virtual services present with correct settings - -import ( - "fmt" - "math" - "net/netip" - "testing" - - "github.com/c2h5oh/datasize" - "github.com/gopacket/gopacket/layers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yanet-platform/yanet2/common/go/logging" - "github.com/yanet-platform/yanet2/common/go/xpacket" - "github.com/yanet-platform/yanet2/modules/balancer/agent/balancerpb" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go" - "github.com/yanet-platform/yanet2/modules/balancer/tests/go/utils" - "go.uber.org/zap/zapcore" - "google.golang.org/protobuf/types/known/durationpb" -) - -func TestWlc(t *testing.T) { - vsIP := netip.MustParseAddr("1.1.1.1") - vsPort := uint16(80) - real1Ip := netip.MustParseAddr("2.2.2.2") - real2Ip := netip.MustParseAddr("3.3.3.3") - real3Ip := netip.MustParseAddr("4.4.4.4") - - client := func(id int) netip.Addr { - return netip.MustParseAddr( - fmt.Sprintf("10.%d.%d.%d", id/(256*256)%256, (id/256)%256, id%256), - ) - } - - config := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - { - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{ - Bytes: vsIP.AsSlice(), - }, - Port: uint32(vsPort), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("10.0.0.0"). - AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.0.0.0"). - AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: true, - }, - Reals: []*balancerpb.Real{ - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: real1Ip.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: real1Ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: real2Ip.AsSlice(), - }, - Port: 0, - }, - Weight: 1, - SrcAddr: &balancerpb.Addr{ - Bytes: real2Ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - { - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{ - Bytes: real3Ip.AsSlice(), - }, - Port: 0, - }, - Weight: 2, - SrcAddr: &balancerpb.Addr{ - Bytes: real3Ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255"). - AsSlice(), - }, - }, - }, - Peers: []*balancerpb.Addr{}, - }, - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(8000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - ts, err := utils.Make(&utils.TestConfig{ - Mock: utils.SingleWorkerMockConfig(64*datasize.MB, 4*datasize.MB), - Balancer: config, - AgentMemory: func() *datasize.ByteSize { - memory := 16 * datasize.MB - return &memory - }(), - }) - require.NoError(t, err, "failed to setup test") - defer ts.Free() - - mock := ts.Mock - balancerMgr := ts.Balancer - - // Enable only the first two reals initially - initialConfig := balancerMgr.Config() - enableTrue := true - enableFalse := false - initialUpdates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: initialConfig.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: initialConfig.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real2Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - { - RealId: &balancerpb.RealIdentifier{ - Vs: initialConfig.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real3Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - } - _, err = balancerMgr.UpdateReals(initialUpdates, false) - require.NoError(t, err, "failed to enable first two reals") - - // Send random packets - packets := 500 - - // Send random SYNs to the first two reals - // Expect uniform distribution - now := mock.CurrentTime() - - t.Run("Send_Random_SYNs", func(t *testing.T) { - for packetIdx := range packets { - clientIP := client(packetIdx) - packetLayers := utils.MakeTCPPacket( - clientIP, - 1000, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output)) - require.Empty(t, result.Drop) - } - - // Get stats - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - stats, err := balancerMgr.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, stats) - require.NotEmpty(t, stats.Vs) - - vsStats := stats.Vs[0] - require.NotEmpty(t, vsStats.Reals) - - assert.Equal( - t, - uint64(packets/2), - vsStats.Reals[0].Stats.CreatedSessions, - ) - assert.Equal( - t, - uint64(packets/2), - vsStats.Reals[1].Stats.CreatedSessions, - ) - assert.Equal(t, uint64(0), vsStats.Reals[2].Stats.CreatedSessions) - - // NEW: Validate Stats - t.Run("Validate_Stats", func(t *testing.T) { - assert.Equal( - t, - uint64(packets), - vsStats.Stats.IncomingPackets, - "VS incoming packets", - ) - assert.Equal( - t, - uint64(packets), - vsStats.Stats.OutgoingPackets, - "VS outgoing packets", - ) - assert.Equal( - t, - uint64(packets/2), - vsStats.Reals[0].Stats.Packets, - "Real1 packets", - ) - assert.Equal( - t, - uint64(packets/2), - vsStats.Reals[1].Stats.Packets, - "Real2 packets", - ) - assert.Equal( - t, - uint64(0), - vsStats.Reals[2].Stats.Packets, - "Real3 packets", - ) - }) - - // NEW: Validate Info - t.Run("Validate_Info", func(t *testing.T) { - info, err := balancerMgr.Info(now) - require.NoError(t, err) - require.NotNil(t, info) - require.NotEmpty(t, info.Vs) - - vsInfo := info.Vs[0] - require.NotEmpty(t, vsInfo.Reals) - - assert.Equal( - t, - uint64(packets), - info.ActiveSessions, - "total active sessions", - ) - assert.Equal( - t, - uint64(packets), - vsInfo.ActiveSessions, - "VS active sessions", - ) - assert.Equal( - t, - uint64(packets/2), - vsInfo.Reals[0].ActiveSessions, - "Real1 active sessions", - ) - assert.Equal( - t, - uint64(packets/2), - vsInfo.Reals[1].ActiveSessions, - "Real2 active sessions", - ) - assert.Equal( - t, - uint64(0), - vsInfo.Reals[2].ActiveSessions, - "Real3 active sessions", - ) - }) - - // NEW: Validate Sessions - t.Run("Validate_Sessions", func(t *testing.T) { - sessions, err := balancerMgr.Sessions(now) - require.NoError(t, err) - require.NotNil(t, sessions) - - assert.Equal(t, packets, len(sessions), "total sessions count") - - // Count sessions per real - real1Sessions := 0 - real2Sessions := 0 - real3Sessions := 0 - - for _, session := range sessions { - realAddr, _ := netip.AddrFromSlice(session.RealId.Real.Ip.Bytes) - switch realAddr { - case real1Ip: - real1Sessions++ - case real2Ip: - real2Sessions++ - case real3Ip: - real3Sessions++ - } - } - - assert.Equal(t, packets/2, real1Sessions, "Real1 sessions") - assert.Equal(t, packets/2, real2Sessions, "Real2 sessions") - assert.Equal(t, 0, real3Sessions, "Real3 sessions") - }) - - // Refresh to scan active sessions, update them, - // and recalculate effective weights - err = balancerMgr.Refresh(now) - require.NoError(t, err) - }) - - // Enable third real - t.Run("Enable_Third_Real", func(t *testing.T) { - config := balancerMgr.Config() - enableTrue := true - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: config.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real3Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - } - _, err := balancerMgr.UpdateReals(updates, false) - require.NoError(t, err) - }) - - // Send more random packets - // Ensure we have good distribution after - t.Run("Send_Random_SYNs_Again", func(t *testing.T) { - firstClient := packets - packets = 5 * packets - for packetIdx := range packets { - if packetIdx%50 == 0 { - if err := balancerMgr.Refresh(now); err != nil { - t.Errorf( - "failed to refresh: packetIdx=%d, error=%v", - packetIdx, - err, - ) - } - } - clientIP := client(firstClient + packetIdx) - packetLayers := utils.MakeTCPPacket( - clientIP, - 1000, - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output)) - require.Empty(t, result.Drop) - } - - // Get stats - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - stats, err := balancerMgr.Stats(statsRef) - require.NoError(t, err) - require.NotNil(t, stats) - require.NotEmpty(t, stats.Vs) - - vsStats := stats.Vs[0] - require.NotEmpty(t, vsStats.Reals) - - firstTwoPackets := vsStats.Reals[0].Stats.CreatedSessions + vsStats.Reals[1].Stats.CreatedSessions - thirdPackets := vsStats.Reals[2].Stats.CreatedSessions - - rel := float64(thirdPackets) / float64(firstTwoPackets) - assert.Less(t, math.Abs(rel-1.0), 0.3) - }) - - // NEW: Stage 2 - New Balancer Agent with State Persistence - t.Run("Stage2_New_Agent_State_Persistence", func(t *testing.T) { - // Create new balancer agent using same shared memory - logLevel := zapcore.InfoLevel - sugaredLogger, _, _ := logging.Init(&logging.Config{ - Level: logLevel, - }) - - agentMemory := 16 * datasize.MB - newAgent, err := balancer.NewBalancerAgent( - ts.Mock.SharedMemory(), // Same shared memory - agentMemory, - sugaredLogger, - ) - require.NoError(t, err, "failed to create new balancer agent") - - // Attach to existing BalancerManager - newBalancer, err := newAgent.BalancerManager(utils.BalancerName) - require.NoError(t, err, "failed to attach to existing balancer manager") - require.NotNil(t, newBalancer, "balancer manager should not be nil") - - // 2.1: Verify all reals are enabled using Graph() - t.Run("Verify_All_Reals_Enabled", func(t *testing.T) { - graph := newBalancer.Graph() - require.NotNil(t, graph) - require.NotEmpty(t, graph.VirtualServices) - - vs := graph.VirtualServices[0] - require.Equal(t, 3, len(vs.Reals), "should have 3 reals") - - // Verify all reals are enabled - for i, real := range vs.Reals { - assert.True(t, real.Enabled, "Real%d should be enabled", i+1) - } - }) - - // 2.2: Verify config contains original weights (1, 1, 2) - t.Run("Verify_Original_Weights", func(t *testing.T) { - config := newBalancer.Config() - require.NotNil(t, config) - require.NotNil(t, config.PacketHandler) - require.NotEmpty(t, config.PacketHandler.Vs) - - vs := config.PacketHandler.Vs[0] - require.Equal(t, 3, len(vs.Reals), "should have 3 reals") - - assert.Equal(t, uint32(1), vs.Reals[0].Weight, "Real1 weight") - assert.Equal(t, uint32(1), vs.Reals[1].Weight, "Real2 weight") - assert.Equal(t, uint32(2), vs.Reals[2].Weight, "Real3 weight") - }) - - // 2.3: Disable first real and send packets - t.Run("Disable_Real1_And_Send_Packets", func(t *testing.T) { - config := newBalancer.Config() - enableFalse := false - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: config.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableFalse, - }, - } - _, err := newBalancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to disable Real1") - - // Send 100 new packets - firstClient := 3500 - for i := 0; i < 100; i++ { - clientIP := client(firstClient + i) - packetLayers := utils.MakeTCPPacket( - clientIP, - uint16(2000+i), - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output)) - require.Empty(t, result.Drop) - } - - // Verify packets only go to Real2 and Real3 - // Expected distribution: Real2 ~33%, Real3 ~67% (weights 1:2) - statsRef := &balancerpb.PacketHandlerRef{ - Device: &utils.DeviceName, - Pipeline: &utils.PipelineName, - Function: &utils.FunctionName, - Chain: &utils.ChainName, - } - stats, err := newBalancer.Stats(statsRef) - require.NoError(t, err) - - vsStats := stats.Vs[0] - // Note: Stats are cumulative, so we can't easily verify just these 100 packets - // But we can verify Real1 didn't get any new sessions - t.Logf( - "After disabling Real1: Real1=%d, Real2=%d, Real3=%d sessions", - vsStats.Reals[0].Stats.CreatedSessions, - vsStats.Reals[1].Stats.CreatedSessions, - vsStats.Reals[2].Stats.CreatedSessions, - ) - }) - - // 2.4: Re-enable first real and send more packets - t.Run("Enable_Real1_And_Send_More_Packets", func(t *testing.T) { - config := newBalancer.Config() - enableTrue := true - updates := []*balancerpb.RealUpdate{ - { - RealId: &balancerpb.RealIdentifier{ - Vs: config.PacketHandler.Vs[0].Id, - Real: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: real1Ip.AsSlice()}, - Port: 0, - }, - }, - Enable: &enableTrue, - }, - } - _, err := newBalancer.UpdateReals(updates, false) - require.NoError(t, err, "failed to enable Real1") - - // Send 300 more packets - firstClient := 3600 - for i := 0; i < 300; i++ { - if i%50 == 0 { - err := newBalancer.Refresh(now) - require.NoError(t, err) - } - clientIP := client(firstClient + i) - packetLayers := utils.MakeTCPPacket( - clientIP, - uint16(3000+i), - vsIP, - vsPort, - &layers.TCP{SYN: true}, - ) - packet := xpacket.LayersToPacket(t, packetLayers...) - result, err := mock.HandlePackets(packet) - require.NoError(t, err) - require.Equal(t, 1, len(result.Output)) - require.Empty(t, result.Drop) - } - }) - - // 2.5: Validate session distribution proportional to weights (1:1:2) - t.Run("Validate_Session_Distribution", func(t *testing.T) { - info, err := newBalancer.Info(now) - require.NoError(t, err) - require.NotNil(t, info) - - vsInfo := info.Vs[0] - real1Sessions := vsInfo.Reals[0].ActiveSessions - real2Sessions := vsInfo.Reals[1].ActiveSessions - real3Sessions := vsInfo.Reals[2].ActiveSessions - - totalSessions := real1Sessions + real2Sessions + real3Sessions - require.Greater( - t, - totalSessions, - uint64(0), - "should have active sessions", - ) - - // Calculate ratios - real1Ratio := float64(real1Sessions) / float64(totalSessions) - real2Ratio := float64(real2Sessions) / float64(totalSessions) - real3Ratio := float64(real3Sessions) / float64(totalSessions) - - // Expected ratios based on weights 1:1:2 - // Total weight = 4, so Real1=25%, Real2=25%, Real3=50% - expectedReal1Ratio := 0.25 - expectedReal2Ratio := 0.25 - expectedReal3Ratio := 0.50 - - tolerance := 0.15 // 15% tolerance - - t.Logf( - "Session distribution: Real1=%.2f%% (expected 25%%), Real2=%.2f%% (expected 25%%), Real3=%.2f%% (expected 50%%)", - real1Ratio*100, - real2Ratio*100, - real3Ratio*100, - ) - - assert.InDelta( - t, - expectedReal1Ratio, - real1Ratio, - tolerance, - "Real1 session ratio", - ) - assert.InDelta( - t, - expectedReal2Ratio, - real2Ratio, - tolerance, - "Real2 session ratio", - ) - assert.InDelta( - t, - expectedReal3Ratio, - real3Ratio, - tolerance, - "Real3 session ratio", - ) - }) - }) - - // NEW: Stage 3 - Multi-VS Configuration Update - t.Run("Stage3_Multi_VS_Configuration", func(t *testing.T) { - // 3.1: Create updated config with 4 virtual services - vs2Ip := netip.MustParseAddr("10.10.1.1") - vs3Ip := netip.MustParseAddr("10.10.2.1") - vs4Ip := netip.MustParseAddr("10.10.3.1") - - real4Ip := netip.MustParseAddr("192.168.1.1") - real5Ip := netip.MustParseAddr("192.168.1.2") - real6Ip := netip.MustParseAddr("192.168.1.3") - real7Ip := netip.MustParseAddr("192.168.2.1") - real8Ip := netip.MustParseAddr("192.168.2.2") - real9Ip := netip.MustParseAddr("192.168.3.1") - real10Ip := netip.MustParseAddr("192.168.3.2") - real11Ip := netip.MustParseAddr("192.168.3.3") - - createReal := func(ip netip.Addr, weight uint32) *balancerpb.Real { - return &balancerpb.Real{ - Id: &balancerpb.RelativeRealIdentifier{ - Ip: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: 0, - }, - Weight: weight, - SrcAddr: &balancerpb.Addr{ - Bytes: ip.AsSlice(), - }, - SrcMask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("255.255.255.255").AsSlice(), - }, - } - } - - createVS := func(ip netip.Addr, port uint16, wlc bool, reals []*balancerpb.Real) *balancerpb.VirtualService { - return &balancerpb.VirtualService{ - Id: &balancerpb.VsIdentifier{ - Addr: &balancerpb.Addr{Bytes: ip.AsSlice()}, - Port: uint32(port), - Proto: balancerpb.TransportProto_TCP, - }, - Scheduler: balancerpb.VsScheduler_ROUND_ROBIN, - AllowedSrcs: []*balancerpb.AllowedSources{ - { - Nets: []*balancerpb.Net{{ - Addr: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - Mask: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("0.0.0.0").AsSlice(), - }, - }}, - }, - }, - Flags: &balancerpb.VsFlags{ - Gre: false, - FixMss: false, - Ops: false, - PureL3: false, - Wlc: wlc, - }, - Reals: reals, - Peers: []*balancerpb.Addr{}, - } - } - - updatedConfig := &balancerpb.BalancerConfig{ - PacketHandler: &balancerpb.PacketHandlerConfig{ - SourceAddressV4: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("5.5.5.5").AsSlice(), - }, - SourceAddressV6: &balancerpb.Addr{ - Bytes: netip.MustParseAddr("fe80::5").AsSlice(), - }, - Vs: []*balancerpb.VirtualService{ - // Keep original VS with WLC=true - createVS(vsIP, vsPort, true, []*balancerpb.Real{ - createReal(real1Ip, 1), - createReal(real2Ip, 1), - createReal(real3Ip, 2), - }), - // Add VS2 with WLC=true, weights 1,2,1 - createVS(vs2Ip, 80, true, []*balancerpb.Real{ - createReal(real4Ip, 1), - createReal(real5Ip, 2), - createReal(real6Ip, 1), - }), - // Add VS3 with WLC=false (RR), weights 1,1 - createVS(vs3Ip, 80, false, []*balancerpb.Real{ - createReal(real7Ip, 1), - createReal(real8Ip, 1), - }), - // Add VS4 with WLC=true, weights 2,2,1 - createVS(vs4Ip, 80, true, []*balancerpb.Real{ - createReal(real9Ip, 2), - createReal(real10Ip, 2), - createReal(real11Ip, 1), - }), - }, - DecapAddresses: []*balancerpb.Addr{}, - SessionsTimeouts: &balancerpb.SessionsTimeouts{ - TcpSynAck: 60, - TcpSyn: 60, - TcpFin: 60, - Tcp: 60, - Udp: 60, - Default: 60, - }, - }, - State: &balancerpb.StateConfig{ - SessionTableCapacity: func() *uint64 { v := uint64(8000); return &v }(), - SessionTableMaxLoadFactor: func() *float32 { v := float32(0.5); return &v }(), - RefreshPeriod: durationpb.New(0), - Wlc: &balancerpb.WlcConfig{ - Power: func() *uint64 { v := uint64(10); return &v }(), - MaxWeight: func() *uint32 { v := uint32(1000); return &v }(), - }, - }, - } - - _, err := balancerMgr.Update(updatedConfig, now) - require.NoError(t, err, "failed to update configuration") - - // 3.2: Verify Config() matches updated configuration - t.Run("Verify_Config_Matches_Update", func(t *testing.T) { - retrievedConfig := balancerMgr.Config() - require.NotNil(t, retrievedConfig) - require.NotNil(t, retrievedConfig.PacketHandler) - - // Verify 4 virtual services - assert.Equal( - t, - 4, - len(retrievedConfig.PacketHandler.Vs), - "should have 4 virtual services", - ) - - // Verify each VS - vs1Found := false - vs2Found := false - vs3Found := false - vs4Found := false - - for _, vs := range retrievedConfig.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - switch vsAddr { - case vsIP: - vs1Found = true - require.NotNil(t, vs.Flags, "VS1 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS1 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS1 should have 3 reals") - assert.Equal( - t, - uint32(1), - vs.Reals[0].Weight, - "VS1 Real1 weight", - ) - assert.Equal( - t, - uint32(1), - vs.Reals[1].Weight, - "VS1 Real2 weight", - ) - assert.Equal( - t, - uint32(2), - vs.Reals[2].Weight, - "VS1 Real3 weight", - ) - case vs2Ip: - vs2Found = true - require.NotNil(t, vs.Flags, "VS2 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS2 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS2 should have 3 reals") - assert.Equal( - t, - uint32(1), - vs.Reals[0].Weight, - "VS2 Real1 weight", - ) - assert.Equal( - t, - uint32(2), - vs.Reals[1].Weight, - "VS2 Real2 weight", - ) - assert.Equal( - t, - uint32(1), - vs.Reals[2].Weight, - "VS2 Real3 weight", - ) - case vs3Ip: - vs3Found = true - require.NotNil(t, vs.Flags, "VS3 flags should not be nil") - assert.False(t, vs.Flags.Wlc, "VS3 should have WLC=false") - assert.Equal(t, 2, len(vs.Reals), "VS3 should have 2 reals") - assert.Equal( - t, - uint32(1), - vs.Reals[0].Weight, - "VS3 Real1 weight", - ) - assert.Equal( - t, - uint32(1), - vs.Reals[1].Weight, - "VS3 Real2 weight", - ) - case vs4Ip: - vs4Found = true - require.NotNil(t, vs.Flags, "VS4 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS4 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS4 should have 3 reals") - assert.Equal( - t, - uint32(2), - vs.Reals[0].Weight, - "VS4 Real1 weight", - ) - assert.Equal( - t, - uint32(2), - vs.Reals[1].Weight, - "VS4 Real2 weight", - ) - assert.Equal( - t, - uint32(1), - vs.Reals[2].Weight, - "VS4 Real3 weight", - ) - } - } - - assert.True(t, vs1Found, "VS1 should be present") - assert.True(t, vs2Found, "VS2 should be present") - assert.True(t, vs3Found, "VS3 should be present") - assert.True(t, vs4Found, "VS4 should be present") - }) - - // 3.3: Create new BalancerAgent and verify config persistence - t.Run("Verify_Config_Persistence_With_New_Agent", func(t *testing.T) { - // Create third balancer agent using same shared memory - logLevel := zapcore.InfoLevel - sugaredLogger, _, _ := logging.Init(&logging.Config{ - Level: logLevel, - }) - - agentMemory := 16 * datasize.MB - thirdAgent, err := balancer.NewBalancerAgent( - ts.Mock.SharedMemory(), // Same shared memory - agentMemory, - sugaredLogger, - ) - require.NoError(t, err, "failed to create third balancer agent") - - // Attach to existing BalancerManager - thirdBalancer, err := thirdAgent.BalancerManager(utils.BalancerName) - require.NoError( - t, - err, - "failed to attach to existing balancer manager", - ) - require.NotNil( - t, - thirdBalancer, - "balancer manager should not be nil", - ) - - // Verify config matches the updated multi-VS config - thirdConfig := thirdBalancer.Config() - require.NotNil(t, thirdConfig) - require.NotNil(t, thirdConfig.PacketHandler) - - // Verify 4 virtual services - assert.Equal( - t, - 4, - len(thirdConfig.PacketHandler.Vs), - "should have 4 virtual services", - ) - - // Verify each VS is present with correct configuration - vs1Found := false - vs2Found := false - vs3Found := false - vs4Found := false - - for _, vs := range thirdConfig.PacketHandler.Vs { - vsAddr, _ := netip.AddrFromSlice(vs.Id.Addr.Bytes) - switch vsAddr { - case vsIP: - vs1Found = true - require.NotNil(t, vs.Flags, "VS1 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS1 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS1 should have 3 reals") - case vs2Ip: - vs2Found = true - require.NotNil(t, vs.Flags, "VS2 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS2 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS2 should have 3 reals") - case vs3Ip: - vs3Found = true - require.NotNil(t, vs.Flags, "VS3 flags should not be nil") - assert.False(t, vs.Flags.Wlc, "VS3 should have WLC=false") - assert.Equal(t, 2, len(vs.Reals), "VS3 should have 2 reals") - case vs4Ip: - vs4Found = true - require.NotNil(t, vs.Flags, "VS4 flags should not be nil") - assert.True(t, vs.Flags.Wlc, "VS4 should have WLC=true") - assert.Equal(t, 3, len(vs.Reals), "VS4 should have 3 reals") - } - } - - assert.True( - t, - vs1Found, - "VS1 should be present in third agent config", - ) - assert.True( - t, - vs2Found, - "VS2 should be present in third agent config", - ) - assert.True( - t, - vs3Found, - "VS3 should be present in third agent config", - ) - assert.True( - t, - vs4Found, - "VS4 should be present in third agent config", - ) - - t.Log( - "Successfully verified config persistence across agent instances", - ) - }) - }) -} diff --git a/modules/balancer/tests/meson.build b/modules/balancer/tests/meson.build deleted file mode 100644 index 082b746f6..000000000 --- a/modules/balancer/tests/meson.build +++ /dev/null @@ -1 +0,0 @@ -subdir('unit') diff --git a/modules/balancer/tests/unit/active_sessions_tracker.c b/modules/balancer/tests/unit/active_sessions_tracker.c deleted file mode 100644 index 08ad9539c..000000000 --- a/modules/balancer/tests/unit/active_sessions_tracker.c +++ /dev/null @@ -1,444 +0,0 @@ -/* - * Unit tests for active_sessions_tracker. - * - * Constants: - * ACTIVE_SESSIONS_TRACKER_PRECISION = 16 - * active_sessions_tracker_now(ts) = ts / 16 - * active_sessions_tracker_until(ts) = (ts + 15) / 16 - * - * shard->count tracks the running total of active sessions. - * new_session() adds make()'s return value to count. - * prolong_session() adds prolong()'s return value to count. - */ - -#include "../../../../common/memory.h" -#include "../../../../common/memory_block.h" -#include "../../../../common/test_assert.h" -#include "../../controlplane/state/active_sessions.h" -#include "../../dataplane/active_sessions.h" -#include "lib/logging/log.h" -#include -#include - -/* ------------------------------------------------------------------ */ -/* Memory setup helpers */ -/* ------------------------------------------------------------------ */ - -/* - * A small static arena large enough for a few tracker shards. - * sizeof(active_sessions_tracker_shard) = 64 bytes (cache-line aligned). - * 8 shards = 512 bytes; add padding for alignment. - */ -#define TEST_ARENA_SIZE (4096) - -static uint8_t g_arena[TEST_ARENA_SIZE] __attribute__((aligned(64))); - -static void -setup_mctx(struct block_allocator *ba, struct memory_context *mctx) { - block_allocator_init(ba); - block_allocator_put_arena(ba, g_arena, TEST_ARENA_SIZE); - memory_context_init(mctx, "test", ba); -} - -/* ------------------------------------------------------------------ */ -/* Test 1: create initialises all shards to count=0 */ -/* ------------------------------------------------------------------ */ -static int -test_create_initialises_shards(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 3; - const uint32_t now = 0; - - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, now); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - for (size_t i = 0; i < shards; ++i) { - TEST_ASSERT_EQUAL( - (int64_t)tracker[i].count, - 0, - "shard[%zu].count should be 0 after create", - i - ); - } - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 2: new_session increments count by 1 */ -/* ------------------------------------------------------------------ */ -static int -test_new_session_increments_count(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 2; - /* - * now=0, timeout=32: - * now_tick = 0/16 = 0 - * until_tick = (0+32+15)/16 = 47/16 = 2 - * make(0, 2): diff[0]+=1 consumed -> +1; diff[2]-=1 pending - * count += 1 => count = 1 - */ - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - active_sessions_tracker_new_session(tracker, 0, 0, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 1, - "count should be 1 after one new_session" - ); - TEST_ASSERT_EQUAL( - (int64_t)tracker[1].count, 0, "shard[1].count should remain 0" - ); - - /* Second session on shard 0 at the same timestamp. */ - active_sessions_tracker_new_session(tracker, 0, 0, 48); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 2, - "count should be 2 after two new_sessions on shard 0" - ); - - /* Session on shard 1. */ - active_sessions_tracker_new_session(tracker, 1, 0, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[1].count, - 1, - "shard[1].count should be 1 after one new_session" - ); - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 3: session expires and count decrements */ -/* ------------------------------------------------------------------ */ -static int -test_session_expires_decrements_count(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 1; - /* - * Create session at ts=0 with timeout=32: - * now_tick=0, until_tick=(0+32+15)/16=2 - * make(0,2): +1 -> count=1; diff[2]=-1 pending - * - * Advance to ts=32 (tick=2) by creating a new session: - * now_tick=32/16=2, until_tick=(32+32+15)/16=79/16=4 - * make(2,4): advance sweeps [0..2): slot 0 (=0), slot 1 (=0); - * consume slot 2 (=-1) -> change=-1 - * diff[2]+=1 -> 0; diff[4]-=1 - * return 0 + (-1) = -1 - * count += -1 => count = 1 + (-1) = 0 - * Then the new session's +1 is included: actually make returns - * the net change including the new +1. - */ - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - /* Session A: ts=0, timeout=32 -> expires at tick 2 */ - active_sessions_tracker_new_session(tracker, 0, 0, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, 1, "count=1 after session A" - ); - - /* - * Session B at ts=32 (tick=2): expiry of A fires, new B starts. - * Net change = 0 (expiry -1 + new start +1). - * count stays 1. - */ - active_sessions_tracker_new_session(tracker, 0, 32, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 1, - "count=1: A expired(-1) + B started(+1) = net 0" - ); - - /* - * Advance to ts=64 (tick=4): expiry of B fires. - * Session C at ts=64, timeout=32 -> until_tick=(64+32+15)/16=6 - * make(4,6): - * diff[4%8=4] += 1 -> diff[4] = -1+1 = 0 - * diff[6%8=6] -= 1 - * advance(4): sweep [2..4): slot 2 (=0), slot 3 (=0); - * consume slot 4 (=0) -> 0 - * return 0 - * count stays 1. - */ - active_sessions_tracker_new_session(tracker, 0, 64, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 1, - "count=1: B expired(-1) + C started(+1) = net 0" - ); - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 4: prolong_session does not change count at current time */ -/* ------------------------------------------------------------------ */ -static int -test_prolong_session_does_not_change_count(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 1; - /* - * Session at ts=0, timeout=32: - * now_tick=0, until_tick=2 - * count=1 - * - * Prolong at ts=0: prev_timeout=32, new_timeout=64 - * prev_until_tick = (0+32+15)/16 = 2 - * new_until_tick = (0+64+15)/16 = 4 - * prolong(now_tick=0, prev_until=2, new_until=4): - * diff[2]+=1 -> 0 (cancel old expiry) - * diff[4]-=1 - * advance(0): consume slot 0 (=0) -> 0 - * return 0 - * count += 0 => count stays 1 - */ - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - active_sessions_tracker_new_session(tracker, 0, 0, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, 1, "count=1 after new_session" - ); - - active_sessions_tracker_prolong_session(tracker, 0, 0, 32, 0, 64); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 1, - "prolong at same timestamp should not change count" - ); - - /* - * Advance to ts=32 (tick=2): old expiry was moved to tick=4, - * so no expiry fires here. - * Session B at ts=32, timeout=32 -> until_tick=4. - * make(2, 4): - * diff[2]+=1 (=0+1=1); diff[4]-=1 (=-1-1=-2) - * advance(2): sweep [0..2): slot 0 (=0), slot 1 (=0); - * consume slot 2 (=1) -> +1 - * return 0 + 1 = 1 - * count += 1 => count = 2 - * - * No expiry at tick=2 (prolong moved it to tick=4). - */ - active_sessions_tracker_new_session(tracker, 0, 32, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].count, - 2, - "no expiry at old until=tick2; new session starts -> count=2" - ); - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 5: multiple shards are independent */ -/* ------------------------------------------------------------------ */ -static int -test_multiple_shards_are_independent(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 4; - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - /* Add different numbers of sessions to each shard. */ - for (uint32_t w = 0; w < (uint32_t)shards; ++w) { - for (uint32_t s = 0; s <= w; ++s) { - active_sessions_tracker_new_session( - tracker, w, s * 16, 32 - ); - } - } - - /* - * Expected counts with timeout=32 and spacing=16: - * each session lasts 2 tracker ticks, and we start one every tick, - * so at most 2 sessions overlap on a shard. - * - * shard 0: 1 session (ts=0) - * shard 1: 2 sessions (ts=0,16) - * shard 2: 2 sessions (ts=0 expired when ts=32 started) - * shard 3: 2 sessions (steady-state overlap of 2) - */ - static const uint32_t expected[] = {1, 2, 2, 2}; - for (uint32_t w = 0; w < (uint32_t)shards; ++w) { - TEST_ASSERT_EQUAL( - (int64_t)tracker[w].count, - (int64_t)expected[w], - "shard[%u].count should be %u", - w, - expected[w] - ); - } - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 6: last_packet_timestamp is updated by new_session and prolong */ -/* ------------------------------------------------------------------ */ -static int -test_last_packet_timestamp_is_updated(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 1; - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - /* - * new_session at ts=32: - * now_tick=2, until_tick=(32+32+15)/16=4 - * last_packet_timestamp = 32 - */ - active_sessions_tracker_new_session(tracker, 0, 32, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].last_packet_timestamp, - 32, - "last_packet_timestamp should be 32 after new_session(ts=32)" - ); - - /* - * prolong at ts=48 (tick=3): - * prev_until_tick = (32+32+15)/16 = 4 (>= now_tick=3 ✓) - * new_until_tick = (48+32+15)/16 = 5 - * 5 - 3 = 2 < 8 ✓ - * last_packet_timestamp = 48 - */ - active_sessions_tracker_prolong_session(tracker, 0, 32, 32, 48, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].last_packet_timestamp, - 48, - "last_packet_timestamp should be 48 after prolong(now=48)" - ); - - /* - * new_session at ts=80 (tick=5): - * now_tick=5, until_tick=(80+32+15)/16=7 - * last_packet_timestamp = 80 - */ - active_sessions_tracker_new_session(tracker, 0, 80, 32); - TEST_ASSERT_EQUAL( - (int64_t)tracker[0].last_packet_timestamp, - 80, - "last_packet_timestamp should be 80 after new_session(ts=80)" - ); - - active_sessions_tracker_destroy(tracker, shards, &mctx); - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 7: destroy frees memory (balloc_count == bfree_count) */ -/* ------------------------------------------------------------------ */ -static int -test_destroy_frees_memory(void) { - struct block_allocator ba; - struct memory_context mctx; - setup_mctx(&ba, &mctx); - - const size_t shards = 2; - struct active_sessions_tracker_shard *tracker = - active_sessions_tracker_create(&mctx, shards, 0); - TEST_ASSERT_NOT_NULL(tracker, "tracker_create should not return NULL"); - - TEST_ASSERT_EQUAL( - (int64_t)mctx.balloc_count, - 1, - "exactly one allocation should have been made" - ); - TEST_ASSERT_EQUAL( - (int64_t)mctx.bfree_count, 0, "no frees before destroy" - ); - - active_sessions_tracker_destroy(tracker, shards, &mctx); - - TEST_ASSERT_EQUAL( - (int64_t)mctx.bfree_count, 1, "exactly one free after destroy" - ); - TEST_ASSERT_EQUAL( - (int64_t)mctx.balloc_count, - (int64_t)mctx.bfree_count, - "balloc_count should equal bfree_count after destroy" - ); - - return TEST_SUCCESS; -} - -int -main(void) { - log_enable_name("debug"); - - LOG(INFO, "test_create_initialises_shards..."); - TEST_ASSERT_SUCCESS( - test_create_initialises_shards(), - "test_create_initialises_shards failed" - ); - - LOG(INFO, "test_new_session_increments_count..."); - TEST_ASSERT_SUCCESS( - test_new_session_increments_count(), - "test_new_session_increments_count failed" - ); - - LOG(INFO, "test_session_expires_decrements_count..."); - TEST_ASSERT_SUCCESS( - test_session_expires_decrements_count(), - "test_session_expires_decrements_count failed" - ); - - LOG(INFO, "test_prolong_session_does_not_change_count..."); - TEST_ASSERT_SUCCESS( - test_prolong_session_does_not_change_count(), - "test_prolong_session_does_not_change_count failed" - ); - - LOG(INFO, "test_multiple_shards_are_independent..."); - TEST_ASSERT_SUCCESS( - test_multiple_shards_are_independent(), - "test_multiple_shards_are_independent failed" - ); - - LOG(INFO, "test_last_packet_timestamp_is_updated..."); - TEST_ASSERT_SUCCESS( - test_last_packet_timestamp_is_updated(), - "test_last_packet_timestamp_is_updated failed" - ); - - LOG(INFO, "test_destroy_frees_memory..."); - TEST_ASSERT_SUCCESS( - test_destroy_frees_memory(), "test_destroy_frees_memory failed" - ); - - return TEST_SUCCESS; -} diff --git a/modules/balancer/tests/unit/interval_counter.c b/modules/balancer/tests/unit/interval_counter.c deleted file mode 100644 index 4373ac23f..000000000 --- a/modules/balancer/tests/unit/interval_counter.c +++ /dev/null @@ -1,424 +0,0 @@ -/* - * Unit tests for rt_interval_counter. - * - * Ring size = 8 (RT_INTERVAL_COUNTER_RING_SIZE), mask = 7. - * - * The counter stores int32_t diff[8]. make(now, until) writes +1 at - * now%8 and -1 at until%8, then advances past `now` (consuming and - * clearing that slot). The return value is the net change to the - * caller's running total. - * - * prolong(now, prev_until, new_until) moves the -1 from prev_until%8 - * to new_until%8 by writing +1 at prev_until%8 and -1 at new_until%8, - * then advances past `now`. It does NOT change the count at `now`. - * - * When now - last_timestamp >= 8 (stale gap), try_reset() sums all - * remaining diffs, clears the ring, and resets last_timestamp = now. - * That sum is included in the return value of the next make/prolong. - */ - -#include "../../controlplane/state/interval_counter.h" -#include "../../../../common/test_assert.h" -#include "../../dataplane/interval_counter.h" -#include "lib/logging/log.h" -#include - -/* ------------------------------------------------------------------ */ -/* Test 1: each make() at a new timestamp returns +1 */ -/* ------------------------------------------------------------------ */ -static int -test_make_reports_visible_count_changes(void) { - struct rt_interval_counter counter; - - rt_interval_counter_init(&counter, 10); - - /* - * make(10, 13): - * diff[10%8=2] += 1 -> diff[2] = 1 - * diff[13%8=5] -= 1 -> diff[5] = -1 - * advance(10): consume slot 2 -> change = +1, diff[2] = 0 - * return 0 + 1 = 1 - */ - int64_t change = rt_interval_counter_make(&counter, 10, 13); - TEST_ASSERT_EQUAL(change, 1, "make(10,13) should return +1"); - - /* - * make(11, 15): - * diff[11%8=3] += 1 -> diff[3] = 1 - * diff[15%8=7] -= 1 -> diff[7] = -1 - * advance(11): sweep slot 10%8=2 (=0) -> change=0; - * consume slot 11%8=3 (=1) -> change=1, diff[3]=0 - * return 0 + 1 = 1 - */ - change = rt_interval_counter_make(&counter, 11, 15); - TEST_ASSERT_EQUAL(change, 1, "make(11,15) should return +1"); - - /* - * make(12, 16): - * diff[12%8=4] += 1 - * diff[16%8=0] -= 1 - * advance(12): sweep slot 11%8=3 (=0); consume slot 12%8=4 (=1) - * return 1 - */ - change = rt_interval_counter_make(&counter, 12, 16); - TEST_ASSERT_EQUAL(change, 1, "make(12,16) should return +1"); - - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 2: prolong() returns 0; expiry fires at new time, not old */ -/* ------------------------------------------------------------------ */ -static int -test_prolong_moves_expiry_without_changing_current_count(void) { - struct rt_interval_counter counter; - - rt_interval_counter_init(&counter, 20); - - /* - * make(20, 23): - * diff[20%8=4] += 1; diff[23%8=7] -= 1 - * advance(20): consume slot 4 -> +1 - * return 1 - */ - int64_t change = rt_interval_counter_make(&counter, 20, 23); - TEST_ASSERT_EQUAL(change, 1, "make(20,23) should return +1"); - - /* - * prolong(20, prev_until=23, new_until=26): - * diff[23%8=7] += 1 -> diff[7] = -1+1 = 0 (cancel old expiry) - * diff[26%8=2] -= 1 -> diff[2] = -1 (new expiry) - * advance(20): consume slot 20%8=4 (=0) -> change=0 - * return 0 + 0 = 0 - */ - change = rt_interval_counter_prolong(&counter, 20, 23, 26); - TEST_ASSERT_EQUAL(change, 0, "prolong(20,23,26) should return 0"); - - /* - * make(23, 27): a new interval starts at t=23. - * diff[23%8=7] += 1 -> diff[7] = 0+1 = 1 (slot was cleared by - * prolong) diff[27%8=3] -= 1 -> diff[3] = -1 advance(23): sweep - * [20..23): slot 20%8=4 (=0), slot 21%8=5 (=0), slot 22%8=6 (=0) - * consume slot 23%8=7 (=1) -> change=+1, diff[7]=0 - * return 0 + 1 = 1 - * - * The prolong cancelled the -1 expiry at slot 7, so there is no - * expiry firing here. The +1 is purely from the new interval start. - */ - change = rt_interval_counter_make(&counter, 23, 27); - TEST_ASSERT_EQUAL( - change, - 1, - "make(23,27): new interval starts (+1), no expiry at old " - "until=23" - ); - - /* - * make(26, 30): the prolong moved the expiry to slot 26%8=2. - * diff[26%8=2] += 1 -> diff[2] = -1+1 = 0 (cancels prolong expiry) - * diff[30%8=6] -= 1 -> diff[6] = -1 - * advance(26): sweep [23..26): - * slot 23%8=7 (=0), slot 24%8=0 (=0), slot 25%8=1 (=0) - * consume slot 26%8=2 (=0) -> change=0 - * return 0 + 0 = 0 - * - * The +1 from make(26,30) cancels the prolong's -1 at slot 2. - * Net = 0: the prolonged interval expired (-1) and a new one started - * (+1). - */ - change = rt_interval_counter_make(&counter, 26, 30); - TEST_ASSERT_EQUAL( - change, 0, "at t=26: prolong expiry(-1) + new interval(+1) = 0" - ); - - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 3: multiple make() at the same now each return +1 */ -/* ------------------------------------------------------------------ */ -static int -test_same_timestamp_operations_accumulate_once(void) { - struct rt_interval_counter counter; - - rt_interval_counter_init(&counter, 30); - - /* - * Three intervals all starting at now=30. - * Each make() writes +1 at slot 30%8=6, then advance() consumes - * and clears slot 6 immediately, so each call returns +1. - * The -1 expiry slots accumulate independently at slots 3, 4, 5. - * - * 35%8=3, 36%8=4, 37%8=5 - */ - int64_t change = rt_interval_counter_make(&counter, 30, 35); - TEST_ASSERT_EQUAL(change, 1, "1st make(30,35) should return +1"); - - change = rt_interval_counter_make(&counter, 30, 36); - TEST_ASSERT_EQUAL(change, 1, "2nd make(30,36) should return +1"); - - change = rt_interval_counter_make(&counter, 30, 37); - TEST_ASSERT_EQUAL(change, 1, "3rd make(30,37) should return +1"); - - /* - * State: diff[3]=-1, diff[4]=-1, diff[5]=-1, last_timestamp=30. - * - * Jump far ahead (stale gap >= 8) to flush all three pending - * expiries at once via try_reset, then start a new interval. - * - * make(200, 202): gap=200-30=170 >= 8 -> try_reset fires. - * try_reset: sum all diffs: diff[3]+diff[4]+diff[5] = -3 - * clear ring; last_timestamp=200; return -3 - * diff[200%8=0] += 1; diff[202%8=2] -= 1 - * advance(200): consume slot 0 (=1) -> +1 - * return -3 + 1 = -2 - * - * Net -2: three intervals expired (-3) and one new started (+1). - */ - change = rt_interval_counter_make(&counter, 200, 202); - TEST_ASSERT_EQUAL( - change, - -2, - "stale gap flushes 3 pending expiries(-3) + new interval(+1) = " - "-2" - ); - - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 4: stale gap flushes pending diffs before the new interval */ -/* ------------------------------------------------------------------ */ -static int -test_stale_gap_returns_pending_sum_before_new_interval(void) { - struct rt_interval_counter counter; - const uint32_t now = 100; - - rt_interval_counter_init(&counter, now); - - /* - * make(100, 103): - * diff[100%8=4] += 1; diff[103%8=7] -= 1 - * advance(100): consume slot 4 -> +1 - * return 1 - */ - int64_t change = rt_interval_counter_make(&counter, 100, 103); - TEST_ASSERT_EQUAL(change, 1, "make(100,103) should return +1"); - - /* - * make(200, 205): gap = 200-100 = 100 >= 8 -> try_reset fires. - * try_reset: sum all diffs: diff[7]=-1, rest 0 -> sum=-1 - * clear ring; last_timestamp=200; return -1 - * diff[200%8=0] += 1; diff[205%8=5] -= 1 - * advance(200): consume slot 0 (=1) -> +1 - * return -1 + 1 = 0 - * - * Net 0: the pending expiry (-1) and the new interval (+1) cancel. - */ - change = rt_interval_counter_make(&counter, 200, 205); - TEST_ASSERT_EQUAL( - change, - 0, - "stale gap: pending expiry(-1) + new interval(+1) = 0" - ); - - /* - * make(201, 206): no stale gap (201-200=1 < 8). - * diff[201%8=1] += 1; diff[206%8=6] -= 1 - * advance(201): sweep slot 200%8=0 (=0); consume slot 1 (=1) -> +1 - * return 0 + 1 = 1 - */ - change = rt_interval_counter_make(&counter, 201, 206); - TEST_ASSERT_EQUAL( - change, 1, "make(201,206) after reset should return +1" - ); - - /* - * Another large jump: make(400, 405). - * gap = 400-201 = 199 >= 8 -> try_reset fires. - * Remaining diffs: diff[205%8=5]=-1, diff[206%8=6]=-1, rest 0 - * sum = -2; clear; last_timestamp=400 - * diff[400%8=0] += 1; diff[405%8=5] -= 1 - * advance(400): consume slot 0 (=1) -> +1 - * return -2 + 1 = -1 - */ - change = rt_interval_counter_make(&counter, 400, 405); - TEST_ASSERT_EQUAL( - change, - -1, - "second stale gap: two pending expiries(-2) + new interval(+1) " - "= -1" - ); - - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 5: init zeroes the ring and sets last_timestamp correctly */ -/* ------------------------------------------------------------------ */ -static int -test_init_zeroes_ring_and_sets_timestamp(void) { - struct rt_interval_counter counter; - - /* Dirty the memory first to ensure init actually zeroes it. */ - for (size_t i = 0; i < RT_INTERVAL_COUNTER_RING_SIZE; ++i) { - counter.diff[i] = (int32_t)(i + 1); - } - counter.last_timestamp = 0xDEADBEEFu; - - rt_interval_counter_init(&counter, 42); - - TEST_ASSERT_EQUAL( - (int64_t)counter.last_timestamp, - 42, - "last_timestamp should be set to now=42" - ); - - for (size_t i = 0; i < RT_INTERVAL_COUNTER_RING_SIZE; ++i) { - TEST_ASSERT_EQUAL( - (int64_t)counter.diff[i], - 0, - "diff[%zu] should be zeroed after init", - i - ); - } - - return TEST_SUCCESS; -} - -/* ------------------------------------------------------------------ */ -/* Test 6: expiry fires at the exact timestamp, not before or after */ -/* ------------------------------------------------------------------ */ -static int -test_expiry_fires_at_correct_timestamp(void) { - struct rt_interval_counter counter; - - /* - * Constraint: until - now < 8 (ring size). - * Use short intervals of length 3 and 4. - * - * Interval A: [10, 13) diff[10%8=2]+=1 consumed, diff[13%8=5]-=1 - * Interval B: [11, 14) diff[11%8=3]+=1 consumed, diff[14%8=6]-=1 - */ - rt_interval_counter_init(&counter, 10); - - int64_t change = rt_interval_counter_make(&counter, 10, 13); - TEST_ASSERT_EQUAL(change, 1, "make(10,13) should return +1"); - - change = rt_interval_counter_make(&counter, 11, 14); - TEST_ASSERT_EQUAL(change, 1, "make(11,14) should return +1"); - - /* - * Advance to t=12 (before either expiry) via make(12, 15): - * 15%8=7, safe (not 5 or 6). - * diff[12%8=4]+=1; diff[15%8=7]-=1 - * advance(12): sweep slot 11%8=3 (=0); consume slot 12%8=4 (=1) -> +1 - * return 0 + 1 = 1 - * A new interval starts; no expiry yet. - */ - change = rt_interval_counter_make(&counter, 12, 15); - TEST_ASSERT_EQUAL( - change, 1, "make(12,15): new interval, no expiry yet" - ); - - /* - * Advance to t=13 (expiry of A) via make(13, 16): - * 16%8=0, safe (not 5,6,7). - * diff[13%8=5]+=1 -> diff[5]=-1+1=0 (cancel A's expiry) - * diff[16%8=0]-=1 - * advance(13): sweep slot 12%8=4 (=0); consume slot 13%8=5 (=0) -> 0 - * return 0 - * Net = 0: A expired (-1) and new interval started (+1). - */ - change = rt_interval_counter_make(&counter, 13, 16); - TEST_ASSERT_EQUAL( - change, 0, "at t=13: expiry of A(-1) + new start(+1) = 0" - ); - - /* - * Advance to t=14 (expiry of B) via make(14, 17): - * 17%8=1, safe (not 6,7,0). - * diff[14%8=6]+=1 -> diff[6]=-1+1=0 (cancel B's expiry) - * diff[17%8=1]-=1 - * advance(14): sweep slot 13%8=5 (=0); consume slot 14%8=6 (=0) -> 0 - * return 0 - * Net = 0: B expired (-1) and new interval started (+1). - */ - change = rt_interval_counter_make(&counter, 14, 17); - TEST_ASSERT_EQUAL( - change, 0, "at t=14: expiry of B(-1) + new start(+1) = 0" - ); - - /* - * Advance to t=15 (expiry of make(12,15)) via make(15, 18): - * 18%8=2, safe (not 7,1). - * diff[15%8=7]+=1 -> diff[7]=-1+1=0 (cancel make(12,15) expiry) - * diff[18%8=2]-=1 - * advance(15): sweep slot 14%8=6 (=0); consume slot 15%8=7 (=0) -> 0 - * return 0 - */ - change = rt_interval_counter_make(&counter, 15, 18); - TEST_ASSERT_EQUAL( - change, - 0, - "at t=15: expiry of make(12,15) (-1) + new start(+1) = 0" - ); - - /* - * Verify no spurious expiry between t=13 and t=14: - * Use a fresh counter and advance from t=12 to t=13 without - * starting a new interval — observe only the expiry. - * - * Since make() always starts a new interval, we can't call it - * without the +1 side effect. Instead, verify by checking that - * make(13, 16) returns 0 (not -1), confirming the expiry and - * start cancel exactly. This was already asserted above. - */ - - return TEST_SUCCESS; -} - -int -main(void) { - log_enable_name("debug"); - - LOG(INFO, "test_make_reports_visible_count_changes..."); - TEST_ASSERT_SUCCESS( - test_make_reports_visible_count_changes(), - "test_make_reports_visible_count_changes failed" - ); - - LOG(INFO, - "test_prolong_moves_expiry_without_changing_current_count..."); - TEST_ASSERT_SUCCESS( - test_prolong_moves_expiry_without_changing_current_count(), - "test_prolong_moves_expiry_without_changing_current_count " - "failed" - ); - - LOG(INFO, "test_same_timestamp_operations_accumulate_once..."); - TEST_ASSERT_SUCCESS( - test_same_timestamp_operations_accumulate_once(), - "test_same_timestamp_operations_accumulate_once failed" - ); - - LOG(INFO, "test_stale_gap_returns_pending_sum_before_new_interval..."); - TEST_ASSERT_SUCCESS( - test_stale_gap_returns_pending_sum_before_new_interval(), - "test_stale_gap_returns_pending_sum_before_new_interval failed" - ); - - LOG(INFO, "test_init_zeroes_ring_and_sets_timestamp..."); - TEST_ASSERT_SUCCESS( - test_init_zeroes_ring_and_sets_timestamp(), - "test_init_zeroes_ring_and_sets_timestamp failed" - ); - - LOG(INFO, "test_expiry_fires_at_correct_timestamp..."); - TEST_ASSERT_SUCCESS( - test_expiry_fires_at_correct_timestamp(), - "test_expiry_fires_at_correct_timestamp failed" - ); - - return TEST_SUCCESS; -} diff --git a/modules/balancer/tests/unit/meson.build b/modules/balancer/tests/unit/meson.build deleted file mode 100644 index 4cdfc4823..000000000 --- a/modules/balancer/tests/unit/meson.build +++ /dev/null @@ -1,33 +0,0 @@ -dependencies = [lib_balancer_state_dep] -includes = include_directories('../../..', '../../../', '../../../../') - -interval_counter_test = executable( - 'interval_counter', - [ - 'interval_counter.c', - '../../controlplane/state/interval_counter.c', - ], - c_args: yanet_test_c_args, - link_args: yanet_link_args, - include_directories: includes, - dependencies: dependencies, -) -test( - 'interval_counter', - interval_counter_test, - suite: ['balancer', 'unit'] -) - -active_sessions_tracker_test = executable( - 'active_sessions_tracker', - ['active_sessions_tracker.c',], - c_args: yanet_test_c_args, - link_args: yanet_link_args, - include_directories: includes, - dependencies: dependencies, -) -test( - 'active_sessions_tracker', - active_sessions_tracker_test, - suite: ['balancer', 'unit'] -) \ No newline at end of file diff --git a/modules/route/cli/route/src/main.rs b/modules/route/cli/route/src/main.rs index 928bdc53f..993be6761 100644 --- a/modules/route/cli/route/src/main.rs +++ b/modules/route/cli/route/src/main.rs @@ -242,7 +242,7 @@ impl RouteService { let mut entries = response.routes.into_iter().map(RouteEntry::from).collect::>(); - entries.sort_by(|a, b| a.prefix.0.cmp(&b.prefix.0)); + entries.sort_by_key(|a| a.prefix.0); print_table(entries); diff --git a/tests/controlplane/go/update_test.go b/tests/controlplane/go/update_test.go index 3df029ecd..90fe667d6 100644 --- a/tests/controlplane/go/update_test.go +++ b/tests/controlplane/go/update_test.go @@ -2,16 +2,53 @@ package controlplane_test import ( "fmt" - "net/netip" "testing" "github.com/c2h5oh/datasize" "github.com/stretchr/testify/require" "github.com/yanet-platform/yanet2/controlplane/ffi" mock "github.com/yanet-platform/yanet2/mock/go" - balancer "github.com/yanet-platform/yanet2/modules/balancer/agent/go/ffi" + balancer "github.com/yanet-platform/yanet2/modules/balancer/controlplane" + "github.com/yanet-platform/yanet2/modules/balancer/controlplane/balancerpb" + "go.uber.org/zap" + "google.golang.org/protobuf/types/known/durationpb" ) +func minimalBalancerConfig() *balancerpb.BalancerConfig { + var ( + capacity uint64 = 1 + mlf float32 = 0.5 + power uint64 = 0 + maxWeight uint32 = 0 + ) + return &balancerpb.BalancerConfig{ + PacketHandler: &balancerpb.PacketHandlerConfig{ + SourceAddressV4: []byte{1, 1, 1, 1}, + SourceAddressV6: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + }, + DecapAddresses: [][]byte{}, + SessionsTimeouts: &balancerpb.SessionsTimeouts{ + TcpSynAck: 25, + TcpSyn: 20, + TcpFin: 15, + Tcp: 60, + Udp: 30, + }, + }, + State: &balancerpb.StateConfig{ + SessionTableCapacity: &capacity, + SessionTableMaxLoadFactor: &mlf, + Wlc: &balancerpb.WlcConfig{ + Power: &power, + MaxWeight: &maxWeight, + }, + RefreshPeriod: durationpb.New(0), + }, + } +} + func TestControlplaneUpdates(t *testing.T) { config := mock.YanetMockConfig{ AgentsMemory: 128 * datasize.MB, @@ -115,24 +152,15 @@ func TestControlplaneUpdates(t *testing.T) { require.Equal(t, []string{"function0", "function1"}, pipelines[0].Functions) }) - balancerAgent, err := balancer.NewBalancerAgent(yanet.SharedMemory(), uint(64*datasize.MB)) + balancerAgent, err := balancer.AttachNewAgent(yanet.SharedMemory(), 0, 64*datasize.MB) require.NoError(t, err, "failed to create balancer agent") // Register only balancer0 first, leaving balancer1 unresolved in function0. t.Run("RegisterBalancer0", func(t *testing.T) { - config := balancer.BalancerManagerConfig{ - Balancer: balancer.BalancerConfig{ - Handler: balancer.PacketHandlerConfig{ - SourceV4: netip.MustParseAddr("1.1.1.1"), - SourceV6: netip.MustParseAddr("::1"), - }, - State: balancer.StateConfig{ - TableCapacity: 1, - }, - }, - } - _, err := balancerAgent.NewManager("balancer0", &config) - require.NoError(t, err, "failed to create balancer manager") + log := zap.NewNop().Sugar() + b, err := balancer.NewBalancer(balancerAgent, "balancer0", minimalBalancerConfig(), log) + require.NoError(t, err, "failed to create balancer") + balancerAgent.PutBalancer("balancer0", b) }) // Add function1 with a fully defined module config so pipeline0 can reference it. @@ -211,19 +239,10 @@ func TestControlplaneUpdates(t *testing.T) { // Register balancer1 so all module references used by pipeline0 become valid. t.Run("RegisterBalancer1", func(t *testing.T) { - config := balancer.BalancerManagerConfig{ - Balancer: balancer.BalancerConfig{ - Handler: balancer.PacketHandlerConfig{ - SourceV4: netip.MustParseAddr("1.1.1.1"), - SourceV6: netip.MustParseAddr("::1"), - }, - State: balancer.StateConfig{ - TableCapacity: 1, - }, - }, - } - _, err := balancerAgent.NewManager("balancer1", &config) - require.NoError(t, err, "failed to create balancer manager") + log := zap.NewNop().Sugar() + b, err := balancer.NewBalancer(balancerAgent, "balancer1", minimalBalancerConfig(), log) + require.NoError(t, err, "failed to create balancer") + balancerAgent.PutBalancer("balancer1", b) }) // Device linking should also fail when the referenced input pipeline does not exist. diff --git a/tests/functional/balancer_test.sh b/tests/functional/balancer_test.sh index 5ab2e6778..ef7abf616 100755 --- a/tests/functional/balancer_test.sh +++ b/tests/functional/balancer_test.sh @@ -32,5 +32,3 @@ sleep 3 /mnt/target/release/yanet-cli-pipeline update --name=dummy /mnt/target/release/yanet-cli-device-plain update --name=01:00.0 --input test:1 --output dummy:1 - -/mnt/target/release/yanet-cli-balancer stats --name=balancer0 --device=01:00.0 --pipeline=test --function=test --chain=ch0 \ No newline at end of file diff --git a/tests/functional/main/balancer_test.go b/tests/functional/main/balancer_test.go index e4c72a2a9..58464804c 100644 --- a/tests/functional/main/balancer_test.go +++ b/tests/functional/main/balancer_test.go @@ -11,7 +11,7 @@ import ( "github.com/yanet-platform/yanet2/tests/functional/framework" ) -func createTcpPacket(srcIP, dstIP net.IP, srcPort, dstPort int, payload []byte, SYN bool) []byte { +func createTCPPacket(srcIP, dstIP net.IP, srcPort, dstPort int, payload []byte, SYN bool) []byte { eth := layers.Ethernet{ SrcMAC: framework.MustParseMAC(framework.SrcMAC), DstMAC: framework.MustParseMAC(framework.DstMAC), @@ -72,9 +72,6 @@ func TestBalancer(t *testing.T) { // Configure devices "/mnt/target/release/yanet-cli-device-plain update --name=01:00.0 --input test:1 --output dummy:1", - // Show config stats - "/mnt/target/release/yanet-cli-balancer stats --name=balancer0 --device=01:00.0 --pipeline=test --function=test --chain=ch0", - // Enable single real "/mnt/target/release/yanet-cli-balancer reals enable --name=balancer0 --vs 192.0.2.1:80/tcp --reals 10.1.1.1", "/mnt/target/release/yanet-cli-balancer reals flush --name=balancer0", @@ -85,7 +82,7 @@ func TestBalancer(t *testing.T) { }) fw.Run("Test_IPv4_Packet", func(fw *framework.F, t *testing.T) { - packet := createTcpPacket( + packet := createTCPPacket( net.ParseIP("192.168.2.2"), net.ParseIP("192.0.2.1"), 12345,