diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index ef4f4ec3..596b30b4 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "lists.go", "math.go", "native.go", + "network.go", "protos.go", "regex.go", "sets.go", @@ -60,6 +61,7 @@ go_test( "lists_test.go", "math_test.go", "native_test.go", + "network_test.go", "protos_test.go", "regex_test.go", "sets_test.go", diff --git a/ext/network.go b/ext/network.go new file mode 100644 index 00000000..41030ea0 --- /dev/null +++ b/ext/network.go @@ -0,0 +1,635 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +const ( + // Version1 is the initial version of the Network library, providing + // parity with Kubernetes v1.30+ CEL network functions. + Version1 uint32 = 1 +) + +// Network returns a cel.EnvOption to configure extended functions for network +// address parsing, inspection, and CIDR range manipulation. +// +// Note: This library defines global functions `ip`, `cidr`, `isIP`, `isCIDR` +// and `ip.isCanonical`. If you are currently using variables named `ip` or +// `cidr`, these functions will likely work as intended, however there is a +// chance for collision. +// +// The library closely mirrors the behavior of the Kubernetes CEL network +// libraries, treating IP addresses and CIDR ranges as opaque types. It parses +// IPs strictly: IPv4-mapped IPv6 addresses and IP zones are not allowed. +// +// This library includes a TypeAdapter that allows `netip.Addr` and +// `netip.Prefix` Go types to be passed directly into the CEL environment. +// +// # IP Addresses +// +// The `ip` function converts a string to an IP address (IPv4 or IPv6). If the +// string is not a valid IP, an error is returned. The `isIP` function checks +// if a string is a valid IP address without throwing an error. +// +// ip(string) -> ip +// isIP(string) -> bool +// +// Examples: +// +// ip('127.0.0.1') +// ip('::1') +// isIP('1.2.3.4') // true +// isIP('invalid') // false +// +// # CIDR Ranges +// +// The `cidr` function converts a string to a Classless Inter-Domain Routing +// (CIDR) range. If the string is not valid, an error is returned. +// +// The `isCIDR` function checks if a string is a valid CIDR notation. Note that +// `isCIDR` is "loose" and allows CIDRs with non-zero host bits (e.g., +// '10.0.0.1/8'). For strict validation of subnets, use `isStrictCIDR`. +// +// The `isStrictCIDR` function checks if a string is a valid canonical CIDR +// (no host bits). +// +// The `isInterfaceAddress` function is an alias for `isCIDR` that explicitly +// signifies the intent to allow host bits. +// +// cidr(string) -> cidr +// isCIDR(string) -> bool +// isStrictCIDR(string) -> bool +// isInterfaceAddress(string) -> bool +// +// Examples: +// +// cidr('192.168.0.0/24') +// cidr('::1/128') +// isCIDR('10.0.0.0/8') // true +// isStrictCIDR('10.0.0.1/8') // false +// isInterfaceAddress('10.0.0.1/8') // true +// +// # IP Inspection and Canonicalization +// +// IP objects support various inspection methods. +// +// .family() -> int +// .isLoopback() -> bool +// .isGlobalUnicast() -> bool +// .isLinkLocalMulticast() -> bool +// .isLinkLocalUnicast() -> bool +// .isUnspecified() -> bool +// +// The `ip.isCanonical` function takes a string and returns true if it matches +// the RFC 5952 canonical string representation of that address. +// +// ip.isCanonical(string) -> bool +// +// Examples: +// +// ip('127.0.0.1').family() == 4 +// ip('::1').family() == 6 +// ip('127.0.0.1').isLoopback() == true +// ip.isCanonical('2001:db8::1') == true // RFC 5952 format +// ip.isCanonical('2001:DB8::1') == false // Uppercase is not canonical +// ip.isCanonical('2001:db8:0:0:0:0:0:1') == false // Expanded is not canonical +// +// # CIDR Member Functions +// +// CIDR objects support containment checks and property extraction. +// +// .containsIP(ip|string) -> bool +// .containsCIDR(cidr|string) -> bool +// .ip() -> ip +// .masked() -> cidr +// .prefixLength() -> int +// +// Examples: +// +// cidr('10.0.0.0/8').containsIP(ip('10.0.0.1')) == true +// cidr('10.0.0.0/8').containsIP('10.0.0.1') == true +// cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16') == true +// cidr('192.168.1.5/24').ip() == ip('192.168.1.5') +// cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24') +// cidr('192.168.1.0/24').prefixLength() == 24 +func Network(opts ...NetworkOption) cel.EnvOption { + lib := &networkLib{version: Version1} + for _, o := range opts { + lib = o(lib) + } + return func(e *cel.Env) (*cel.Env, error) { + // Install the library (Types and Functions) + e, err := cel.Lib(lib)(e) + if err != nil { + return nil, err + } + + // Install the Adapter (Wrapping the existing one) + adapter := &networkAdapter{Adapter: e.CELTypeAdapter()} + return cel.CustomTypeAdapter(adapter)(e) + } +} + +// NetworkOption declares a functional operator for configuring the Network library behavior. +type NetworkOption func(*networkLib) *networkLib + +// NetworkVersion sets the version of the network library to an explicit version. +func NetworkVersion(version uint32) NetworkOption { + return func(lib *networkLib) *networkLib { + lib.version = version + return lib + } +} + +const ( + // Function names matching the original Kubernetes implementation of this networking library. + // isStrictCIDR and isInterfaceAddress are added to enable strict isCIDR parsing without breaking + // functionality for existing users. Ctx: https://github.com/kubernetes/kubernetes/issues/134224 + cidrFunc = "cidr" + cidrToString = "string" + containsCIDRFunc = "containsCIDR" + containsIPFunc = "containsIP" + familyFunc = "family" + ipFunc = "ip" + ipToString = "string" + isCanonicalFunc = "ip.isCanonical" + isCIDRFunc = "isCIDR" + isGlobalUnicastFunc = "isGlobalUnicast" + isInterfaceAddrFunc = "isInterfaceAddress" + isIPFunc = "isIP" + isLinkLocalMcastFunc = "isLinkLocalMulticast" + isLinkLocalUcastFunc = "isLinkLocalUnicast" + isLoopbackFunc = "isLoopback" + isStrictCIDRFunc = "isStrictCIDR" + isUnspecifiedFunc = "isUnspecified" + maskedFunc = "masked" + prefixLengthFunc = "prefixLength" +) + +var ( + // Definitions for the Opaque Types + IPType = types.NewOpaqueType("net.IP") + CIDRType = types.NewOpaqueType("net.CIDR") +) + +type networkLib struct { + version uint32 +} + +func (*networkLib) LibraryName() string { + return "cel.lib.ext.network" +} + +func (*networkLib) CompileOptions() []cel.EnvOption { + return []cel.EnvOption{ + // 1. Register Types + cel.Types( + IPType, + CIDRType, + ), + + // 2. Register Functions + cel.Function(cidrFunc, + // K8s Parity: Following the pattern, this is "string_to_cidr" + cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, CIDRType, + cel.UnaryBinding(netCIDRString)), + ), + cel.Function(cidrToString, + cel.Overload("cidr_to_string", []*cel.Type{CIDRType}, cel.StringType, + cel.UnaryBinding(netCIDRToString)), + ), + cel.Function(containsCIDRFunc, + cel.MemberOverload("cidr_contains_cidr", []*cel.Type{CIDRType, CIDRType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDR)), + cel.MemberOverload("cidr_contains_cidr_string", []*cel.Type{CIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsCIDRString)), + ), + cel.Function(containsIPFunc, + cel.MemberOverload("cidr_contains_ip_ip", []*cel.Type{CIDRType, IPType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIP)), + cel.MemberOverload("cidr_contains_ip_string", []*cel.Type{CIDRType, cel.StringType}, cel.BoolType, + cel.BinaryBinding(netCIDRContainsIPString)), + ), + cel.Function(familyFunc, + cel.MemberOverload("ip_family", []*cel.Type{IPType}, cel.IntType, + cel.UnaryBinding(netIPFamily)), + ), + cel.Function(ipFunc, + // K8s Parity: The global overload is named "string_to_ip" + cel.Overload("string_to_ip", []*cel.Type{cel.StringType}, IPType, + cel.UnaryBinding(netIPString)), + // K8s Parity: The member overload is named "cidr_ip" + cel.MemberOverload("cidr_ip", []*cel.Type{CIDRType}, IPType, + cel.UnaryBinding(netCIDRIP)), + ), + cel.Function(ipToString, + cel.Overload("ip_to_string", []*cel.Type{IPType}, cel.StringType, + cel.UnaryBinding(netIPToString)), + ), + cel.Function(isCanonicalFunc, + cel.Overload("ip_is_canonical", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIPIsCanonical)), + ), + cel.Function(isCIDRFunc, + cel.Overload("is_cidr", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsCIDR)), + ), + cel.Function(isGlobalUnicastFunc, + cel.MemberOverload("ip_is_global_unicast", []*cel.Type{IPType}, cel.BoolType, + cel.UnaryBinding(netIPIsGlobalUnicast)), + ), + cel.Function(isInterfaceAddrFunc, + cel.Overload("is_interface_address", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsCIDR)), + ), + cel.Function(isIPFunc, + cel.Overload("is_ip", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsIP)), + ), + cel.Function(isLinkLocalMcastFunc, + cel.MemberOverload("ip_is_link_local_multicast", []*cel.Type{IPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalMulticast)), + ), + cel.Function(isLinkLocalUcastFunc, + cel.MemberOverload("ip_is_link_local_unicast", []*cel.Type{IPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLinkLocalUnicast)), + ), + cel.Function(isLoopbackFunc, + cel.MemberOverload("ip_is_loopback", []*cel.Type{IPType}, cel.BoolType, + cel.UnaryBinding(netIPIsLoopback)), + ), + cel.Function(isStrictCIDRFunc, + cel.Overload("is_strict_cidr", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(netIsStrictCIDR)), + ), + cel.Function(isUnspecifiedFunc, + cel.MemberOverload("ip_is_unspecified", []*cel.Type{IPType}, cel.BoolType, + cel.UnaryBinding(netIPIsUnspecified)), + ), + cel.Function(maskedFunc, + cel.MemberOverload("cidr_masked", []*cel.Type{CIDRType}, CIDRType, + cel.UnaryBinding(netCIDRMasked)), + ), + cel.Function(prefixLengthFunc, + cel.MemberOverload("cidr_prefix_length", []*cel.Type{CIDRType}, cel.IntType, + cel.UnaryBinding(netCIDRPrefixLength)), + ), + cel.ASTValidators( + networkFormatValidator{funcName: ipFunc, argNum: 0, check: checkIP}, + networkFormatValidator{funcName: cidrFunc, argNum: 0, check: checkCIDR}, + ), + } +} + +func (*networkLib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +// networkAdapter adapts netip types while preserving existing adapters. +type networkAdapter struct { + types.Adapter +} + +func (a *networkAdapter) NativeToValue(value any) ref.Val { + switch v := value.(type) { + case netip.Addr: + return IP{Addr: v} + case netip.Prefix: + return CIDR{Prefix: v} + } + // Delegate to the wrapped adapter (e.g., Protobuf adapter) + return a.Adapter.NativeToValue(value) +} + +// --- Implementation Logic --- + +func netCIDRContainsCIDR(lhs, rhs ref.Val) ref.Val { + parent := lhs.(CIDR) + child := rhs.(CIDR) + return types.Bool(parent.Prefix.Overlaps(child.Prefix) && parent.Prefix.Bits() <= child.Prefix.Bits()) +} + +func netCIDRContainsCIDRString(lhs, rhs ref.Val) ref.Val { + parent := lhs.(CIDR) + s := rhs.(types.String) + childPrefix, err := parseCIDR(string(s)) + if err != nil { + return types.WrapErr(err) + } + return types.Bool(parent.Prefix.Overlaps(childPrefix) && parent.Prefix.Bits() <= childPrefix.Bits()) +} + +func netCIDRContainsIP(lhs, rhs ref.Val) ref.Val { + cidr := lhs.(CIDR) + ip := rhs.(IP) + return types.Bool(cidr.Prefix.Contains(ip.Addr)) +} + +func netCIDRContainsIPString(lhs, rhs ref.Val) ref.Val { + cidr := lhs.(CIDR) + s := rhs.(types.String) + addr, err := parseIPAddr(string(s)) + if err != nil { + return types.WrapErr(err) + } + return types.Bool(cidr.Prefix.Contains(addr)) +} + +func netCIDRIP(val ref.Val) ref.Val { + cidr := val.(CIDR) + return IP{Addr: cidr.Prefix.Addr()} +} + +func netCIDRMasked(val ref.Val) ref.Val { + cidr := val.(CIDR) + return CIDR{Prefix: cidr.Prefix.Masked()} +} + +func netCIDRPrefixLength(val ref.Val) ref.Val { + cidr := val.(CIDR) + return types.Int(cidr.Prefix.Bits()) +} + +func netCIDRString(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + prefix, err := parseCIDR(str) + if err != nil { + return types.WrapErr(err) + } + return CIDR{Prefix: prefix} +} + +func netCIDRToString(val ref.Val) ref.Val { + cidr := val.(CIDR) + return types.String(cidr.Prefix.String()) +} + +func netIPFamily(val ref.Val) ref.Val { + ip := val.(IP) + if ip.Addr.Is4() { + return types.Int(4) + } + return types.Int(6) +} + +func netIPIsCanonical(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + addr, err := parseIPAddr(str) + if err != nil { + return types.WrapErr(err) + } + return types.Bool(addr.String() == str) +} + +func netIPIsGlobalUnicast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsGlobalUnicast()) +} + +func netIPIsLinkLocalMulticast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalMulticast()) +} + +func netIPIsLinkLocalUnicast(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLinkLocalUnicast()) +} + +func netIPIsLoopback(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsLoopback()) +} + +func netIPIsUnspecified(val ref.Val) ref.Val { + ip := val.(IP) + return types.Bool(ip.Addr.IsUnspecified()) +} + +func netIPString(val ref.Val) ref.Val { + s := val.(types.String) + str := string(s) + addr, err := parseIPAddr(str) + if err != nil { + return types.WrapErr(err) + } + return IP{Addr: addr} +} + +func netIPToString(val ref.Val) ref.Val { + ip := val.(IP) + return types.String(ip.Addr.String()) +} + +func netIsCIDR(val ref.Val) ref.Val { + s := val.(types.String) + _, err := parseCIDR(string(s)) + return types.Bool(err == nil) +} + +func netIsIP(val ref.Val) ref.Val { + s := val.(types.String) + _, err := parseIPAddr(string(s)) + return types.Bool(err == nil) +} + +func netIsStrictCIDR(val ref.Val) ref.Val { + s := val.(types.String) + prefix, err := parseCIDR(string(s)) + if err != nil { + return types.False + } + // Strict check: address must match its masked version (no host bits) + return types.Bool(prefix.Addr() == prefix.Masked().Addr()) +} + +func parseCIDR(raw string) (netip.Prefix, error) { + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return netip.Prefix{}, fmt.Errorf("CIDR %q parse error during conversion from string: %v", raw, err) + } + if prefix.Addr().Zone() != "" { + return netip.Prefix{}, fmt.Errorf("CIDR %q with zone value is not allowed", raw) + } + if prefix.Addr().Is4In6() { + return netip.Prefix{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw) + } + return prefix, nil +} + +func parseIPAddr(raw string) (netip.Addr, error) { + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("IP Address %q parse error during conversion from string: %v", raw, err) + } + if addr.Zone() != "" { + return netip.Addr{}, fmt.Errorf("IP address %q with zone value is not allowed", raw) + } + if addr.Is4In6() { + return netip.Addr{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw) + } + return addr, nil +} + +// --- Opaque Type Wrappers --- + +type IP struct { + netip.Addr +} + +// ConvertToNative converts the IP value to a native Go type. +func (i IP) ConvertToNative(typeDesc reflect.Type) (any, error) { + if typeDesc == reflect.TypeFor[netip.Addr]() { + return i.Addr, nil + } + if typeDesc.Kind() == reflect.String { + return i.Addr.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +// ConvertToType converts the IP value to a CEL type. +func (i IP) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(i.Addr.String()) + case IPType: + return i + case types.TypeType: + return IPType + } + return types.NewErr("type conversion error from '%s' to '%s'", IPType, typeValue) +} + +// Equal returns true if this IP is equal to the other ref.Val. +func (i IP) Equal(other ref.Val) ref.Val { + o, ok := other.(IP) + if !ok { + return types.False + } + return types.Bool(i.Addr == o.Addr) +} + +// Type returns the CEL type of the IP. +func (i IP) Type() ref.Type { + return IPType +} + +// Value returns the raw Go value (netip.Addr) of the IP. +func (i IP) Value() any { + return i.Addr +} + +type CIDR struct { + netip.Prefix +} + +// ConvertToNative converts the CIDR value to a native Go type. +func (c CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) { + if typeDesc == reflect.TypeFor[netip.Prefix]() { + return c.Prefix, nil + } + if typeDesc.Kind() == reflect.String { + return c.Prefix.String(), nil + } + return nil, fmt.Errorf("unsupported type conversion to '%v'", typeDesc) +} + +// ConvertToType converts the CIDR value to a CEL type. +func (c CIDR) ConvertToType(typeValue ref.Type) ref.Val { + switch typeValue { + case types.StringType: + return types.String(c.Prefix.String()) + case CIDRType: + return c + case types.TypeType: + return CIDRType + } + return types.NewErr("type conversion error from '%s' to '%s'", CIDRType, typeValue) +} + +// Equal returns true if this CIDR is equal to the other ref.Val. +func (c CIDR) Equal(other ref.Val) ref.Val { + o, ok := other.(CIDR) + if !ok { + return types.False + } + return types.Bool(c.Prefix == o.Prefix) +} + +// Type returns the CEL type of the CIDR. +func (c CIDR) Type() ref.Type { + return CIDRType +} + +// Value returns the raw Go value (netip.Prefix) of the CIDR. +func (c CIDR) Value() any { + return c.Prefix +} + +// --- Static Validators --- + +type argChecker func(e *cel.Env, call, arg ast.Expr) error + +type networkFormatValidator struct { + funcName string + argNum int + check argChecker +} + +func (v networkFormatValidator) Name() string { + return fmt.Sprintf("cel.validator.network.%s", v.funcName) +} + +func (v networkFormatValidator) Validate(e *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { + root := ast.NavigateAST(a) + funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName)) + for _, call := range funcCalls { + callArgs := call.AsCall().Args() + if len(callArgs) <= v.argNum { + continue + } + litArg := callArgs[v.argNum] + if litArg.Kind() != ast.LiteralKind { + continue + } + if err := v.check(e, call, litArg); err != nil { + iss.ReportErrorAtID(litArg.ID(), "invalid %s argument: %v", v.funcName, err) + } + } +} + +func checkIP(e *cel.Env, call, arg ast.Expr) error { + pattern := arg.AsLiteral().Value().(string) + _, err := parseIPAddr(pattern) + return err +} + +func checkCIDR(e *cel.Env, call, arg ast.Expr) error { + pattern := arg.AsLiteral().Value().(string) + _, err := parseCIDR(pattern) + return err +} diff --git a/ext/network_test.go b/ext/network_test.go new file mode 100644 index 00000000..b84ca0ac --- /dev/null +++ b/ext/network_test.go @@ -0,0 +1,562 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "net/netip" + "reflect" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" +) + +func TestNetwork_Success(t *testing.T) { + // These test cases are ported from kubernetes/staging/src/k8s.io/apiserver/pkg/cel/library + // to ensure 1-to-1 parity with the Kubernetes implementation. + tests := []struct { + name string + expr string + out any + }{ + // CIDR Accessors + { + name: "cidr ip extraction", + expr: "cidr('192.168.0.0/24').ip() == ip('192.168.0.0')", + out: true, + }, + { + name: "cidr ip extraction (host bits set)", + // K8s behavior: cidr('1.2.3.4/24').ip() returns 1.2.3.4, not 1.2.3.0 + expr: "cidr('192.168.1.5/24').ip() == ip('192.168.1.5')", + out: true, + }, + { + name: "cidr masked", + // masked() zeroes out the host bits + expr: "cidr('192.168.1.5/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + { + name: "cidr masked identity", + expr: "cidr('192.168.1.0/24').masked() == cidr('192.168.1.0/24')", + out: true, + }, + { + name: "cidr prefixLength", + expr: "cidr('192.168.0.0/24').prefixLength()", + out: int64(24), + }, + { + name: "cidr to string IPv4", + expr: "string(cidr('10.0.0.0/8'))", + out: "10.0.0.0/8", + }, + { + name: "cidr to string IPv6", + expr: "string(cidr('::1/128'))", + out: "::1/128", + }, + + // Containment (CIDR in CIDR) + { + name: "containsCIDR different family", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('::1/128'))", + out: false, + }, + { + name: "containsCIDR disjoint", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('11.0.0.0/8'))", + out: false, + }, + { + name: "containsCIDR exact match", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.0.0.0/8'))", + out: true, + }, + { + name: "containsCIDR larger prefix (false)", + // /8 does not contain /4 + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('0.0.0.0/4'))", + out: false, + }, + { + name: "containsCIDR string overload", + expr: "cidr('10.0.0.0/8').containsCIDR('10.1.0.0/16')", + out: true, + }, + { + name: "containsCIDR subnet", + expr: "cidr('10.0.0.0/8').containsCIDR(cidr('10.1.0.0/16'))", + out: true, + }, + + // Containment (IP in CIDR) + { + name: "containsIP edge case (broadcast)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.255.255.255'))", + out: true, + }, + { + name: "containsIP edge case (network address)", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.0.0.0'))", + out: true, + }, + { + name: "containsIP false", + expr: "cidr('10.0.0.0/8').containsIP(ip('11.0.0.0'))", + out: false, + }, + { + name: "containsIP simple", + expr: "cidr('10.0.0.0/8').containsIP(ip('10.1.2.3'))", + out: true, + }, + { + name: "containsIP string overload", + expr: "cidr('10.0.0.0/8').containsIP('10.1.2.3')", + out: true, + }, + + // IP Constructors & Properties + { + name: "family IPv4", + expr: "ip('127.0.0.1').family()", + out: int64(4), + }, + { + name: "family IPv6", + expr: "ip('::1').family()", + out: int64(6), + }, + { + name: "ip equality IPv4", + expr: "ip('127.0.0.1') == ip('127.0.0.1')", + out: true, + }, + { + name: "ip equality IPv6 mixed case inputs", + // Logic check: The value is equal even if string rep was different + expr: "ip('2001:db8::1') == ip('2001:DB8::1')", + out: true, + }, + { + name: "ip inequality", + expr: "ip('127.0.0.1') == ip('1.2.3.4')", + out: false, + }, + { + name: "ip to string IPv4", + expr: "string(ip('1.2.3.4'))", + out: "1.2.3.4", + }, + { + name: "ip to string IPv6", + expr: "string(ip('2001:db8::1'))", + out: "2001:db8::1", + }, + + // IP Canonicalization + { + name: "isCanonical IPv4 simple", + expr: "ip.isCanonical('127.0.0.1')", + out: true, + }, + { + name: "isCanonical IPv6 expanded (invalid)", + expr: "ip.isCanonical('2001:db8:0:0:0:0:0:1')", + out: false, + }, + { + name: "isCanonical IPv6 standard", + expr: "ip.isCanonical('2001:db8::1')", + out: true, + }, + { + name: "isCanonical IPv6 uppercase (invalid)", + expr: "ip.isCanonical('2001:DB8::1')", + out: false, + }, + + // IP Types & Predicates + { + name: "isGlobalUnicast 8.8.8.8", + expr: "ip('8.8.8.8').isGlobalUnicast()", + out: true, + }, + { + name: "isLinkLocalMulticast", + expr: "ip('ff02::1').isLinkLocalMulticast()", + out: true, + }, + { + name: "isLoopback IPv4", + expr: "ip('127.0.0.1').isLoopback()", + out: true, + }, + { + name: "isLoopback IPv6", + expr: "ip('::1').isLoopback()", + out: true, + }, + { + name: "isUnspecified IPv4", + expr: "ip('0.0.0.0').isUnspecified()", + out: true, + }, + { + name: "isUnspecified IPv6", + expr: "ip('::').isUnspecified()", + out: true, + }, + + // Global Predicates (IP & CIDR) + { + name: "isCIDR invalid mask", + expr: "isCIDR('10.0.0.0/999')", + out: false, + }, + { + name: "isCIDR loose (host bits)", + expr: "isCIDR('10.0.0.1/8')", + out: true, + }, + { + name: "isCIDR valid", + expr: "isCIDR('10.0.0.0/8')", + out: true, + }, + { + name: "isInterfaceAddress valid (host bits)", + expr: "isInterfaceAddress('10.0.0.1/8')", + out: true, + }, + { + name: "isIP invalid", + expr: "isIP('not.an.ip')", + out: false, + }, + { + name: "isIP valid IPv4", + expr: "isIP('1.2.3.4')", + out: true, + }, + { + name: "isIP valid IPv6", + expr: "isIP('2001:db8::1')", + out: true, + }, + { + name: "isIP with port (invalid)", + expr: "isIP('127.0.0.1:80')", + out: false, + }, + { + name: "isStrictCIDR invalid (host bits)", + expr: "isStrictCIDR('10.0.0.1/8')", + out: false, + }, + { + name: "isStrictCIDR valid", + expr: "isStrictCIDR('10.0.0.0/8')", + out: true, + }, + } + + // Initialize the environment with the Network extension + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval(%q) failed: %v", tst.expr, err) + } + + // Convert the CEL result to a native Go value for comparison + got, err := out.ConvertToNative(reflect.TypeOf(tst.out)) + if err != nil { + t.Fatalf("ConvertToNative failed for expr %q: %v", tst.expr, err) + } + + if !reflect.DeepEqual(got, tst.out) { + t.Errorf("Expr %q result got %v, wanted %v", tst.expr, got, tst.out) + } + }) + } +} + +func TestNetwork_RuntimeErrors(t *testing.T) { + tests := []struct { + name string + expr string + errContains string + }{ + { + name: "containsIP string overload invalid", + expr: "cidr('10.0.0.0/8').containsIP('not-an-ip')", + errContains: "parse error", + }, + { + name: "containsCIDR string overload invalid", + expr: "cidr('10.0.0.0/8').containsCIDR('not-a-cidr')", + errContains: "parse error", + }, + } + + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + // Note: We only check runtime errors here. Compile errors are unexpected + // because these functions accept strings, so type-check passes. + t.Fatalf("Compile(%q) failed unexpectedly: %v", tst.expr, iss.Err()) + } + + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program(%q) failed: %v", tst.expr, err) + } + + _, _, err = prg.Eval(cel.NoVars()) + if err == nil { + t.Errorf("Expected runtime error for %q, got nil", tst.expr) + return + } + + // CEL errors are sometimes wrapped, so we check substring + if !types.IsError(types.NewErr("%s", err.Error())) { + // Just a sanity check that it is indeed a CEL-compatible error structure + // Not strictly necessary but good practice + } + + // Standard substring check + gotErr := err.Error() + // We just check if the message contains the specific error text we return in network.go + found := false + // Note: The actual error might be wrapped in "evaluation error: ..." + if len(tst.errContains) > 0 { + // Simple string contains check + for i := 0; i < len(gotErr)-len(tst.errContains)+1; i++ { + if gotErr[i:i+len(tst.errContains)] == tst.errContains { + found = true + break + } + } + } + + if !found { + t.Errorf("Expected error containing %q, got %q", tst.errContains, gotErr) + } + }) + } +} + +func TestNetwork_TypeConversions(t *testing.T) { + addr, _ := netip.ParseAddr("1.2.3.4") + prefix, _ := netip.ParsePrefix("10.0.0.0/8") + + ipVal := IP{Addr: addr} + cidrVal := CIDR{Prefix: prefix} + + // IP Conversions + t.Run("IP ConvertToNative netip.Addr", func(t *testing.T) { + got, err := ipVal.ConvertToNative(reflect.TypeOf(netip.Addr{})) + if err != nil { + t.Fatalf("ConvertToNative failed: %v", err) + } + if got != addr { + t.Errorf("got %v, want %v", got, addr) + } + }) + + t.Run("IP ConvertToNative string", func(t *testing.T) { + got, err := ipVal.ConvertToNative(reflect.TypeOf("")) + if err != nil { + t.Fatalf("ConvertToNative failed: %v", err) + } + if got != "1.2.3.4" { + t.Errorf("got %v, want %v", got, "1.2.3.4") + } + }) + + t.Run("IP ConvertToNative unsupported", func(t *testing.T) { + _, err := ipVal.ConvertToNative(reflect.TypeOf(0)) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("IP ConvertToType StringType", func(t *testing.T) { + got := ipVal.ConvertToType(types.StringType) + if got.Type() != types.StringType { + t.Errorf("got type %v, want %v", got.Type(), types.StringType) + } + if got.Value() != "1.2.3.4" { + t.Errorf("got value %v, want %v", got.Value(), "1.2.3.4") + } + }) + + t.Run("IP ConvertToType IPType", func(t *testing.T) { + got := ipVal.ConvertToType(IPType) + if got != ipVal { + t.Errorf("got %v, want %v", got, ipVal) + } + }) + + t.Run("IP ConvertToType TypeType", func(t *testing.T) { + got := ipVal.ConvertToType(types.TypeType) + if got != IPType { + t.Errorf("got %v, want %v", got, IPType) + } + }) + + // CIDR Conversions + t.Run("CIDR ConvertToNative netip.Prefix", func(t *testing.T) { + got, err := cidrVal.ConvertToNative(reflect.TypeOf(netip.Prefix{})) + if err != nil { + t.Fatalf("ConvertToNative failed: %v", err) + } + if got != prefix { + t.Errorf("got %v, want %v", got, prefix) + } + }) + + t.Run("CIDR ConvertToNative string", func(t *testing.T) { + got, err := cidrVal.ConvertToNative(reflect.TypeOf("")) + if err != nil { + t.Fatalf("ConvertToNative failed: %v", err) + } + if got != "10.0.0.0/8" { + t.Errorf("got %v, want %v", got, "10.0.0.0/8") + } + }) + + t.Run("CIDR ConvertToNative unsupported", func(t *testing.T) { + _, err := cidrVal.ConvertToNative(reflect.TypeOf(0)) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("CIDR ConvertToType StringType", func(t *testing.T) { + got := cidrVal.ConvertToType(types.StringType) + if got.Type() != types.StringType { + t.Errorf("got type %v, want %v", got.Type(), types.StringType) + } + if got.Value() != "10.0.0.0/8" { + t.Errorf("got value %v, want %v", got.Value(), "10.0.0.0/8") + } + }) + + t.Run("CIDR ConvertToType CIDRType", func(t *testing.T) { + got := cidrVal.ConvertToType(CIDRType) + if got != cidrVal { + t.Errorf("got %v, want %v", got, cidrVal) + } + }) + + t.Run("CIDR ConvertToType TypeType", func(t *testing.T) { + got := cidrVal.ConvertToType(types.TypeType) + if got != CIDRType { + t.Errorf("got %v, want %v", got, CIDRType) + } + }) +} + +func TestNetwork_CompileErrors(t *testing.T) { + tests := []struct { + name string + expr string + errContains string + }{ + { + name: "ip constructor invalid literal", + expr: "ip('999.999.999.999')", + errContains: "invalid ip argument", + }, + { + name: "cidr constructor invalid literal", + expr: "cidr('1.2.3.4')", + errContains: "invalid cidr argument", + }, + { + name: "cidr constructor invalid mask literal", + expr: "cidr('10.0.0.0/999')", + errContains: "invalid cidr argument", + }, + { + name: "ip constructor valid literal", + expr: "ip('127.0.0.1')", + errContains: "", + }, + { + name: "cidr constructor valid literal", + expr: "cidr('10.0.0.0/8')", + errContains: "", + }, + } + + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + _, iss := env.Compile(tst.expr) + if tst.errContains != "" { + if iss.Err() == nil { + t.Errorf("Expected compile error for %q, got nil", tst.expr) + return + } + gotErr := iss.Err().Error() + // Simple string contains check + found := false + for i := 0; i < len(gotErr)-len(tst.errContains)+1; i++ { + if gotErr[i:i+len(tst.errContains)] == tst.errContains { + found = true + break + } + } + if !found { + t.Errorf("Expected compile error containing %q, got %q", tst.errContains, gotErr) + } + } else { + if iss.Err() != nil { + t.Errorf("Compile(%q) failed unexpectedly: %v", tst.expr, iss.Err()) + } + } + }) + } +}