From 04c0490992cdb31a46ca585a9e3d64bbfb3a3863 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 3 Mar 2026 20:59:13 +0800 Subject: [PATCH 01/41] Add MAC and hostname rule items --- adapter/inbound.go | 3 + adapter/neighbor.go | 13 + adapter/router.go | 2 + docs/configuration/dns/rule.md | 31 ++ docs/configuration/dns/rule.zh.md | 31 ++ docs/configuration/inbound/tun.md | 30 ++ docs/configuration/inbound/tun.zh.md | 35 ++ docs/configuration/route/index.md | 31 ++ docs/configuration/route/index.zh.md | 31 ++ docs/configuration/route/rule.md | 31 ++ docs/configuration/route/rule.zh.md | 31 ++ go.mod | 6 +- go.sum | 4 +- option/route.go | 2 + option/rule.go | 2 + option/rule_dns.go | 2 + option/tun.go | 2 + protocol/tun/inbound.go | 18 + route/neighbor_resolver_linux.go | 596 +++++++++++++++++++++ route/neighbor_resolver_stub.go | 14 + route/route.go | 17 + route/router.go | 39 ++ route/rule/rule_default.go | 10 + route/rule/rule_dns.go | 10 + route/rule/rule_item_source_hostname.go | 42 ++ route/rule/rule_item_source_mac_address.go | 48 ++ route/rule_conds.go | 8 + 27 files changed, 1084 insertions(+), 5 deletions(-) create mode 100644 adapter/neighbor.go create mode 100644 route/neighbor_resolver_linux.go create mode 100644 route/neighbor_resolver_stub.go create mode 100644 route/rule/rule_item_source_hostname.go create mode 100644 route/rule/rule_item_source_mac_address.go diff --git a/adapter/inbound.go b/adapter/inbound.go index f047199e43..52af336e5b 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,6 +2,7 @@ package adapter import ( "context" + "net" "net/netip" "time" @@ -82,6 +83,8 @@ type InboundContext struct { SourceGeoIPCode string GeoIPCode string ProcessInfo *ConnectionOwner + SourceMACAddress net.HardwareAddr + SourceHostname string QueryType uint16 FakeIP bool diff --git a/adapter/neighbor.go b/adapter/neighbor.go new file mode 100644 index 0000000000..920398f674 --- /dev/null +++ b/adapter/neighbor.go @@ -0,0 +1,13 @@ +package adapter + +import ( + "net" + "net/netip" +) + +type NeighborResolver interface { + LookupMAC(address netip.Addr) (net.HardwareAddr, bool) + LookupHostname(address netip.Addr) (string, bool) + Start() error + Close() error +} diff --git a/adapter/router.go b/adapter/router.go index 3d5310c4ee..82e6881a60 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -26,6 +26,8 @@ type Router interface { RuleSet(tag string) (RuleSet, bool) Rules() []Rule NeedFindProcess() bool + NeedFindNeighbor() bool + NeighborResolver() NeighborResolver AppendTracker(tracker ConnectionTracker) ResetNetwork() } diff --git a/docs/configuration/dns/rule.md b/docs/configuration/dns/rule.md index 4348674847..f8a7ac4c37 100644 --- a/docs/configuration/dns/rule.md +++ b/docs/configuration/dns/rule.md @@ -2,6 +2,11 @@ icon: material/alert-decagram --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [interface_address](#interface_address) @@ -149,6 +154,12 @@ icon: material/alert-decagram "default_interface_address": [ "2000::/3" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "wifi_ssid": [ "My WIFI" ], @@ -408,6 +419,26 @@ Matches network interface (same values as `network_type`) address. Match default interface address. +#### source_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device MAC address. + +#### source_hostname + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device hostname from DHCP leases. + #### wifi_ssid !!! quote "" diff --git a/docs/configuration/dns/rule.zh.md b/docs/configuration/dns/rule.zh.md index f35cfc7e3e..421fdfb5c1 100644 --- a/docs/configuration/dns/rule.zh.md +++ b/docs/configuration/dns/rule.zh.md @@ -2,6 +2,11 @@ icon: material/alert-decagram --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [interface_address](#interface_address) @@ -149,6 +154,12 @@ icon: material/alert-decagram "default_interface_address": [ "2000::/3" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "wifi_ssid": [ "My WIFI" ], @@ -407,6 +418,26 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. 匹配默认接口地址。 +#### source_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备 MAC 地址。 + +#### source_hostname + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备从 DHCP 租约获取的主机名。 + #### wifi_ssid !!! quote "" diff --git a/docs/configuration/inbound/tun.md b/docs/configuration/inbound/tun.md index 74d02dc933..5a2f58d3db 100644 --- a/docs/configuration/inbound/tun.md +++ b/docs/configuration/inbound/tun.md @@ -134,6 +134,12 @@ icon: material/new-box "exclude_package": [ "com.android.captiveportallogin" ], + "include_mac_address": [ + "00:11:22:33:44:55" + ], + "exclude_mac_address": [ + "66:77:88:99:aa:bb" + ], "platform": { "http_proxy": { "enabled": false, @@ -560,6 +566,30 @@ Limit android packages in route. Exclude android packages in route. +#### include_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +Limit MAC addresses in route. Not limited by default. + +Conflict with `exclude_mac_address`. + +#### exclude_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +Exclude MAC addresses in route. + +Conflict with `include_mac_address`. + #### platform Platform-specific settings, provided by client applications. diff --git a/docs/configuration/inbound/tun.zh.md b/docs/configuration/inbound/tun.zh.md index eaf5ff49c3..a41e5ae9ff 100644 --- a/docs/configuration/inbound/tun.zh.md +++ b/docs/configuration/inbound/tun.zh.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [include_mac_address](#include_mac_address) + :material-plus: [exclude_mac_address](#exclude_mac_address) + !!! quote "sing-box 1.13.3 中的更改" :material-alert: [strict_route](#strict_route) @@ -130,6 +135,12 @@ icon: material/new-box "exclude_package": [ "com.android.captiveportallogin" ], + "include_mac_address": [ + "00:11:22:33:44:55" + ], + "exclude_mac_address": [ + "66:77:88:99:aa:bb" + ], "platform": { "http_proxy": { "enabled": false, @@ -543,6 +554,30 @@ TCP/IP 栈。 排除路由的 Android 应用包名。 +#### include_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。 + +限制被路由的 MAC 地址。默认不限制。 + +与 `exclude_mac_address` 冲突。 + +#### exclude_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `auto_route` 和 `auto_redirect` 已启用。 + +排除路由的 MAC 地址。 + +与 `include_mac_address` 冲突。 + #### platform 平台特定的设置,由客户端应用提供。 diff --git a/docs/configuration/route/index.md b/docs/configuration/route/index.md index 1fc9bfd231..01e405614e 100644 --- a/docs/configuration/route/index.md +++ b/docs/configuration/route/index.md @@ -4,6 +4,11 @@ icon: material/alert-decagram # Route +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [find_neighbor](#find_neighbor) + :material-plus: [dhcp_lease_files](#dhcp_lease_files) + !!! quote "Changes in sing-box 1.12.0" :material-plus: [default_domain_resolver](#default_domain_resolver) @@ -35,6 +40,8 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_neighbor": false, + "dhcp_lease_files": [], "default_domain_resolver": "", // or {} "default_network_strategy": "", "default_network_type": [], @@ -107,6 +114,30 @@ Set routing mark by default. Takes no effect if `outbound.routing_mark` is set. +#### find_neighbor + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux. + +Enable neighbor resolution for source MAC address and hostname lookup. + +Required for `source_mac_address` and `source_hostname` rule items. + +#### dhcp_lease_files + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux. + +Custom DHCP lease file paths for hostname and MAC address resolution. + +Automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea) if empty. + #### default_domain_resolver !!! question "Since sing-box 1.12.0" diff --git a/docs/configuration/route/index.zh.md b/docs/configuration/route/index.zh.md index 1a50d3e3b5..2c12a58eb3 100644 --- a/docs/configuration/route/index.zh.md +++ b/docs/configuration/route/index.zh.md @@ -4,6 +4,11 @@ icon: material/alert-decagram # 路由 +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [find_neighbor](#find_neighbor) + :material-plus: [dhcp_lease_files](#dhcp_lease_files) + !!! quote "sing-box 1.12.0 中的更改" :material-plus: [default_domain_resolver](#default_domain_resolver) @@ -37,6 +42,8 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_neighbor": false, + "dhcp_lease_files": [], "default_network_strategy": "", "default_fallback_delay": "" } @@ -106,6 +113,30 @@ icon: material/alert-decagram 如果设置了 `outbound.routing_mark` 设置,则不生效。 +#### find_neighbor + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux。 + +启用邻居解析以查找源 MAC 地址和主机名。 + +`source_mac_address` 和 `source_hostname` 规则项需要此选项。 + +#### dhcp_lease_files + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux。 + +用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。 + +为空时自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。 + #### default_domain_resolver !!! question "自 sing-box 1.12.0 起" diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index 925187261c..16c100c1c0 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [interface_address](#interface_address) @@ -159,6 +164,12 @@ icon: material/new-box "tailscale", "wireguard" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "rule_set": [ "geoip-cn", "geosite-cn" @@ -449,6 +460,26 @@ Match specified outbounds' preferred routes. | `tailscale` | Match MagicDNS domains and peers' allowed IPs | | `wireguard` | Match peers's allowed IPs | +#### source_mac_address + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device MAC address. + +#### source_hostname + +!!! question "Since sing-box 1.14.0" + +!!! quote "" + + Only supported on Linux with `route.find_neighbor` enabled. + +Match source device hostname from DHCP leases. + #### rule_set !!! question "Since sing-box 1.8.0" diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index 53da4475f1..f21e6677b8 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [source_mac_address](#source_mac_address) + :material-plus: [source_hostname](#source_hostname) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [interface_address](#interface_address) @@ -157,6 +162,12 @@ icon: material/new-box "tailscale", "wireguard" ], + "source_mac_address": [ + "00:11:22:33:44:55" + ], + "source_hostname": [ + "my-device" + ], "rule_set": [ "geoip-cn", "geosite-cn" @@ -447,6 +458,26 @@ icon: material/new-box | `tailscale` | 匹配 MagicDNS 域名和对端的 allowed IPs | | `wireguard` | 匹配对端的 allowed IPs | +#### source_mac_address + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备 MAC 地址。 + +#### source_hostname + +!!! question "自 sing-box 1.14.0 起" + +!!! quote "" + + 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + +匹配源设备从 DHCP 租约获取的主机名。 + #### rule_set !!! question "自 sing-box 1.8.0 起" diff --git a/go.mod b/go.mod index 7af87ee34b..405cc56444 100644 --- a/go.mod +++ b/go.mod @@ -14,11 +14,13 @@ require ( github.com/godbus/dbus/v5 v5.2.2 github.com/gofrs/uuid/v5 v5.4.0 github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91 + github.com/jsimonetti/rtnetlink v1.4.0 github.com/keybase/go-keychain v0.0.1 github.com/libdns/acmedns v0.5.0 github.com/libdns/alidns v1.0.6 github.com/libdns/cloudflare v0.2.2 github.com/logrusorgru/aurora v2.0.3+incompatible + github.com/mdlayher/netlink v1.9.0 github.com/metacubex/utls v1.8.4 github.com/mholt/acmez/v3 v3.1.6 github.com/miekg/dns v1.1.72 @@ -39,7 +41,7 @@ require ( github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 - github.com/sagernet/sing-tun v0.8.6 + github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 github.com/sagernet/smux v1.5.50-sing-box-mod.1 github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.7 @@ -92,11 +94,9 @@ require ( github.com/hashicorp/yamux v0.1.2 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/libdns/libdns v1.1.1 // indirect - github.com/mdlayher/netlink v1.9.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect diff --git a/go.sum b/go.sum index 9fd6ec8654..4d53eb40cd 100644 --- a/go.sum +++ b/go.sum @@ -248,8 +248,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w= github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA= -github.com/sagernet/sing-tun v0.8.6 h1:NydXFikSXhiKqhahHKtuZ90HQPZFzlOFVRONmkr4C7I= -github.com/sagernet/sing-tun v0.8.6/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs= +github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d h1:vi0j6301f6H8t2GYgAC2PA2AdnGdMwkP34B4+N03Qt4= +github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs= github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o= github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY= github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478= diff --git a/option/route.go b/option/route.go index f4b6539156..0c3e576d13 100644 --- a/option/route.go +++ b/option/route.go @@ -9,6 +9,8 @@ type RouteOptions struct { RuleSet []RuleSet `json:"rule_set,omitempty"` Final string `json:"final,omitempty"` FindProcess bool `json:"find_process,omitempty"` + FindNeighbor bool `json:"find_neighbor,omitempty"` + DHCPLeaseFiles badoption.Listable[string] `json:"dhcp_lease_files,omitempty"` AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` DefaultInterface string `json:"default_interface,omitempty"` diff --git a/option/rule.go b/option/rule.go index 3e7fd8771b..b792ccf4b2 100644 --- a/option/rule.go +++ b/option/rule.go @@ -103,6 +103,8 @@ type RawDefaultRule struct { InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"` NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"` DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"` + SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"` + SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"` PreferredBy badoption.Listable[string] `json:"preferred_by,omitempty"` RuleSet badoption.Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` diff --git a/option/rule_dns.go b/option/rule_dns.go index dbc1657898..880b96ac54 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -106,6 +106,8 @@ type RawDefaultDNSRule struct { InterfaceAddress *badjson.TypedMap[string, badoption.Listable[*badoption.Prefixable]] `json:"interface_address,omitempty"` NetworkInterfaceAddress *badjson.TypedMap[InterfaceType, badoption.Listable[*badoption.Prefixable]] `json:"network_interface_address,omitempty"` DefaultInterfaceAddress badoption.Listable[*badoption.Prefixable] `json:"default_interface_address,omitempty"` + SourceMACAddress badoption.Listable[string] `json:"source_mac_address,omitempty"` + SourceHostname badoption.Listable[string] `json:"source_hostname,omitempty"` RuleSet badoption.Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` RuleSetIPCIDRAcceptEmpty bool `json:"rule_set_ip_cidr_accept_empty,omitempty"` diff --git a/option/tun.go b/option/tun.go index 72b6e456ba..fda028b69e 100644 --- a/option/tun.go +++ b/option/tun.go @@ -39,6 +39,8 @@ type TunInboundOptions struct { IncludeAndroidUser badoption.Listable[int] `json:"include_android_user,omitempty"` IncludePackage badoption.Listable[string] `json:"include_package,omitempty"` ExcludePackage badoption.Listable[string] `json:"exclude_package,omitempty"` + IncludeMACAddress badoption.Listable[string] `json:"include_mac_address,omitempty"` + ExcludeMACAddress badoption.Listable[string] `json:"exclude_mac_address,omitempty"` UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"` Stack string `json:"stack,omitempty"` Platform *TunPlatformOptions `json:"platform,omitempty"` diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index 6820831a5c..4b113f4a78 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -160,6 +160,22 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if nfQueue == 0 { nfQueue = tun.DefaultAutoRedirectNFQueue } + var includeMACAddress []net.HardwareAddr + for i, macString := range options.IncludeMACAddress { + mac, macErr := net.ParseMAC(macString) + if macErr != nil { + return nil, E.Cause(macErr, "parse include_mac_address[", i, "]") + } + includeMACAddress = append(includeMACAddress, mac) + } + var excludeMACAddress []net.HardwareAddr + for i, macString := range options.ExcludeMACAddress { + mac, macErr := net.ParseMAC(macString) + if macErr != nil { + return nil, E.Cause(macErr, "parse exclude_mac_address[", i, "]") + } + excludeMACAddress = append(excludeMACAddress, mac) + } networkManager := service.FromContext[adapter.NetworkManager](ctx) multiPendingPackets := C.IsDarwin && ((options.Stack == "gvisor" && tunMTU < 32768) || (options.Stack != "gvisor" && options.MTU <= 9000)) inbound := &Inbound{ @@ -197,6 +213,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo IncludeAndroidUser: options.IncludeAndroidUser, IncludePackage: options.IncludePackage, ExcludePackage: options.ExcludePackage, + IncludeMACAddress: includeMACAddress, + ExcludeMACAddress: excludeMACAddress, InterfaceMonitor: networkManager.InterfaceMonitor(), EXP_MultiPendingPackets: multiPendingPackets, }, diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go new file mode 100644 index 0000000000..40db5766ad --- /dev/null +++ b/route/neighbor_resolver_linux.go @@ -0,0 +1,596 @@ +//go:build linux + +package route + +import ( + "bufio" + "encoding/binary" + "encoding/hex" + "net" + "net/netip" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var defaultLeaseFiles = []string{ + "/tmp/dhcp.leases", + "/var/lib/dhcp/dhcpd.leases", + "/var/lib/dhcpd/dhcpd.leases", + "/var/lib/kea/kea-leases4.csv", + "/var/lib/kea/kea-leases6.csv", +} + +type neighborResolver struct { + logger logger.ContextLogger + leaseFiles []string + access sync.RWMutex + neighborIPToMAC map[netip.Addr]net.HardwareAddr + leaseIPToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string + watcher *fswatch.Watcher + done chan struct{} +} + +func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) { + if len(leaseFiles) == 0 { + for _, path := range defaultLeaseFiles { + info, err := os.Stat(path) + if err == nil && info.Size() > 0 { + leaseFiles = append(leaseFiles, path) + } + } + } + return &neighborResolver{ + logger: resolverLogger, + leaseFiles: leaseFiles, + neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr), + leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + done: make(chan struct{}), + }, nil +} + +func (r *neighborResolver) Start() error { + err := r.loadNeighborTable() + if err != nil { + r.logger.Warn(E.Cause(err, "load neighbor table")) + } + r.reloadLeaseFiles() + go r.subscribeNeighborUpdates() + if len(r.leaseFiles) > 0 { + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: r.leaseFiles, + Logger: r.logger, + Callback: func(_ string) { + r.reloadLeaseFiles() + }, + }) + if err != nil { + r.logger.Warn(E.Cause(err, "create lease file watcher")) + } else { + r.watcher = watcher + err = watcher.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start lease file watcher")) + } + } + } + return nil +} + +func (r *neighborResolver) Close() error { + close(r.done) + if r.watcher != nil { + return r.watcher.Close() + } + return nil +} + +func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.neighborIPToMAC[address] + if found { + return mac, true + } + mac, found = r.leaseIPToMAC[address] + if found { + return mac, true + } + mac, found = extractMACFromEUI64(address) + if found { + return mac, true + } + return nil, false +} + +func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, macFound := r.neighborIPToMAC[address] + if !macFound { + mac, macFound = r.leaseIPToMAC[address] + } + if !macFound { + mac, macFound = extractMACFromEUI64(address) + } + if macFound { + hostname, found = r.macToHostname[mac.String()] + if found { + return hostname, true + } + } + return "", false +} + +func (r *neighborResolver) loadNeighborTable() error { + connection, err := rtnetlink.Dial(nil) + if err != nil { + return E.Cause(err, "dial rtnetlink") + } + defer connection.Close() + neighbors, err := connection.Neigh.List() + if err != nil { + return E.Cause(err, "list neighbors") + } + r.access.Lock() + defer r.access.Unlock() + for _, neigh := range neighbors { + if neigh.Attributes == nil { + continue + } + if neigh.Attributes.LLAddress == nil || len(neigh.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neigh.Attributes.Address) + if !ok { + continue + } + r.neighborIPToMAC[address] = slices.Clone(neigh.Attributes.LLAddress) + } + return nil +} + +func (r *neighborResolver) subscribeNeighborUpdates() { + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + r.logger.Warn(E.Cause(err, "subscribe neighbor updates")) + return + } + defer connection.Close() + for { + select { + case <-r.done: + return + default: + } + err = connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + r.logger.Warn(E.Cause(err, "set netlink read deadline")) + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-r.done: + return + default: + } + r.logger.Warn(E.Cause(err, "receive neighbor update")) + continue + } + for _, message := range messages { + switch message.Header.Type { + case unix.RTM_NEWNEIGH: + var neighMessage rtnetlink.NeighMessage + unmarshalErr := neighMessage.UnmarshalBinary(message.Data) + if unmarshalErr != nil { + continue + } + if neighMessage.Attributes == nil { + continue + } + if neighMessage.Attributes.LLAddress == nil || len(neighMessage.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + continue + } + r.access.Lock() + r.neighborIPToMAC[address] = slices.Clone(neighMessage.Attributes.LLAddress) + r.access.Unlock() + case unix.RTM_DELNEIGH: + var neighMessage rtnetlink.NeighMessage + unmarshalErr := neighMessage.UnmarshalBinary(message.Data) + if unmarshalErr != nil { + continue + } + if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + continue + } + r.access.Lock() + delete(r.neighborIPToMAC, address) + r.access.Unlock() + } + } + } +} + +func (r *neighborResolver) reloadLeaseFiles() { + leaseIPToMAC := make(map[netip.Addr]net.HardwareAddr) + ipToHostname := make(map[netip.Addr]string) + macToHostname := make(map[string]string) + for _, path := range r.leaseFiles { + r.parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) + } + r.access.Lock() + r.leaseIPToMAC = leaseIPToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() +} + +func (r *neighborResolver) parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + file, err := os.Open(path) + if err != nil { + return + } + defer file.Close() + if strings.HasSuffix(path, "kea-leases4.csv") { + r.parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases6.csv") { + r.parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "dhcpd.leases") { + r.parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) + return + } + r.parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) +} + +func (r *neighborResolver) parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "duid ") { + continue + } + if strings.HasPrefix(line, "# ") { + r.parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) + continue + } + fields := strings.Fields(line) + if len(fields) < 4 { + continue + } + expiry, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + continue + } + if expiry != 0 && expiry < now { + continue + } + if strings.Contains(fields[1], ":") { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + ipToMAC[address] = mac + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } else { + var mac net.HardwareAddr + if len(fields) >= 5 { + duid, duidErr := parseDUID(fields[4]) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } + } +} + +func (r *neighborResolver) parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + fields := strings.Fields(line) + if len(fields) < 5 { + return + } + validTime, err := strconv.ParseInt(fields[4], 10, 64) + if err != nil { + return + } + if validTime == 0 { + return + } + if validTime > 0 && validTime < time.Now().Unix() { + return + } + hostname := fields[3] + if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { + hostname = "" + } + if len(fields) >= 8 && fields[2] == "ipv4" { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + return + } + addressField := fields[7] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + return + } + address = address.Unmap() + ipToMAC[address] = mac + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + return + } + var mac net.HardwareAddr + duidHex := fields[1] + duidBytes, hexErr := hex.DecodeString(duidHex) + if hexErr == nil { + mac, _ = extractMACFromDUID(duidBytes) + } + for i := 7; i < len(fields); i++ { + addressField := fields[i] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func (r *neighborResolver) parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentHostname string + var currentActive bool + var inLease bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { + ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) + if addrOK { + currentIP = parsed.Unmap() + inLease = true + currentMAC = nil + currentHostname = "" + currentActive = false + } + continue + } + if line == "}" && inLease { + if currentActive && currentMAC != nil { + ipToMAC[currentIP] = currentMAC + if currentHostname != "" { + ipToHostname[currentIP] = currentHostname + macToHostname[currentMAC.String()] = currentHostname + } + } else { + delete(ipToMAC, currentIP) + delete(ipToHostname, currentIP) + } + inLease = false + continue + } + if !inLease { + continue + } + if strings.HasPrefix(line, "hardware ethernet ") { + macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") + parsed, macErr := net.ParseMAC(macString) + if macErr == nil { + currentMAC = parsed + } + } else if strings.HasPrefix(line, "client-hostname ") { + hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") + hostname = strings.Trim(hostname, "\"") + if hostname != "" { + currentHostname = hostname + } + } else if strings.HasPrefix(line, "binding state ") { + state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") + currentActive = state == "active" + } + } +} + +func (r *neighborResolver) parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 10 { + continue + } + if fields[9] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + ipToMAC[address] = mac + hostname := "" + if len(fields) > 8 { + hostname = fields[8] + } + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } +} + +func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 14 { + continue + } + if fields[13] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + var mac net.HardwareAddr + if fields[12] != "" { + mac, _ = net.ParseMAC(fields[12]) + } + if mac == nil { + duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + hostname := "" + if len(fields) > 11 { + hostname = fields[11] + } + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { + if len(duid) < 4 { + return nil, false + } + duidType := binary.BigEndian.Uint16(duid[0:2]) + hwType := binary.BigEndian.Uint16(duid[2:4]) + if hwType != 1 { + return nil, false + } + switch duidType { + case 1: + if len(duid) < 14 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[8:14])), true + case 3: + if len(duid) < 10 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[4:10])), true + } + return nil, false +} + +func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { + if !address.Is6() { + return nil, false + } + b := address.As16() + if b[11] != 0xff || b[12] != 0xfe { + return nil, false + } + return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true +} + +func parseDUID(s string) ([]byte, error) { + cleaned := strings.ReplaceAll(s, ":", "") + return hex.DecodeString(cleaned) +} diff --git a/route/neighbor_resolver_stub.go b/route/neighbor_resolver_stub.go new file mode 100644 index 0000000000..9288892a8d --- /dev/null +++ b/route/neighbor_resolver_stub.go @@ -0,0 +1,14 @@ +//go:build !linux + +package route + +import ( + "os" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/logger" +) + +func newNeighborResolver(_ logger.ContextLogger, _ []string) (adapter.NeighborResolver, error) { + return nil, os.ErrInvalid +} diff --git a/route/route.go b/route/route.go index 77b66ea409..8e449db463 100644 --- a/route/route.go +++ b/route/route.go @@ -438,6 +438,23 @@ func (r *Router) matchRule( metadata.ProcessInfo = processInfo } } + if r.neighborResolver != nil && metadata.SourceMACAddress == nil && metadata.Source.Addr.IsValid() { + mac, macFound := r.neighborResolver.LookupMAC(metadata.Source.Addr) + if macFound { + metadata.SourceMACAddress = mac + } + hostname, hostnameFound := r.neighborResolver.LookupHostname(metadata.Source.Addr) + if hostnameFound { + metadata.SourceHostname = hostname + if macFound { + r.logger.InfoContext(ctx, "found neighbor: ", mac, ", hostname: ", hostname) + } else { + r.logger.InfoContext(ctx, "found neighbor hostname: ", hostname) + } + } else if macFound { + r.logger.InfoContext(ctx, "found neighbor: ", mac) + } + } if metadata.Destination.Addr.IsValid() && r.dnsTransport.FakeIP() != nil && r.dnsTransport.FakeIP().Store().Contains(metadata.Destination.Addr) { domain, loaded := r.dnsTransport.FakeIP().Store().Lookup(metadata.Destination.Addr) if !loaded { diff --git a/route/router.go b/route/router.go index bc19b5d38f..52eb9e4362 100644 --- a/route/router.go +++ b/route/router.go @@ -35,10 +35,13 @@ type Router struct { network adapter.NetworkManager rules []adapter.Rule needFindProcess bool + needFindNeighbor bool + leaseFiles []string ruleSets []adapter.RuleSet ruleSetMap map[string]adapter.RuleSet processSearcher process.Searcher processCache freelru.Cache[processCacheKey, processCacheEntry] + neighborResolver adapter.NeighborResolver pauseManager pause.Manager trackers []adapter.ConnectionTracker platformInterface adapter.PlatformInterface @@ -58,6 +61,8 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route rules: make([]adapter.Rule, 0, len(options.Rules)), ruleSetMap: make(map[string]adapter.RuleSet), needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, + needFindNeighbor: hasRule(options.Rules, isNeighborRule) || hasDNSRule(dnsOptions.Rules, isNeighborDNSRule) || options.FindNeighbor, + leaseFiles: options.DHCPLeaseFiles, pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[adapter.PlatformInterface](ctx), } @@ -117,6 +122,7 @@ func (r *Router) Start(stage adapter.StartStage) error { } r.network.Initialize(r.ruleSets) needFindProcess := r.needFindProcess + needFindNeighbor := r.needFindNeighbor for _, ruleSet := range r.ruleSets { metadata := ruleSet.Metadata() if metadata.ContainsProcessRule { @@ -151,6 +157,24 @@ func (r *Router) Start(stage adapter.StartStage) error { processCache.SetLifetime(200 * time.Millisecond) r.processCache = processCache } + r.needFindNeighbor = needFindNeighbor + if needFindNeighbor { + monitor.Start("initialize neighbor resolver") + resolver, err := newNeighborResolver(r.logger, r.leaseFiles) + monitor.Finish() + if err != nil { + if err != os.ErrInvalid { + r.logger.Warn(E.Cause(err, "create neighbor resolver")) + } + } else { + err = resolver.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start neighbor resolver")) + } else { + r.neighborResolver = resolver + } + } + } case adapter.StartStatePostStart: for i, rule := range r.rules { monitor.Start("initialize rule[", i, "]") @@ -182,6 +206,13 @@ func (r *Router) Start(stage adapter.StartStage) error { func (r *Router) Close() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) var err error + if r.neighborResolver != nil { + monitor.Start("close neighbor resolver") + err = E.Append(err, r.neighborResolver.Close(), func(closeErr error) error { + return E.Cause(closeErr, "close neighbor resolver") + }) + monitor.Finish() + } for i, rule := range r.rules { monitor.Start("close rule[", i, "]") err = E.Append(err, rule.Close(), func(err error) error { @@ -223,6 +254,14 @@ func (r *Router) NeedFindProcess() bool { return r.needFindProcess } +func (r *Router) NeedFindNeighbor() bool { + return r.needFindNeighbor +} + +func (r *Router) NeighborResolver() adapter.NeighborResolver { + return r.neighborResolver +} + func (r *Router) ResetNetwork() { r.network.ResetNetwork() r.dns.ResetNetwork() diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index b921c8b286..5ce1f87d4a 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -264,6 +264,16 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourceMACAddress) > 0 { + item := NewSourceMACAddressItem(options.SourceMACAddress) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceHostname) > 0 { + item := NewSourceHostnameItem(options.SourceHostname) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.PreferredBy) > 0 { item := NewPreferredByItem(ctx, options.PreferredBy) rule.items = append(rule.items, item) diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 04f0f236b2..f33d6096ae 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -265,6 +265,16 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourceMACAddress) > 0 { + item := NewSourceMACAddressItem(options.SourceMACAddress) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceHostname) > 0 { + item := NewSourceHostnameItem(options.SourceHostname) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.RuleSet) > 0 { //nolint:staticcheck if options.Deprecated_RulesetIPCIDRMatchSource { diff --git a/route/rule/rule_item_source_hostname.go b/route/rule/rule_item_source_hostname.go new file mode 100644 index 0000000000..0df11c8c8a --- /dev/null +++ b/route/rule/rule_item_source_hostname.go @@ -0,0 +1,42 @@ +package rule + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*SourceHostnameItem)(nil) + +type SourceHostnameItem struct { + hostnames []string + hostnameMap map[string]bool +} + +func NewSourceHostnameItem(hostnameList []string) *SourceHostnameItem { + rule := &SourceHostnameItem{ + hostnames: hostnameList, + hostnameMap: make(map[string]bool), + } + for _, hostname := range hostnameList { + rule.hostnameMap[hostname] = true + } + return rule +} + +func (r *SourceHostnameItem) Match(metadata *adapter.InboundContext) bool { + if metadata.SourceHostname == "" { + return false + } + return r.hostnameMap[metadata.SourceHostname] +} + +func (r *SourceHostnameItem) String() string { + var description string + if len(r.hostnames) == 1 { + description = "source_hostname=" + r.hostnames[0] + } else { + description = "source_hostname=[" + strings.Join(r.hostnames, " ") + "]" + } + return description +} diff --git a/route/rule/rule_item_source_mac_address.go b/route/rule/rule_item_source_mac_address.go new file mode 100644 index 0000000000..feeadb1dbf --- /dev/null +++ b/route/rule/rule_item_source_mac_address.go @@ -0,0 +1,48 @@ +package rule + +import ( + "net" + "strings" + + "github.com/sagernet/sing-box/adapter" +) + +var _ RuleItem = (*SourceMACAddressItem)(nil) + +type SourceMACAddressItem struct { + addresses []string + addressMap map[string]bool +} + +func NewSourceMACAddressItem(addressList []string) *SourceMACAddressItem { + rule := &SourceMACAddressItem{ + addresses: addressList, + addressMap: make(map[string]bool), + } + for _, address := range addressList { + parsed, err := net.ParseMAC(address) + if err == nil { + rule.addressMap[parsed.String()] = true + } else { + rule.addressMap[address] = true + } + } + return rule +} + +func (r *SourceMACAddressItem) Match(metadata *adapter.InboundContext) bool { + if metadata.SourceMACAddress == nil { + return false + } + return r.addressMap[metadata.SourceMACAddress.String()] +} + +func (r *SourceMACAddressItem) String() string { + var description string + if len(r.addresses) == 1 { + description = "source_mac_address=" + r.addresses[0] + } else { + description = "source_mac_address=[" + strings.Join(r.addresses, " ") + "]" + } + return description +} diff --git a/route/rule_conds.go b/route/rule_conds.go index 55c4a058e2..22ce94fffd 100644 --- a/route/rule_conds.go +++ b/route/rule_conds.go @@ -45,6 +45,14 @@ func isProcessDNSRule(rule option.DefaultDNSRule) bool { return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.ProcessPathRegex) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 } +func isNeighborRule(rule option.DefaultRule) bool { + return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0 +} + +func isNeighborDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.SourceMACAddress) > 0 || len(rule.SourceHostname) > 0 +} + func isWIFIRule(rule option.DefaultRule) bool { return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0 } From 45339d101bd16cbbc7476dc08a96bd5be0ad28db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 5 Mar 2026 00:15:37 +0800 Subject: [PATCH 02/41] Add Android support for MAC and hostname rule items --- adapter/neighbor.go | 10 ++ adapter/platform.go | 4 + experimental/libbox/config.go | 12 +++ experimental/libbox/neighbor.go | 135 +++++++++++++++++++++++++++ experimental/libbox/neighbor_stub.go | 24 +++++ experimental/libbox/platform.go | 6 ++ experimental/libbox/service.go | 37 ++++++++ route/neighbor_resolver_linux.go | 85 ++--------------- route/neighbor_resolver_parse.go | 50 ++++++++++ route/neighbor_resolver_platform.go | 84 +++++++++++++++++ route/neighbor_table_linux.go | 68 ++++++++++++++ route/router.go | 33 +++++-- 12 files changed, 462 insertions(+), 86 deletions(-) create mode 100644 experimental/libbox/neighbor.go create mode 100644 experimental/libbox/neighbor_stub.go create mode 100644 route/neighbor_resolver_parse.go create mode 100644 route/neighbor_resolver_platform.go create mode 100644 route/neighbor_table_linux.go diff --git a/adapter/neighbor.go b/adapter/neighbor.go index 920398f674..d917db5b7a 100644 --- a/adapter/neighbor.go +++ b/adapter/neighbor.go @@ -5,9 +5,19 @@ import ( "net/netip" ) +type NeighborEntry struct { + Address netip.Addr + MACAddress net.HardwareAddr + Hostname string +} + type NeighborResolver interface { LookupMAC(address netip.Addr) (net.HardwareAddr, bool) LookupHostname(address netip.Addr) (string, bool) Start() error Close() error } + +type NeighborUpdateListener interface { + UpdateNeighborTable(entries []NeighborEntry) +} diff --git a/adapter/platform.go b/adapter/platform.go index fa4cbc2e45..fd96654811 100644 --- a/adapter/platform.go +++ b/adapter/platform.go @@ -36,6 +36,10 @@ type PlatformInterface interface { UsePlatformNotification() bool SendNotification(notification *Notification) error + + UsePlatformNeighborResolver() bool + StartNeighborMonitor(listener NeighborUpdateListener) error + CloseNeighborMonitor(listener NeighborUpdateListener) error } type FindConnectionOwnerRequest struct { diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 122425d293..54369bf770 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -144,6 +144,18 @@ func (s *platformInterfaceStub) SendNotification(notification *adapter.Notificat return nil } +func (s *platformInterfaceStub) UsePlatformNeighborResolver() bool { + return false +} + +func (s *platformInterfaceStub) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return os.ErrInvalid +} + +func (s *platformInterfaceStub) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return nil +} + func (s *platformInterfaceStub) UsePlatformLocalDNSTransport() bool { return false } diff --git a/experimental/libbox/neighbor.go b/experimental/libbox/neighbor.go new file mode 100644 index 0000000000..b2ded5f7a1 --- /dev/null +++ b/experimental/libbox/neighbor.go @@ -0,0 +1,135 @@ +//go:build linux + +package libbox + +import ( + "net" + "net/netip" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type NeighborEntry struct { + Address string + MACAddress string + Hostname string +} + +type NeighborEntryIterator interface { + Next() *NeighborEntry + HasNext() bool +} + +type NeighborSubscription struct { + done chan struct{} +} + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + return nil, E.Cause(err, "subscribe neighbor updates") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, connection, table) + return subscription, nil +} + +func (s *NeighborSubscription) Close() { + close(s.done) +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { + defer connection.Close() + for { + select { + case <-s.done: + return + default: + } + err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + changed := false + for _, message := range messages { + address, mac, isDelete, ok := route.ParseNeighborMessage(message) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} + +func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator { + entries := make([]*NeighborEntry, 0, len(table)) + for address, mac := range table { + entries = append(entries, &NeighborEntry{ + Address: address.String(), + MACAddress: mac.String(), + }) + } + return &neighborEntryIterator{entries} +} + +type neighborEntryIterator struct { + entries []*NeighborEntry +} + +func (i *neighborEntryIterator) HasNext() bool { + return len(i.entries) > 0 +} + +func (i *neighborEntryIterator) Next() *NeighborEntry { + if len(i.entries) == 0 { + return nil + } + entry := i.entries[0] + i.entries = i.entries[1:] + return entry +} diff --git a/experimental/libbox/neighbor_stub.go b/experimental/libbox/neighbor_stub.go new file mode 100644 index 0000000000..95f6dc7d6f --- /dev/null +++ b/experimental/libbox/neighbor_stub.go @@ -0,0 +1,24 @@ +//go:build !linux + +package libbox + +import "os" + +type NeighborEntry struct { + Address string + MACAddress string + Hostname string +} + +type NeighborEntryIterator interface { + Next() *NeighborEntry + HasNext() bool +} + +type NeighborSubscription struct{} + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + return nil, os.ErrInvalid +} + +func (s *NeighborSubscription) Close() {} diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 4db32a2226..759b14e88c 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -21,6 +21,12 @@ type PlatformInterface interface { SystemCertificates() StringIterator ClearDNSCache() SendNotification(notification *Notification) error + StartNeighborMonitor(listener NeighborUpdateListener) error + CloseNeighborMonitor(listener NeighborUpdateListener) error +} + +type NeighborUpdateListener interface { + UpdateNeighborTable(entries NeighborEntryIterator) } type ConnectionOwner struct { diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 0a841a1b20..58c6de41c1 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -220,6 +220,43 @@ func (w *platformInterfaceWrapper) SendNotification(notification *adapter.Notifi return w.iif.SendNotification((*Notification)(notification)) } +func (w *platformInterfaceWrapper) UsePlatformNeighborResolver() bool { + return true +} + +func (w *platformInterfaceWrapper) StartNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return w.iif.StartNeighborMonitor(&neighborUpdateListenerWrapper{listener: listener}) +} + +func (w *platformInterfaceWrapper) CloseNeighborMonitor(listener adapter.NeighborUpdateListener) error { + return w.iif.CloseNeighborMonitor(nil) +} + +type neighborUpdateListenerWrapper struct { + listener adapter.NeighborUpdateListener +} + +func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntryIterator) { + var result []adapter.NeighborEntry + for entries.HasNext() { + entry := entries.Next() + address, err := netip.ParseAddr(entry.Address) + if err != nil { + continue + } + macAddress, err := net.ParseMAC(entry.MACAddress) + if err != nil { + continue + } + result = append(result, adapter.NeighborEntry{ + Address: address, + MACAddress: macAddress, + Hostname: entry.Hostname, + }) + } + w.listener.UpdateNeighborTable(result) +} + func AvailablePort(startPort int32) (int32, error) { for port := int(startPort); ; port++ { if port > 65535 { diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go index 40db5766ad..111cc6f040 100644 --- a/route/neighbor_resolver_linux.go +++ b/route/neighbor_resolver_linux.go @@ -4,7 +4,6 @@ package route import ( "bufio" - "encoding/binary" "encoding/hex" "net" "net/netip" @@ -204,43 +203,17 @@ func (r *neighborResolver) subscribeNeighborUpdates() { continue } for _, message := range messages { - switch message.Header.Type { - case unix.RTM_NEWNEIGH: - var neighMessage rtnetlink.NeighMessage - unmarshalErr := neighMessage.UnmarshalBinary(message.Data) - if unmarshalErr != nil { - continue - } - if neighMessage.Attributes == nil { - continue - } - if neighMessage.Attributes.LLAddress == nil || len(neighMessage.Attributes.Address) == 0 { - continue - } - address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) - if !ok { - continue - } - r.access.Lock() - r.neighborIPToMAC[address] = slices.Clone(neighMessage.Attributes.LLAddress) - r.access.Unlock() - case unix.RTM_DELNEIGH: - var neighMessage rtnetlink.NeighMessage - unmarshalErr := neighMessage.UnmarshalBinary(message.Data) - if unmarshalErr != nil { - continue - } - if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { - continue - } - address, ok := netip.AddrFromSlice(neighMessage.Attributes.Address) - if !ok { - continue - } - r.access.Lock() + address, mac, isDelete, ok := ParseNeighborMessage(message) + if !ok { + continue + } + r.access.Lock() + if isDelete { delete(r.neighborIPToMAC, address) - r.access.Unlock() + } else { + r.neighborIPToMAC[address] = mac } + r.access.Unlock() } } } @@ -554,43 +527,3 @@ func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]ne } } } - -func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { - if len(duid) < 4 { - return nil, false - } - duidType := binary.BigEndian.Uint16(duid[0:2]) - hwType := binary.BigEndian.Uint16(duid[2:4]) - if hwType != 1 { - return nil, false - } - switch duidType { - case 1: - if len(duid) < 14 { - return nil, false - } - return net.HardwareAddr(slices.Clone(duid[8:14])), true - case 3: - if len(duid) < 10 { - return nil, false - } - return net.HardwareAddr(slices.Clone(duid[4:10])), true - } - return nil, false -} - -func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { - if !address.Is6() { - return nil, false - } - b := address.As16() - if b[11] != 0xff || b[12] != 0xfe { - return nil, false - } - return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true -} - -func parseDUID(s string) ([]byte, error) { - cleaned := strings.ReplaceAll(s, ":", "") - return hex.DecodeString(cleaned) -} diff --git a/route/neighbor_resolver_parse.go b/route/neighbor_resolver_parse.go new file mode 100644 index 0000000000..1979b7eabc --- /dev/null +++ b/route/neighbor_resolver_parse.go @@ -0,0 +1,50 @@ +package route + +import ( + "encoding/binary" + "encoding/hex" + "net" + "net/netip" + "slices" + "strings" +) + +func extractMACFromDUID(duid []byte) (net.HardwareAddr, bool) { + if len(duid) < 4 { + return nil, false + } + duidType := binary.BigEndian.Uint16(duid[0:2]) + hwType := binary.BigEndian.Uint16(duid[2:4]) + if hwType != 1 { + return nil, false + } + switch duidType { + case 1: + if len(duid) < 14 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[8:14])), true + case 3: + if len(duid) < 10 { + return nil, false + } + return net.HardwareAddr(slices.Clone(duid[4:10])), true + } + return nil, false +} + +func extractMACFromEUI64(address netip.Addr) (net.HardwareAddr, bool) { + if !address.Is6() { + return nil, false + } + b := address.As16() + if b[11] != 0xff || b[12] != 0xfe { + return nil, false + } + return net.HardwareAddr{b[8] ^ 0x02, b[9], b[10], b[13], b[14], b[15]}, true +} + +func parseDUID(s string) ([]byte, error) { + cleaned := strings.ReplaceAll(s, ":", "") + return hex.DecodeString(cleaned) +} diff --git a/route/neighbor_resolver_platform.go b/route/neighbor_resolver_platform.go new file mode 100644 index 0000000000..ddb9a99592 --- /dev/null +++ b/route/neighbor_resolver_platform.go @@ -0,0 +1,84 @@ +package route + +import ( + "net" + "net/netip" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/logger" +) + +type platformNeighborResolver struct { + logger logger.ContextLogger + platform adapter.PlatformInterface + access sync.RWMutex + ipToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string +} + +func newPlatformNeighborResolver(resolverLogger logger.ContextLogger, platform adapter.PlatformInterface) adapter.NeighborResolver { + return &platformNeighborResolver{ + logger: resolverLogger, + platform: platform, + ipToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + } +} + +func (r *platformNeighborResolver) Start() error { + return r.platform.StartNeighborMonitor(r) +} + +func (r *platformNeighborResolver) Close() error { + return r.platform.CloseNeighborMonitor(r) +} + +func (r *platformNeighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.ipToMAC[address] + if found { + return mac, true + } + return extractMACFromEUI64(address) +} + +func (r *platformNeighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, found := r.ipToMAC[address] + if !found { + mac, found = extractMACFromEUI64(address) + } + if !found { + return "", false + } + hostname, found = r.macToHostname[mac.String()] + return hostname, found +} + +func (r *platformNeighborResolver) UpdateNeighborTable(entries []adapter.NeighborEntry) { + ipToMAC := make(map[netip.Addr]net.HardwareAddr) + ipToHostname := make(map[netip.Addr]string) + macToHostname := make(map[string]string) + for _, entry := range entries { + ipToMAC[entry.Address] = entry.MACAddress + if entry.Hostname != "" { + ipToHostname[entry.Address] = entry.Hostname + macToHostname[entry.MACAddress.String()] = entry.Hostname + } + } + r.access.Lock() + r.ipToMAC = ipToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() + r.logger.Info("updated neighbor table: ", len(entries), " entries") +} diff --git a/route/neighbor_table_linux.go b/route/neighbor_table_linux.go new file mode 100644 index 0000000000..61a214fd3a --- /dev/null +++ b/route/neighbor_table_linux.go @@ -0,0 +1,68 @@ +//go:build linux + +package route + +import ( + "net" + "net/netip" + "slices" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func ReadNeighborEntries() ([]adapter.NeighborEntry, error) { + connection, err := rtnetlink.Dial(nil) + if err != nil { + return nil, E.Cause(err, "dial rtnetlink") + } + defer connection.Close() + neighbors, err := connection.Neigh.List() + if err != nil { + return nil, E.Cause(err, "list neighbors") + } + var entries []adapter.NeighborEntry + for _, neighbor := range neighbors { + if neighbor.Attributes == nil { + continue + } + if neighbor.Attributes.LLAddress == nil || len(neighbor.Attributes.Address) == 0 { + continue + } + address, ok := netip.AddrFromSlice(neighbor.Attributes.Address) + if !ok { + continue + } + entries = append(entries, adapter.NeighborEntry{ + Address: address, + MACAddress: slices.Clone(neighbor.Attributes.LLAddress), + }) + } + return entries, nil +} + +func ParseNeighborMessage(message netlink.Message) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) { + var neighMessage rtnetlink.NeighMessage + err := neighMessage.UnmarshalBinary(message.Data) + if err != nil { + return + } + if neighMessage.Attributes == nil || len(neighMessage.Attributes.Address) == 0 { + return + } + address, ok = netip.AddrFromSlice(neighMessage.Attributes.Address) + if !ok { + return + } + isDelete = message.Header.Type == unix.RTM_DELNEIGH + if !isDelete && neighMessage.Attributes.LLAddress == nil { + ok = false + return + } + macAddress = slices.Clone(neighMessage.Attributes.LLAddress) + return +} diff --git a/route/router.go b/route/router.go index 52eb9e4362..c6677d20f9 100644 --- a/route/router.go +++ b/route/router.go @@ -159,21 +159,34 @@ func (r *Router) Start(stage adapter.StartStage) error { } r.needFindNeighbor = needFindNeighbor if needFindNeighbor { - monitor.Start("initialize neighbor resolver") - resolver, err := newNeighborResolver(r.logger, r.leaseFiles) - monitor.Finish() - if err != nil { - if err != os.ErrInvalid { - r.logger.Warn(E.Cause(err, "create neighbor resolver")) - } - } else { - err = resolver.Start() + if r.platformInterface != nil && r.platformInterface.UsePlatformNeighborResolver() { + monitor.Start("initialize neighbor resolver") + resolver := newPlatformNeighborResolver(r.logger, r.platformInterface) + err := resolver.Start() + monitor.Finish() if err != nil { - r.logger.Warn(E.Cause(err, "start neighbor resolver")) + r.logger.Error(E.Cause(err, "start neighbor resolver")) } else { r.neighborResolver = resolver } } + if r.neighborResolver == nil { + monitor.Start("initialize neighbor resolver") + resolver, err := newNeighborResolver(r.logger, r.leaseFiles) + monitor.Finish() + if err != nil { + if err != os.ErrInvalid { + r.logger.Error(E.Cause(err, "create neighbor resolver")) + } + } else { + err = resolver.Start() + if err != nil { + r.logger.Error(E.Cause(err, "start neighbor resolver")) + } else { + r.neighborResolver = resolver + } + } + } } case adapter.StartStatePostStart: for i, rule := range r.rules { From eeb5dead2af3363704fb5c308bac96286d561754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Mar 2026 08:47:37 +0800 Subject: [PATCH 03/41] Add macOS support for MAC and hostname rule items --- experimental/libbox/neighbor.go | 86 +----- experimental/libbox/neighbor_darwin.go | 123 ++++++++ experimental/libbox/neighbor_linux.go | 88 ++++++ experimental/libbox/neighbor_stub.go | 19 +- experimental/libbox/platform.go | 1 + experimental/libbox/service.go | 6 +- route/neighbor_resolver_darwin.go | 239 +++++++++++++++ route/neighbor_resolver_lease.go | 386 +++++++++++++++++++++++++ route/neighbor_resolver_linux.go | 313 +------------------- route/neighbor_resolver_stub.go | 2 +- route/neighbor_table_darwin.go | 104 +++++++ route/router.go | 3 +- 12 files changed, 956 insertions(+), 414 deletions(-) create mode 100644 experimental/libbox/neighbor_darwin.go create mode 100644 experimental/libbox/neighbor_linux.go create mode 100644 route/neighbor_resolver_darwin.go create mode 100644 route/neighbor_resolver_lease.go create mode 100644 route/neighbor_table_darwin.go diff --git a/experimental/libbox/neighbor.go b/experimental/libbox/neighbor.go index b2ded5f7a1..e38aa8023f 100644 --- a/experimental/libbox/neighbor.go +++ b/experimental/libbox/neighbor.go @@ -1,23 +1,13 @@ -//go:build linux - package libbox import ( "net" "net/netip" - "slices" - "time" - - "github.com/sagernet/sing-box/route" - E "github.com/sagernet/sing/common/exceptions" - - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" ) type NeighborEntry struct { Address string - MACAddress string + MacAddress string Hostname string } @@ -30,88 +20,16 @@ type NeighborSubscription struct { done chan struct{} } -func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { - entries, err := route.ReadNeighborEntries() - if err != nil { - return nil, E.Cause(err, "initial neighbor dump") - } - table := make(map[netip.Addr]net.HardwareAddr) - for _, entry := range entries { - table[entry.Address] = entry.MACAddress - } - listener.UpdateNeighborTable(tableToIterator(table)) - connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - Groups: 1 << (unix.RTNLGRP_NEIGH - 1), - }) - if err != nil { - return nil, E.Cause(err, "subscribe neighbor updates") - } - subscription := &NeighborSubscription{ - done: make(chan struct{}), - } - go subscription.loop(listener, connection, table) - return subscription, nil -} - func (s *NeighborSubscription) Close() { close(s.done) } -func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { - defer connection.Close() - for { - select { - case <-s.done: - return - default: - } - err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) - if err != nil { - return - } - messages, err := connection.Receive() - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - continue - } - select { - case <-s.done: - return - default: - } - continue - } - changed := false - for _, message := range messages { - address, mac, isDelete, ok := route.ParseNeighborMessage(message) - if !ok { - continue - } - if isDelete { - if _, exists := table[address]; exists { - delete(table, address) - changed = true - } - } else { - existing, exists := table[address] - if !exists || !slices.Equal(existing, mac) { - table[address] = mac - changed = true - } - } - } - if changed { - listener.UpdateNeighborTable(tableToIterator(table)) - } - } -} - func tableToIterator(table map[netip.Addr]net.HardwareAddr) NeighborEntryIterator { entries := make([]*NeighborEntry, 0, len(table)) for address, mac := range table { entries = append(entries, &NeighborEntry{ Address: address.String(), - MACAddress: mac.String(), + MacAddress: mac.String(), }) } return &neighborEntryIterator{entries} diff --git a/experimental/libbox/neighbor_darwin.go b/experimental/libbox/neighbor_darwin.go new file mode 100644 index 0000000000..d7484a69b4 --- /dev/null +++ b/experimental/libbox/neighbor_darwin.go @@ -0,0 +1,123 @@ +//go:build darwin + +package libbox + +import ( + "net" + "net/netip" + "os" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + xroute "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) + if err != nil { + return nil, E.Cause(err, "open route socket") + } + err = unix.SetNonblock(routeSocket, true) + if err != nil { + unix.Close(routeSocket) + return nil, E.Cause(err, "set route socket nonblock") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, routeSocket, table) + return subscription, nil +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, routeSocket int, table map[netip.Addr]net.HardwareAddr) { + routeSocketFile := os.NewFile(uintptr(routeSocket), "route") + defer routeSocketFile.Close() + buffer := buf.NewPacket() + defer buffer.Release() + for { + select { + case <-s.done: + return + default: + } + tv := unix.NsecToTimeval(int64(3 * time.Second)) + _ = unix.SetsockoptTimeval(routeSocket, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + n, err := routeSocketFile.Read(buffer.FreeBytes()) + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + messages, err := xroute.ParseRIB(xroute.RIBTypeRoute, buffer.FreeBytes()[:n]) + if err != nil { + continue + } + changed := false + for _, message := range messages { + routeMessage, isRouteMessage := message.(*xroute.RouteMessage) + if !isRouteMessage { + continue + } + if routeMessage.Flags&unix.RTF_LLINFO == 0 { + continue + } + address, mac, isDelete, ok := route.ParseRouteNeighborMessage(routeMessage) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} + +func ReadBootpdLeases() NeighborEntryIterator { + leaseIPToMAC, ipToHostname, macToHostname := route.ReloadLeaseFiles([]string{"/var/db/dhcpd_leases"}) + entries := make([]*NeighborEntry, 0, len(leaseIPToMAC)) + for address, mac := range leaseIPToMAC { + entry := &NeighborEntry{ + Address: address.String(), + MacAddress: mac.String(), + } + hostname, found := ipToHostname[address] + if !found { + hostname = macToHostname[mac.String()] + } + entry.Hostname = hostname + entries = append(entries, entry) + } + return &neighborEntryIterator{entries} +} diff --git a/experimental/libbox/neighbor_linux.go b/experimental/libbox/neighbor_linux.go new file mode 100644 index 0000000000..ae10bdd2ee --- /dev/null +++ b/experimental/libbox/neighbor_linux.go @@ -0,0 +1,88 @@ +//go:build linux + +package libbox + +import ( + "net" + "net/netip" + "slices" + "time" + + "github.com/sagernet/sing-box/route" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { + entries, err := route.ReadNeighborEntries() + if err != nil { + return nil, E.Cause(err, "initial neighbor dump") + } + table := make(map[netip.Addr]net.HardwareAddr) + for _, entry := range entries { + table[entry.Address] = entry.MACAddress + } + listener.UpdateNeighborTable(tableToIterator(table)) + connection, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + Groups: 1 << (unix.RTNLGRP_NEIGH - 1), + }) + if err != nil { + return nil, E.Cause(err, "subscribe neighbor updates") + } + subscription := &NeighborSubscription{ + done: make(chan struct{}), + } + go subscription.loop(listener, connection, table) + return subscription, nil +} + +func (s *NeighborSubscription) loop(listener NeighborUpdateListener, connection *netlink.Conn, table map[netip.Addr]net.HardwareAddr) { + defer connection.Close() + for { + select { + case <-s.done: + return + default: + } + err := connection.SetReadDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + messages, err := connection.Receive() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-s.done: + return + default: + } + continue + } + changed := false + for _, message := range messages { + address, mac, isDelete, ok := route.ParseNeighborMessage(message) + if !ok { + continue + } + if isDelete { + if _, exists := table[address]; exists { + delete(table, address) + changed = true + } + } else { + existing, exists := table[address] + if !exists || !slices.Equal(existing, mac) { + table[address] = mac + changed = true + } + } + } + if changed { + listener.UpdateNeighborTable(tableToIterator(table)) + } + } +} diff --git a/experimental/libbox/neighbor_stub.go b/experimental/libbox/neighbor_stub.go index 95f6dc7d6f..d465bc7bb0 100644 --- a/experimental/libbox/neighbor_stub.go +++ b/experimental/libbox/neighbor_stub.go @@ -1,24 +1,9 @@ -//go:build !linux +//go:build !linux && !darwin package libbox import "os" -type NeighborEntry struct { - Address string - MACAddress string - Hostname string -} - -type NeighborEntryIterator interface { - Next() *NeighborEntry - HasNext() bool -} - -type NeighborSubscription struct{} - -func SubscribeNeighborTable(listener NeighborUpdateListener) (*NeighborSubscription, error) { +func SubscribeNeighborTable(_ NeighborUpdateListener) (*NeighborSubscription, error) { return nil, os.ErrInvalid } - -func (s *NeighborSubscription) Close() {} diff --git a/experimental/libbox/platform.go b/experimental/libbox/platform.go index 759b14e88c..e65d08184b 100644 --- a/experimental/libbox/platform.go +++ b/experimental/libbox/platform.go @@ -23,6 +23,7 @@ type PlatformInterface interface { SendNotification(notification *Notification) error StartNeighborMonitor(listener NeighborUpdateListener) error CloseNeighborMonitor(listener NeighborUpdateListener) error + RegisterMyInterface(name string) } type NeighborUpdateListener interface { diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 58c6de41c1..7becf9fac3 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -78,6 +78,7 @@ func (w *platformInterfaceWrapper) OpenInterface(options *tun.Options, platformO } options.FileDescriptor = dupFd w.myTunName = options.Name + w.iif.RegisterMyInterface(options.Name) return tun.New(*options) } @@ -240,11 +241,14 @@ func (w *neighborUpdateListenerWrapper) UpdateNeighborTable(entries NeighborEntr var result []adapter.NeighborEntry for entries.HasNext() { entry := entries.Next() + if entry == nil { + continue + } address, err := netip.ParseAddr(entry.Address) if err != nil { continue } - macAddress, err := net.ParseMAC(entry.MACAddress) + macAddress, err := net.ParseMAC(entry.MacAddress) if err != nil { continue } diff --git a/route/neighbor_resolver_darwin.go b/route/neighbor_resolver_darwin.go new file mode 100644 index 0000000000..a8884ae628 --- /dev/null +++ b/route/neighbor_resolver_darwin.go @@ -0,0 +1,239 @@ +//go:build darwin + +package route + +import ( + "net" + "net/netip" + "os" + "sync" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +var defaultLeaseFiles = []string{ + "/var/db/dhcpd_leases", + "/tmp/dhcp.leases", +} + +type neighborResolver struct { + logger logger.ContextLogger + leaseFiles []string + access sync.RWMutex + neighborIPToMAC map[netip.Addr]net.HardwareAddr + leaseIPToMAC map[netip.Addr]net.HardwareAddr + ipToHostname map[netip.Addr]string + macToHostname map[string]string + watcher *fswatch.Watcher + done chan struct{} +} + +func newNeighborResolver(resolverLogger logger.ContextLogger, leaseFiles []string) (adapter.NeighborResolver, error) { + if len(leaseFiles) == 0 { + for _, path := range defaultLeaseFiles { + info, err := os.Stat(path) + if err == nil && info.Size() > 0 { + leaseFiles = append(leaseFiles, path) + } + } + } + return &neighborResolver{ + logger: resolverLogger, + leaseFiles: leaseFiles, + neighborIPToMAC: make(map[netip.Addr]net.HardwareAddr), + leaseIPToMAC: make(map[netip.Addr]net.HardwareAddr), + ipToHostname: make(map[netip.Addr]string), + macToHostname: make(map[string]string), + done: make(chan struct{}), + }, nil +} + +func (r *neighborResolver) Start() error { + err := r.loadNeighborTable() + if err != nil { + r.logger.Warn(E.Cause(err, "load neighbor table")) + } + r.doReloadLeaseFiles() + go r.subscribeNeighborUpdates() + if len(r.leaseFiles) > 0 { + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: r.leaseFiles, + Logger: r.logger, + Callback: func(_ string) { + r.doReloadLeaseFiles() + }, + }) + if err != nil { + r.logger.Warn(E.Cause(err, "create lease file watcher")) + } else { + r.watcher = watcher + err = watcher.Start() + if err != nil { + r.logger.Warn(E.Cause(err, "start lease file watcher")) + } + } + } + return nil +} + +func (r *neighborResolver) Close() error { + close(r.done) + if r.watcher != nil { + return r.watcher.Close() + } + return nil +} + +func (r *neighborResolver) LookupMAC(address netip.Addr) (net.HardwareAddr, bool) { + r.access.RLock() + defer r.access.RUnlock() + mac, found := r.neighborIPToMAC[address] + if found { + return mac, true + } + mac, found = r.leaseIPToMAC[address] + if found { + return mac, true + } + mac, found = extractMACFromEUI64(address) + if found { + return mac, true + } + return nil, false +} + +func (r *neighborResolver) LookupHostname(address netip.Addr) (string, bool) { + r.access.RLock() + defer r.access.RUnlock() + hostname, found := r.ipToHostname[address] + if found { + return hostname, true + } + mac, macFound := r.neighborIPToMAC[address] + if !macFound { + mac, macFound = r.leaseIPToMAC[address] + } + if !macFound { + mac, macFound = extractMACFromEUI64(address) + } + if macFound { + hostname, found = r.macToHostname[mac.String()] + if found { + return hostname, true + } + } + return "", false +} + +func (r *neighborResolver) loadNeighborTable() error { + entries, err := ReadNeighborEntries() + if err != nil { + return err + } + r.access.Lock() + defer r.access.Unlock() + for _, entry := range entries { + r.neighborIPToMAC[entry.Address] = entry.MACAddress + } + return nil +} + +func (r *neighborResolver) subscribeNeighborUpdates() { + routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) + if err != nil { + r.logger.Warn(E.Cause(err, "subscribe neighbor updates")) + return + } + err = unix.SetNonblock(routeSocket, true) + if err != nil { + unix.Close(routeSocket) + r.logger.Warn(E.Cause(err, "set route socket nonblock")) + return + } + routeSocketFile := os.NewFile(uintptr(routeSocket), "route") + defer routeSocketFile.Close() + buffer := buf.NewPacket() + defer buffer.Release() + for { + select { + case <-r.done: + return + default: + } + err = setReadDeadline(routeSocketFile, 3*time.Second) + if err != nil { + r.logger.Warn(E.Cause(err, "set route socket read deadline")) + return + } + n, err := routeSocketFile.Read(buffer.FreeBytes()) + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + select { + case <-r.done: + return + default: + } + r.logger.Warn(E.Cause(err, "receive neighbor update")) + continue + } + messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.FreeBytes()[:n]) + if err != nil { + continue + } + for _, message := range messages { + routeMessage, isRouteMessage := message.(*route.RouteMessage) + if !isRouteMessage { + continue + } + if routeMessage.Flags&unix.RTF_LLINFO == 0 { + continue + } + address, mac, isDelete, ok := ParseRouteNeighborMessage(routeMessage) + if !ok { + continue + } + r.access.Lock() + if isDelete { + delete(r.neighborIPToMAC, address) + } else { + r.neighborIPToMAC[address] = mac + } + r.access.Unlock() + } + } +} + +func (r *neighborResolver) doReloadLeaseFiles() { + leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles) + r.access.Lock() + r.leaseIPToMAC = leaseIPToMAC + r.ipToHostname = ipToHostname + r.macToHostname = macToHostname + r.access.Unlock() +} + +func setReadDeadline(file *os.File, timeout time.Duration) error { + rawConn, err := file.SyscallConn() + if err != nil { + return err + } + var controlErr error + err = rawConn.Control(func(fd uintptr) { + tv := unix.NsecToTimeval(int64(timeout)) + controlErr = unix.SetsockoptTimeval(int(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + }) + if err != nil { + return err + } + return controlErr +} diff --git a/route/neighbor_resolver_lease.go b/route/neighbor_resolver_lease.go new file mode 100644 index 0000000000..e3f9c0b464 --- /dev/null +++ b/route/neighbor_resolver_lease.go @@ -0,0 +1,386 @@ +package route + +import ( + "bufio" + "encoding/hex" + "net" + "net/netip" + "os" + "strconv" + "strings" + "time" +) + +func parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + file, err := os.Open(path) + if err != nil { + return + } + defer file.Close() + if strings.HasSuffix(path, "dhcpd_leases") { + parseBootpdLeases(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases4.csv") { + parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "kea-leases6.csv") { + parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) + return + } + if strings.HasSuffix(path, "dhcpd.leases") { + parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) + return + } + parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) +} + +func ReloadLeaseFiles(leaseFiles []string) (leaseIPToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + leaseIPToMAC = make(map[netip.Addr]net.HardwareAddr) + ipToHostname = make(map[netip.Addr]string) + macToHostname = make(map[string]string) + for _, path := range leaseFiles { + parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) + } + return +} + +func parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "duid ") { + continue + } + if strings.HasPrefix(line, "# ") { + parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) + continue + } + fields := strings.Fields(line) + if len(fields) < 4 { + continue + } + expiry, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + continue + } + if expiry != 0 && expiry < now { + continue + } + if strings.Contains(fields[1], ":") { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + ipToMAC[address] = mac + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } else { + var mac net.HardwareAddr + if len(fields) >= 5 { + duid, duidErr := parseDUID(fields[4]) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + hostname := fields[3] + if hostname != "*" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } + } +} + +func parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + fields := strings.Fields(line) + if len(fields) < 5 { + return + } + validTime, err := strconv.ParseInt(fields[4], 10, 64) + if err != nil { + return + } + if validTime == 0 { + return + } + if validTime > 0 && validTime < time.Now().Unix() { + return + } + hostname := fields[3] + if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { + hostname = "" + } + if len(fields) >= 8 && fields[2] == "ipv4" { + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + return + } + addressField := fields[7] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + return + } + address = address.Unmap() + ipToMAC[address] = mac + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + return + } + var mac net.HardwareAddr + duidHex := fields[1] + duidBytes, hexErr := hex.DecodeString(duidHex) + if hexErr == nil { + mac, _ = extractMACFromDUID(duidBytes) + } + for i := 7; i < len(fields); i++ { + addressField := fields[i] + slashIndex := strings.IndexByte(addressField, '/') + if slashIndex >= 0 { + addressField = addressField[:slashIndex] + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) + if !addrOK { + continue + } + address = address.Unmap() + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentHostname string + var currentActive bool + var inLease bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { + ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) + if addrOK { + currentIP = parsed.Unmap() + inLease = true + currentMAC = nil + currentHostname = "" + currentActive = false + } + continue + } + if line == "}" && inLease { + if currentActive && currentMAC != nil { + ipToMAC[currentIP] = currentMAC + if currentHostname != "" { + ipToHostname[currentIP] = currentHostname + macToHostname[currentMAC.String()] = currentHostname + } + } else { + delete(ipToMAC, currentIP) + delete(ipToHostname, currentIP) + } + inLease = false + continue + } + if !inLease { + continue + } + if strings.HasPrefix(line, "hardware ethernet ") { + macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") + parsed, macErr := net.ParseMAC(macString) + if macErr == nil { + currentMAC = parsed + } + } else if strings.HasPrefix(line, "client-hostname ") { + hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") + hostname = strings.Trim(hostname, "\"") + if hostname != "" { + currentHostname = hostname + } + } else if strings.HasPrefix(line, "binding state ") { + state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") + currentActive = state == "active" + } + } +} + +func parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 10 { + continue + } + if fields[9] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + mac, macErr := net.ParseMAC(fields[1]) + if macErr != nil { + continue + } + ipToMAC[address] = mac + hostname := "" + if len(fields) > 8 { + hostname = fields[8] + } + if hostname != "" { + ipToHostname[address] = hostname + macToHostname[mac.String()] = hostname + } + } +} + +func parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + scanner := bufio.NewScanner(file) + firstLine := true + for scanner.Scan() { + if firstLine { + firstLine = false + continue + } + fields := strings.Split(scanner.Text(), ",") + if len(fields) < 14 { + continue + } + if fields[13] != "0" { + continue + } + address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) + if !addrOK { + continue + } + address = address.Unmap() + var mac net.HardwareAddr + if fields[12] != "" { + mac, _ = net.ParseMAC(fields[12]) + } + if mac == nil { + duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) + if duidErr == nil { + mac, _ = extractMACFromDUID(duid) + } + } + hostname := "" + if len(fields) > 11 { + hostname = fields[11] + } + if mac != nil { + ipToMAC[address] = mac + } + if hostname != "" { + ipToHostname[address] = hostname + if mac != nil { + macToHostname[mac.String()] = hostname + } + } + } +} + +func parseBootpdLeases(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { + now := time.Now().Unix() + scanner := bufio.NewScanner(file) + var currentName string + var currentIP netip.Addr + var currentMAC net.HardwareAddr + var currentLease int64 + var inBlock bool + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "{" { + inBlock = true + currentName = "" + currentIP = netip.Addr{} + currentMAC = nil + currentLease = 0 + continue + } + if line == "}" && inBlock { + if currentMAC != nil && currentIP.IsValid() { + if currentLease == 0 || currentLease >= now { + ipToMAC[currentIP] = currentMAC + if currentName != "" { + ipToHostname[currentIP] = currentName + macToHostname[currentMAC.String()] = currentName + } + } + } + inBlock = false + continue + } + if !inBlock { + continue + } + key, value, found := strings.Cut(line, "=") + if !found { + continue + } + switch key { + case "name": + currentName = value + case "ip_address": + parsed, addrOK := netip.AddrFromSlice(net.ParseIP(value)) + if addrOK { + currentIP = parsed.Unmap() + } + case "hw_address": + typeAndMAC, hasSep := strings.CutPrefix(value, "1,") + if hasSep { + mac, macErr := net.ParseMAC(typeAndMAC) + if macErr == nil { + currentMAC = mac + } + } + case "lease": + leaseHex := strings.TrimPrefix(value, "0x") + parsed, parseErr := strconv.ParseInt(leaseHex, 16, 64) + if parseErr == nil { + currentLease = parsed + } + } + } +} diff --git a/route/neighbor_resolver_linux.go b/route/neighbor_resolver_linux.go index 111cc6f040..b7991b4c89 100644 --- a/route/neighbor_resolver_linux.go +++ b/route/neighbor_resolver_linux.go @@ -3,14 +3,10 @@ package route import ( - "bufio" - "encoding/hex" "net" "net/netip" "os" "slices" - "strconv" - "strings" "sync" "time" @@ -69,14 +65,14 @@ func (r *neighborResolver) Start() error { if err != nil { r.logger.Warn(E.Cause(err, "load neighbor table")) } - r.reloadLeaseFiles() + r.doReloadLeaseFiles() go r.subscribeNeighborUpdates() if len(r.leaseFiles) > 0 { watcher, err := fswatch.NewWatcher(fswatch.Options{ Path: r.leaseFiles, Logger: r.logger, Callback: func(_ string) { - r.reloadLeaseFiles() + r.doReloadLeaseFiles() }, }) if err != nil { @@ -218,312 +214,11 @@ func (r *neighborResolver) subscribeNeighborUpdates() { } } -func (r *neighborResolver) reloadLeaseFiles() { - leaseIPToMAC := make(map[netip.Addr]net.HardwareAddr) - ipToHostname := make(map[netip.Addr]string) - macToHostname := make(map[string]string) - for _, path := range r.leaseFiles { - r.parseLeaseFile(path, leaseIPToMAC, ipToHostname, macToHostname) - } +func (r *neighborResolver) doReloadLeaseFiles() { + leaseIPToMAC, ipToHostname, macToHostname := ReloadLeaseFiles(r.leaseFiles) r.access.Lock() r.leaseIPToMAC = leaseIPToMAC r.ipToHostname = ipToHostname r.macToHostname = macToHostname r.access.Unlock() } - -func (r *neighborResolver) parseLeaseFile(path string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - file, err := os.Open(path) - if err != nil { - return - } - defer file.Close() - if strings.HasSuffix(path, "kea-leases4.csv") { - r.parseKeaCSV4(file, ipToMAC, ipToHostname, macToHostname) - return - } - if strings.HasSuffix(path, "kea-leases6.csv") { - r.parseKeaCSV6(file, ipToMAC, ipToHostname, macToHostname) - return - } - if strings.HasSuffix(path, "dhcpd.leases") { - r.parseISCDhcpd(file, ipToMAC, ipToHostname, macToHostname) - return - } - r.parseDnsmasqOdhcpd(file, ipToMAC, ipToHostname, macToHostname) -} - -func (r *neighborResolver) parseDnsmasqOdhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - now := time.Now().Unix() - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "duid ") { - continue - } - if strings.HasPrefix(line, "# ") { - r.parseOdhcpdLine(line[2:], ipToMAC, ipToHostname, macToHostname) - continue - } - fields := strings.Fields(line) - if len(fields) < 4 { - continue - } - expiry, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - continue - } - if expiry != 0 && expiry < now { - continue - } - if strings.Contains(fields[1], ":") { - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) - if !addrOK { - continue - } - address = address.Unmap() - ipToMAC[address] = mac - hostname := fields[3] - if hostname != "*" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - } else { - var mac net.HardwareAddr - if len(fields) >= 5 { - duid, duidErr := parseDUID(fields[4]) - if duidErr == nil { - mac, _ = extractMACFromDUID(duid) - } - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[2])) - if !addrOK { - continue - } - address = address.Unmap() - if mac != nil { - ipToMAC[address] = mac - } - hostname := fields[3] - if hostname != "*" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } - } -} - -func (r *neighborResolver) parseOdhcpdLine(line string, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - fields := strings.Fields(line) - if len(fields) < 5 { - return - } - validTime, err := strconv.ParseInt(fields[4], 10, 64) - if err != nil { - return - } - if validTime == 0 { - return - } - if validTime > 0 && validTime < time.Now().Unix() { - return - } - hostname := fields[3] - if hostname == "-" || strings.HasPrefix(hostname, `broken\x20`) { - hostname = "" - } - if len(fields) >= 8 && fields[2] == "ipv4" { - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - return - } - addressField := fields[7] - slashIndex := strings.IndexByte(addressField, '/') - if slashIndex >= 0 { - addressField = addressField[:slashIndex] - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) - if !addrOK { - return - } - address = address.Unmap() - ipToMAC[address] = mac - if hostname != "" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - return - } - var mac net.HardwareAddr - duidHex := fields[1] - duidBytes, hexErr := hex.DecodeString(duidHex) - if hexErr == nil { - mac, _ = extractMACFromDUID(duidBytes) - } - for i := 7; i < len(fields); i++ { - addressField := fields[i] - slashIndex := strings.IndexByte(addressField, '/') - if slashIndex >= 0 { - addressField = addressField[:slashIndex] - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(addressField)) - if !addrOK { - continue - } - address = address.Unmap() - if mac != nil { - ipToMAC[address] = mac - } - if hostname != "" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } -} - -func (r *neighborResolver) parseISCDhcpd(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - var currentIP netip.Addr - var currentMAC net.HardwareAddr - var currentHostname string - var currentActive bool - var inLease bool - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "lease ") && strings.HasSuffix(line, "{") { - ipString := strings.TrimSuffix(strings.TrimPrefix(line, "lease "), " {") - parsed, addrOK := netip.AddrFromSlice(net.ParseIP(ipString)) - if addrOK { - currentIP = parsed.Unmap() - inLease = true - currentMAC = nil - currentHostname = "" - currentActive = false - } - continue - } - if line == "}" && inLease { - if currentActive && currentMAC != nil { - ipToMAC[currentIP] = currentMAC - if currentHostname != "" { - ipToHostname[currentIP] = currentHostname - macToHostname[currentMAC.String()] = currentHostname - } - } else { - delete(ipToMAC, currentIP) - delete(ipToHostname, currentIP) - } - inLease = false - continue - } - if !inLease { - continue - } - if strings.HasPrefix(line, "hardware ethernet ") { - macString := strings.TrimSuffix(strings.TrimPrefix(line, "hardware ethernet "), ";") - parsed, macErr := net.ParseMAC(macString) - if macErr == nil { - currentMAC = parsed - } - } else if strings.HasPrefix(line, "client-hostname ") { - hostname := strings.TrimSuffix(strings.TrimPrefix(line, "client-hostname "), ";") - hostname = strings.Trim(hostname, "\"") - if hostname != "" { - currentHostname = hostname - } - } else if strings.HasPrefix(line, "binding state ") { - state := strings.TrimSuffix(strings.TrimPrefix(line, "binding state "), ";") - currentActive = state == "active" - } - } -} - -func (r *neighborResolver) parseKeaCSV4(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - firstLine := true - for scanner.Scan() { - if firstLine { - firstLine = false - continue - } - fields := strings.Split(scanner.Text(), ",") - if len(fields) < 10 { - continue - } - if fields[9] != "0" { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) - if !addrOK { - continue - } - address = address.Unmap() - mac, macErr := net.ParseMAC(fields[1]) - if macErr != nil { - continue - } - ipToMAC[address] = mac - hostname := "" - if len(fields) > 8 { - hostname = fields[8] - } - if hostname != "" { - ipToHostname[address] = hostname - macToHostname[mac.String()] = hostname - } - } -} - -func (r *neighborResolver) parseKeaCSV6(file *os.File, ipToMAC map[netip.Addr]net.HardwareAddr, ipToHostname map[netip.Addr]string, macToHostname map[string]string) { - scanner := bufio.NewScanner(file) - firstLine := true - for scanner.Scan() { - if firstLine { - firstLine = false - continue - } - fields := strings.Split(scanner.Text(), ",") - if len(fields) < 14 { - continue - } - if fields[13] != "0" { - continue - } - address, addrOK := netip.AddrFromSlice(net.ParseIP(fields[0])) - if !addrOK { - continue - } - address = address.Unmap() - var mac net.HardwareAddr - if fields[12] != "" { - mac, _ = net.ParseMAC(fields[12]) - } - if mac == nil { - duid, duidErr := hex.DecodeString(strings.ReplaceAll(fields[1], ":", "")) - if duidErr == nil { - mac, _ = extractMACFromDUID(duid) - } - } - hostname := "" - if len(fields) > 11 { - hostname = fields[11] - } - if mac != nil { - ipToMAC[address] = mac - } - if hostname != "" { - ipToHostname[address] = hostname - if mac != nil { - macToHostname[mac.String()] = hostname - } - } - } -} diff --git a/route/neighbor_resolver_stub.go b/route/neighbor_resolver_stub.go index 9288892a8d..177a1fccbc 100644 --- a/route/neighbor_resolver_stub.go +++ b/route/neighbor_resolver_stub.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux && !darwin package route diff --git a/route/neighbor_table_darwin.go b/route/neighbor_table_darwin.go new file mode 100644 index 0000000000..8ca2d0f0b7 --- /dev/null +++ b/route/neighbor_table_darwin.go @@ -0,0 +1,104 @@ +//go:build darwin + +package route + +import ( + "net" + "net/netip" + "syscall" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +func ReadNeighborEntries() ([]adapter.NeighborEntry, error) { + var entries []adapter.NeighborEntry + ipv4Entries, err := readNeighborEntriesAF(syscall.AF_INET) + if err != nil { + return nil, E.Cause(err, "read IPv4 neighbors") + } + entries = append(entries, ipv4Entries...) + ipv6Entries, err := readNeighborEntriesAF(syscall.AF_INET6) + if err != nil { + return nil, E.Cause(err, "read IPv6 neighbors") + } + entries = append(entries, ipv6Entries...) + return entries, nil +} + +func readNeighborEntriesAF(addressFamily int) ([]adapter.NeighborEntry, error) { + rib, err := route.FetchRIB(addressFamily, route.RIBType(syscall.NET_RT_FLAGS), syscall.RTF_LLINFO) + if err != nil { + return nil, err + } + messages, err := route.ParseRIB(route.RIBType(syscall.NET_RT_FLAGS), rib) + if err != nil { + return nil, err + } + var entries []adapter.NeighborEntry + for _, message := range messages { + routeMessage, isRouteMessage := message.(*route.RouteMessage) + if !isRouteMessage { + continue + } + address, macAddress, ok := parseRouteNeighborEntry(routeMessage) + if !ok { + continue + } + entries = append(entries, adapter.NeighborEntry{ + Address: address, + MACAddress: macAddress, + }) + } + return entries, nil +} + +func parseRouteNeighborEntry(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, ok bool) { + if len(message.Addrs) <= unix.RTAX_GATEWAY { + return + } + gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr) + if !isLinkAddr || len(gateway.Addr) < 6 { + return + } + switch destination := message.Addrs[unix.RTAX_DST].(type) { + case *route.Inet4Addr: + address = netip.AddrFrom4(destination.IP) + case *route.Inet6Addr: + address = netip.AddrFrom16(destination.IP) + default: + return + } + macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr))) + copy(macAddress, gateway.Addr) + ok = true + return +} + +func ParseRouteNeighborMessage(message *route.RouteMessage) (address netip.Addr, macAddress net.HardwareAddr, isDelete bool, ok bool) { + isDelete = message.Type == unix.RTM_DELETE + if len(message.Addrs) <= unix.RTAX_GATEWAY { + return + } + switch destination := message.Addrs[unix.RTAX_DST].(type) { + case *route.Inet4Addr: + address = netip.AddrFrom4(destination.IP) + case *route.Inet6Addr: + address = netip.AddrFrom16(destination.IP) + default: + return + } + if !isDelete { + gateway, isLinkAddr := message.Addrs[unix.RTAX_GATEWAY].(*route.LinkAddr) + if !isLinkAddr || len(gateway.Addr) < 6 { + return + } + macAddress = net.HardwareAddr(make([]byte, len(gateway.Addr))) + copy(macAddress, gateway.Addr) + } + ok = true + return +} diff --git a/route/router.go b/route/router.go index c6677d20f9..2815d5095b 100644 --- a/route/router.go +++ b/route/router.go @@ -169,8 +169,7 @@ func (r *Router) Start(stage adapter.StartStage) error { } else { r.neighborResolver = resolver } - } - if r.neighborResolver == nil { + } else { monitor.Start("initialize neighbor resolver") resolver, err := newNeighborResolver(r.logger, r.leaseFiles) monitor.Finish() From 77e51035bdd1c497fd7531714c9bffb0c4dc5b1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Mar 2026 21:43:21 +0800 Subject: [PATCH 04/41] documentation: Update descriptions for neighbor rules --- docs/configuration/dns/rule.md | 4 +- docs/configuration/dns/rule.zh.md | 4 +- docs/configuration/route/index.md | 17 ++++++-- docs/configuration/route/index.zh.md | 17 ++++++-- docs/configuration/route/rule.md | 4 +- docs/configuration/route/rule.zh.md | 4 +- docs/configuration/shared/neighbor.md | 49 ++++++++++++++++++++++++ docs/configuration/shared/neighbor.zh.md | 49 ++++++++++++++++++++++++ mkdocs.yml | 1 + 9 files changed, 133 insertions(+), 16 deletions(-) create mode 100644 docs/configuration/shared/neighbor.md create mode 100644 docs/configuration/shared/neighbor.zh.md diff --git a/docs/configuration/dns/rule.md b/docs/configuration/dns/rule.md index f8a7ac4c37..0b3e56da69 100644 --- a/docs/configuration/dns/rule.md +++ b/docs/configuration/dns/rule.md @@ -425,7 +425,7 @@ Match default interface address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device MAC address. @@ -435,7 +435,7 @@ Match source device MAC address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device hostname from DHCP leases. diff --git a/docs/configuration/dns/rule.zh.md b/docs/configuration/dns/rule.zh.md index 421fdfb5c1..82f85648f0 100644 --- a/docs/configuration/dns/rule.zh.md +++ b/docs/configuration/dns/rule.zh.md @@ -424,7 +424,7 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备 MAC 地址。 @@ -434,7 +434,7 @@ Available values: `wifi`, `cellular`, `ethernet` and `other`. !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备从 DHCP 租约获取的主机名。 diff --git a/docs/configuration/route/index.md b/docs/configuration/route/index.md index 01e405614e..40104b619e 100644 --- a/docs/configuration/route/index.md +++ b/docs/configuration/route/index.md @@ -40,6 +40,7 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_process": false, "find_neighbor": false, "dhcp_lease_files": [], "default_domain_resolver": "", // or {} @@ -114,17 +115,25 @@ Set routing mark by default. Takes no effect if `outbound.routing_mark` is set. +#### find_process + +!!! quote "" + + Only supported on Linux, Windows, and macOS. + +Enable process search for logging when no `process_name`, `process_path`, `package_name`, `user` or `user_id` rules exist. + #### find_neighbor !!! question "Since sing-box 1.14.0" !!! quote "" - Only supported on Linux. + Only supported on Linux and macOS. -Enable neighbor resolution for source MAC address and hostname lookup. +Enable neighbor resolution for logging when no `source_mac_address` or `source_hostname` rules exist. -Required for `source_mac_address` and `source_hostname` rule items. +See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. #### dhcp_lease_files @@ -132,7 +141,7 @@ Required for `source_mac_address` and `source_hostname` rule items. !!! quote "" - Only supported on Linux. + Only supported on Linux and macOS. Custom DHCP lease file paths for hostname and MAC address resolution. diff --git a/docs/configuration/route/index.zh.md b/docs/configuration/route/index.zh.md index 2c12a58eb3..4977b084e2 100644 --- a/docs/configuration/route/index.zh.md +++ b/docs/configuration/route/index.zh.md @@ -42,6 +42,7 @@ icon: material/alert-decagram "override_android_vpn": false, "default_interface": "", "default_mark": 0, + "find_process": false, "find_neighbor": false, "dhcp_lease_files": [], "default_network_strategy": "", @@ -113,17 +114,25 @@ icon: material/alert-decagram 如果设置了 `outbound.routing_mark` 设置,则不生效。 +#### find_process + +!!! quote "" + + 仅支持 Linux、Windows 和 macOS。 + +在没有 `process_name`、`process_path`、`package_name`、`user` 或 `user_id` 规则时启用进程搜索以输出日志。 + #### find_neighbor !!! question "自 sing-box 1.14.0 起" !!! quote "" - 仅支持 Linux。 + 仅支持 Linux 和 macOS。 -启用邻居解析以查找源 MAC 地址和主机名。 +在没有 `source_mac_address` 或 `source_hostname` 规则时启用邻居解析以输出日志。 -`source_mac_address` 和 `source_hostname` 规则项需要此选项。 +参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 #### dhcp_lease_files @@ -131,7 +140,7 @@ icon: material/alert-decagram !!! quote "" - 仅支持 Linux。 + 仅支持 Linux 和 macOS。 用于主机名和 MAC 地址解析的自定义 DHCP 租约文件路径。 diff --git a/docs/configuration/route/rule.md b/docs/configuration/route/rule.md index 16c100c1c0..37e651c924 100644 --- a/docs/configuration/route/rule.md +++ b/docs/configuration/route/rule.md @@ -466,7 +466,7 @@ Match specified outbounds' preferred routes. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device MAC address. @@ -476,7 +476,7 @@ Match source device MAC address. !!! quote "" - Only supported on Linux with `route.find_neighbor` enabled. + Only supported on Linux, macOS, or in graphical clients on Android and macOS. See [Neighbor Resolution](/configuration/shared/neighbor/) for setup. Match source device hostname from DHCP leases. diff --git a/docs/configuration/route/rule.zh.md b/docs/configuration/route/rule.zh.md index f21e6677b8..181a57398d 100644 --- a/docs/configuration/route/rule.zh.md +++ b/docs/configuration/route/rule.zh.md @@ -464,7 +464,7 @@ icon: material/new-box !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备 MAC 地址。 @@ -474,7 +474,7 @@ icon: material/new-box !!! quote "" - 仅支持 Linux,且需要 `route.find_neighbor` 已启用。 + 仅支持 Linux、macOS,或在 Android 和 macOS 图形客户端中支持。参阅 [邻居解析](/configuration/shared/neighbor/) 了解设置方法。 匹配源设备从 DHCP 租约获取的主机名。 diff --git a/docs/configuration/shared/neighbor.md b/docs/configuration/shared/neighbor.md new file mode 100644 index 0000000000..c67d995ebe --- /dev/null +++ b/docs/configuration/shared/neighbor.md @@ -0,0 +1,49 @@ +--- +icon: material/lan +--- + +# Neighbor Resolution + +Match LAN devices by MAC address and hostname using +[`source_mac_address`](/configuration/route/rule/#source_mac_address) and +[`source_hostname`](/configuration/route/rule/#source_hostname) rule items. + +Neighbor resolution is automatically enabled when these rule items exist. +Use [`route.find_neighbor`](/configuration/route/#find_neighbor) to force enable it for logging without rules. + +## Linux + +Works natively. No special setup required. + +Hostname resolution requires DHCP lease files, +automatically detected from common DHCP servers (dnsmasq, odhcpd, ISC dhcpd, Kea). +Custom paths can be set via [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files). + +## Android + +!!! quote "" + + Only supported in graphical clients. + +Requires Android 11 or above and ROOT. + +Must use [VPNHotspot](https://github.com/Mygod/VPNHotspot) to share the VPN connection. +ROM built-in features like "Use VPN for connected devices" can share VPN +but cannot provide MAC address or hostname information. + +Set **IP Masquerade Mode** to **None** in VPNHotspot settings. + +Only route/DNS rules are supported. TUN include/exclude routes are not supported. + +### Hostname Visibility + +Hostname is only visible in sing-box if it is visible in VPNHotspot. +For Apple devices, change **Private Wi-Fi Address** from **Rotating** to **Fixed** in the Wi-Fi settings +of the connected network. Non-Apple devices are always visible. + +## macOS + +Requires the standalone version (macOS system extension). +The App Store version can share the VPN as a hotspot but does not support MAC address or hostname reading. + +See [VPN Hotspot](/manual/misc/vpn-hotspot/#macos) for Internet Sharing setup. diff --git a/docs/configuration/shared/neighbor.zh.md b/docs/configuration/shared/neighbor.zh.md new file mode 100644 index 0000000000..96297fcb57 --- /dev/null +++ b/docs/configuration/shared/neighbor.zh.md @@ -0,0 +1,49 @@ +--- +icon: material/lan +--- + +# 邻居解析 + +通过 +[`source_mac_address`](/configuration/route/rule/#source_mac_address) 和 +[`source_hostname`](/configuration/route/rule/#source_hostname) 规则项匹配局域网设备的 MAC 地址和主机名。 + +当这些规则项存在时,邻居解析自动启用。 +使用 [`route.find_neighbor`](/configuration/route/#find_neighbor) 可在没有规则时强制启用以输出日志。 + +## Linux + +原生支持,无需特殊设置。 + +主机名解析需要 DHCP 租约文件, +自动从常见 DHCP 服务器(dnsmasq、odhcpd、ISC dhcpd、Kea)检测。 +可通过 [`route.dhcp_lease_files`](/configuration/route/#dhcp_lease_files) 设置自定义路径。 + +## Android + +!!! quote "" + + 仅在图形客户端中支持。 + +需要 Android 11 或以上版本和 ROOT。 + +必须使用 [VPNHotspot](https://github.com/Mygod/VPNHotspot) 共享 VPN 连接。 +ROM 自带的「通过 VPN 共享连接」等功能可以共享 VPN, +但无法提供 MAC 地址或主机名信息。 + +在 VPNHotspot 设置中将 **IP 遮掩模式** 设为 **无**。 + +仅支持路由/DNS 规则。不支持 TUN 的 include/exclude 路由。 + +### 设备可见性 + +MAC 地址和主机名仅在 VPNHotspot 中可见时 sing-box 才能读取。 +对于 Apple 设备,需要在所连接网络的 Wi-Fi 设置中将**私有无线局域网地址**从**轮替**改为**固定**。 +非 Apple 设备始终可见。 + +## macOS + +需要独立版本(macOS 系统扩展)。 +App Store 版本可以共享 VPN 热点但不支持 MAC 地址或主机名读取。 + +参阅 [VPN 热点](/manual/misc/vpn-hotspot/#macos) 了解互联网共享设置。 diff --git a/mkdocs.yml b/mkdocs.yml index e295926610..5f95842a5d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -129,6 +129,7 @@ nav: - UDP over TCP: configuration/shared/udp-over-tcp.md - TCP Brutal: configuration/shared/tcp-brutal.md - Wi-Fi State: configuration/shared/wifi-state.md + - Neighbor Resolution: configuration/shared/neighbor.md - Endpoint: - configuration/endpoint/index.md - WireGuard: configuration/endpoint/wireguard.md From 47742abe93b4a9a962c37e98ac32e78d3793359b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 11 Mar 2026 18:35:51 +0800 Subject: [PATCH 05/41] cronet-go: Update chromium to 145.0.7632.159 --- .github/CRONET_GO_VERSION | 2 +- go.mod | 4 ++-- go.sum | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/CRONET_GO_VERSION b/.github/CRONET_GO_VERSION index 47b09f9b6b..40dfcd0d14 100644 --- a/.github/CRONET_GO_VERSION +++ b/.github/CRONET_GO_VERSION @@ -1 +1 @@ -2fef65f9dba90ddb89a87d00a6eb6165487c10c1 +ea7cd33752aed62603775af3df946c1b83f4b0b3 diff --git a/go.mod b/go.mod index 405cc56444..4726fd753e 100644 --- a/go.mod +++ b/go.mod @@ -29,8 +29,8 @@ require ( github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cors v1.2.1 - github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 - github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 + github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc + github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gomobile v0.1.12 github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1 diff --git a/go.sum b/go.sum index 4d53eb40cd..a297e3a9b2 100644 --- a/go.sum +++ b/go.sum @@ -162,10 +162,10 @@ github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkk github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM= github.com/sagernet/cors v1.2.1 h1:Cv5Z8y9YSD6Gm+qSpNrL3LO4lD3eQVvbFYJSG7JCMHQ= github.com/sagernet/cors v1.2.1/go.mod h1:O64VyOjjhrkLmQIjF4KGRrJO/5dVXFdpEmCW/eISRAI= -github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9 h1:xq5Yr10jXEppD3cnGjE3WENaB6D0YsZu6KptZ8d3054= -github.com/sagernet/cronet-go v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw= -github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9 h1:uxQyy6Y/boOuecVA66tf79JgtoRGfeDJcfYZZLKVA5E= -github.com/sagernet/cronet-go/all v0.0.0-20260309102448-2fef65f9dba9/go.mod h1:Xm6cCvs0/twozC1JYNq0sVlOVmcSGzV7YON1XGcD97w= +github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc h1:YK7PwJT0irRAEui9ASdXSxcE2BOVQipWMF/A1Ogt+7c= +github.com/sagernet/cronet-go v0.0.0-20260309100020-c128886ff3fc/go.mod h1:hwFHBEjjthyEquDULbr4c4ucMedp8Drb6Jvm2kt/0Bw= +github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc h1:EJPHOqk23IuBsTjXK9OXqkNxPbKOBWKRmviQoCcriAs= +github.com/sagernet/cronet-go/all v0.0.0-20260309100020-c128886ff3fc/go.mod h1:8aty0RW96DrJSMWXO6bRPMBJEjuqq5JWiOIi4bCRzFA= github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9 h1:Qi0IKBpoPP3qZqIXuOKMsT2dv+l/MLWMyBHDMLRw2EA= github.com/sagernet/cronet-go/lib/android_386 v0.0.0-20260309101654-0cbdcfddded9/go.mod h1:XXDwdjX/T8xftoeJxQmbBoYXZp8MAPFR2CwbFuTpEtw= github.com/sagernet/cronet-go/lib/android_amd64 v0.0.0-20260309101654-0cbdcfddded9 h1:p+wCMjOhj46SpSD/AJeTGgkCcbyA76FyH631XZatyU8= From 2132e68d3a79ef6f853e994d3d624daf906411dc Mon Sep 17 00:00:00 2001 From: nekohasekai Date: Mon, 23 Mar 2026 20:04:36 +0800 Subject: [PATCH 06/41] Refactor ACME support to certificate provider --- adapter/certificate/adapter.go | 21 + adapter/certificate/manager.go | 158 +++++ adapter/certificate/registry.go | 72 ++ adapter/certificate_provider.go | 38 ++ box.go | 123 ++-- common/tls/acme.go | 37 +- common/tls/acme_logger.go | 41 ++ common/tls/reality_server.go | 4 + common/tls/std_server.go | 179 ++++- constant/proxy.go | 62 +- .../tls/acme_contstant.go => constant/tls.go | 2 +- docs/configuration/inbound/tun.md | 2 +- docs/configuration/index.md | 5 +- docs/configuration/index.zh.md | 5 +- .../shared/certificate-provider/acme.md | 150 +++++ .../shared/certificate-provider/acme.zh.md | 145 ++++ .../cloudflare-origin-ca.md | 82 +++ .../cloudflare-origin-ca.zh.md | 82 +++ .../shared/certificate-provider/index.md | 32 + .../shared/certificate-provider/index.zh.md | 32 + .../shared/certificate-provider/tailscale.md | 27 + .../certificate-provider/tailscale.zh.md | 27 + docs/configuration/shared/dns01_challenge.md | 53 ++ .../shared/dns01_challenge.zh.md | 53 ++ docs/configuration/shared/tls.md | 31 +- docs/configuration/shared/tls.zh.md | 29 +- docs/deprecated.md | 30 +- docs/deprecated.zh.md | 32 +- docs/migration.md | 77 +++ docs/migration.zh.md | 77 +++ experimental/deprecated/constants.go | 10 + experimental/libbox/config.go | 2 +- include/acme.go | 12 + include/acme_stub.go | 20 + include/registry.go | 14 +- include/tailscale.go | 5 + include/tailscale_stub.go | 7 + mkdocs.yml | 7 + option/acme.go | 106 +++ option/certificate_provider.go | 100 +++ option/options.go | 44 +- option/origin_ca.go | 76 +++ option/tailscale.go | 4 + option/tls.go | 10 +- protocol/tailscale/certificate_provider.go | 98 +++ service/acme/service.go | 411 ++++++++++++ service/acme/stub.go | 3 + service/origin_ca/service.go | 618 ++++++++++++++++++ 48 files changed, 3083 insertions(+), 172 deletions(-) create mode 100644 adapter/certificate/adapter.go create mode 100644 adapter/certificate/manager.go create mode 100644 adapter/certificate/registry.go create mode 100644 adapter/certificate_provider.go create mode 100644 common/tls/acme_logger.go rename common/tls/acme_contstant.go => constant/tls.go (69%) create mode 100644 docs/configuration/shared/certificate-provider/acme.md create mode 100644 docs/configuration/shared/certificate-provider/acme.zh.md create mode 100644 docs/configuration/shared/certificate-provider/cloudflare-origin-ca.md create mode 100644 docs/configuration/shared/certificate-provider/cloudflare-origin-ca.zh.md create mode 100644 docs/configuration/shared/certificate-provider/index.md create mode 100644 docs/configuration/shared/certificate-provider/index.zh.md create mode 100644 docs/configuration/shared/certificate-provider/tailscale.md create mode 100644 docs/configuration/shared/certificate-provider/tailscale.zh.md create mode 100644 include/acme.go create mode 100644 include/acme_stub.go create mode 100644 option/acme.go create mode 100644 option/certificate_provider.go create mode 100644 option/origin_ca.go create mode 100644 protocol/tailscale/certificate_provider.go create mode 100644 service/acme/service.go create mode 100644 service/acme/stub.go create mode 100644 service/origin_ca/service.go diff --git a/adapter/certificate/adapter.go b/adapter/certificate/adapter.go new file mode 100644 index 0000000000..802020c1e4 --- /dev/null +++ b/adapter/certificate/adapter.go @@ -0,0 +1,21 @@ +package certificate + +type Adapter struct { + providerType string + providerTag string +} + +func NewAdapter(providerType string, providerTag string) Adapter { + return Adapter{ + providerType: providerType, + providerTag: providerTag, + } +} + +func (a *Adapter) Type() string { + return a.providerType +} + +func (a *Adapter) Tag() string { + return a.providerTag +} diff --git a/adapter/certificate/manager.go b/adapter/certificate/manager.go new file mode 100644 index 0000000000..e4b9b535bb --- /dev/null +++ b/adapter/certificate/manager.go @@ -0,0 +1,158 @@ +package certificate + +import ( + "context" + "os" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" +) + +var _ adapter.CertificateProviderManager = (*Manager)(nil) + +type Manager struct { + logger log.ContextLogger + registry adapter.CertificateProviderRegistry + access sync.Mutex + started bool + stage adapter.StartStage + providers []adapter.CertificateProviderService + providerByTag map[string]adapter.CertificateProviderService +} + +func NewManager(logger log.ContextLogger, registry adapter.CertificateProviderRegistry) *Manager { + return &Manager{ + logger: logger, + registry: registry, + providerByTag: make(map[string]adapter.CertificateProviderService), + } +} + +func (m *Manager) Start(stage adapter.StartStage) error { + m.access.Lock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + providers := m.providers + m.access.Unlock() + for _, provider := range providers { + name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]" + m.logger.Trace(stage, " ", name) + startTime := time.Now() + err := adapter.LegacyStart(provider, stage) + if err != nil { + return E.Cause(err, stage, " ", name) + } + m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)") + } + return nil +} + +func (m *Manager) Close() error { + m.access.Lock() + defer m.access.Unlock() + if !m.started { + return nil + } + m.started = false + providers := m.providers + m.providers = nil + monitor := taskmonitor.New(m.logger, C.StopTimeout) + var err error + for _, provider := range providers { + name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]" + m.logger.Trace("close ", name) + startTime := time.Now() + monitor.Start("close ", name) + err = E.Append(err, provider.Close(), func(err error) error { + return E.Cause(err, "close ", name) + }) + monitor.Finish() + m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)") + } + return err +} + +func (m *Manager) CertificateProviders() []adapter.CertificateProviderService { + m.access.Lock() + defer m.access.Unlock() + return m.providers +} + +func (m *Manager) Get(tag string) (adapter.CertificateProviderService, bool) { + m.access.Lock() + provider, found := m.providerByTag[tag] + m.access.Unlock() + return provider, found +} + +func (m *Manager) Remove(tag string) error { + m.access.Lock() + provider, found := m.providerByTag[tag] + if !found { + m.access.Unlock() + return os.ErrInvalid + } + delete(m.providerByTag, tag) + index := common.Index(m.providers, func(it adapter.CertificateProviderService) bool { + return it == provider + }) + if index == -1 { + panic("invalid certificate provider index") + } + m.providers = append(m.providers[:index], m.providers[index+1:]...) + started := m.started + m.access.Unlock() + if started { + return provider.Close() + } + return nil +} + +func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error { + provider, err := m.registry.Create(ctx, logger, tag, providerType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]" + for _, stage := range adapter.ListStartStages { + m.logger.Trace(stage, " ", name) + startTime := time.Now() + err = adapter.LegacyStart(provider, stage) + if err != nil { + return E.Cause(err, stage, " ", name) + } + m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)") + } + } + if existsProvider, loaded := m.providerByTag[tag]; loaded { + if m.started { + err = existsProvider.Close() + if err != nil { + return E.Cause(err, "close certificate-provider/", existsProvider.Type(), "[", existsProvider.Tag(), "]") + } + } + existsIndex := common.Index(m.providers, func(it adapter.CertificateProviderService) bool { + return it == existsProvider + }) + if existsIndex == -1 { + panic("invalid certificate provider index") + } + m.providers = append(m.providers[:existsIndex], m.providers[existsIndex+1:]...) + } + m.providers = append(m.providers, provider) + m.providerByTag[tag] = provider + return nil +} diff --git a/adapter/certificate/registry.go b/adapter/certificate/registry.go new file mode 100644 index 0000000000..5a080f2ccc --- /dev/null +++ b/adapter/certificate/registry.go @@ -0,0 +1,72 @@ +package certificate + +import ( + "context" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.CertificateProviderService, error) + +func Register[Options any](registry *Registry, providerType string, constructor ConstructorFunc[Options]) { + registry.register(providerType, func() any { + return new(Options) + }, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.CertificateProviderService, error) { + var options *Options + if rawOptions != nil { + options = rawOptions.(*Options) + } + return constructor(ctx, logger, tag, common.PtrValueOrDefault(options)) + }) +} + +var _ adapter.CertificateProviderRegistry = (*Registry)(nil) + +type ( + optionsConstructorFunc func() any + constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.CertificateProviderService, error) +) + +type Registry struct { + access sync.Mutex + optionsType map[string]optionsConstructorFunc + constructor map[string]constructorFunc +} + +func NewRegistry() *Registry { + return &Registry{ + optionsType: make(map[string]optionsConstructorFunc), + constructor: make(map[string]constructorFunc), + } +} + +func (m *Registry) CreateOptions(providerType string) (any, bool) { + m.access.Lock() + defer m.access.Unlock() + optionsConstructor, loaded := m.optionsType[providerType] + if !loaded { + return nil, false + } + return optionsConstructor(), true +} + +func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (adapter.CertificateProviderService, error) { + m.access.Lock() + defer m.access.Unlock() + constructor, loaded := m.constructor[providerType] + if !loaded { + return nil, E.New("certificate provider type not found: " + providerType) + } + return constructor(ctx, logger, tag, options) +} + +func (m *Registry) register(providerType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { + m.access.Lock() + defer m.access.Unlock() + m.optionsType[providerType] = optionsConstructor + m.constructor[providerType] = constructor +} diff --git a/adapter/certificate_provider.go b/adapter/certificate_provider.go new file mode 100644 index 0000000000..70bdeb8838 --- /dev/null +++ b/adapter/certificate_provider.go @@ -0,0 +1,38 @@ +package adapter + +import ( + "context" + "crypto/tls" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +type CertificateProvider interface { + GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) +} + +type ACMECertificateProvider interface { + CertificateProvider + GetACMENextProtos() []string +} + +type CertificateProviderService interface { + Lifecycle + Type() string + Tag() string + CertificateProvider +} + +type CertificateProviderRegistry interface { + option.CertificateProviderOptionsRegistry + Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (CertificateProviderService, error) +} + +type CertificateProviderManager interface { + Lifecycle + CertificateProviders() []CertificateProviderService + Get(tag string) (CertificateProviderService, bool) + Remove(tag string) error + Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error +} diff --git a/box.go b/box.go index fe116b3175..a765e21d8f 100644 --- a/box.go +++ b/box.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + boxCertificate "github.com/sagernet/sing-box/adapter/certificate" "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" @@ -37,20 +38,21 @@ import ( var _ adapter.SimpleLifecycle = (*Box)(nil) type Box struct { - createdAt time.Time - logFactory log.Factory - logger log.ContextLogger - network *route.NetworkManager - endpoint *endpoint.Manager - inbound *inbound.Manager - outbound *outbound.Manager - service *boxService.Manager - dnsTransport *dns.TransportManager - dnsRouter *dns.Router - connection *route.ConnectionManager - router *route.Router - internalService []adapter.LifecycleService - done chan struct{} + createdAt time.Time + logFactory log.Factory + logger log.ContextLogger + network *route.NetworkManager + endpoint *endpoint.Manager + inbound *inbound.Manager + outbound *outbound.Manager + service *boxService.Manager + certificateProvider *boxCertificate.Manager + dnsTransport *dns.TransportManager + dnsRouter *dns.Router + connection *route.ConnectionManager + router *route.Router + internalService []adapter.LifecycleService + done chan struct{} } type Options struct { @@ -66,6 +68,7 @@ func Context( endpointRegistry adapter.EndpointRegistry, dnsTransportRegistry adapter.DNSTransportRegistry, serviceRegistry adapter.ServiceRegistry, + certificateProviderRegistry adapter.CertificateProviderRegistry, ) context.Context { if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || service.FromContext[adapter.InboundRegistry](ctx) == nil { @@ -90,6 +93,10 @@ func Context( ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry) ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry) } + if service.FromContext[adapter.CertificateProviderRegistry](ctx) == nil { + ctx = service.ContextWith[option.CertificateProviderOptionsRegistry](ctx, certificateProviderRegistry) + ctx = service.ContextWith[adapter.CertificateProviderRegistry](ctx, certificateProviderRegistry) + } return ctx } @@ -106,6 +113,7 @@ func New(options Options) (*Box, error) { outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx) serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx) + certificateProviderRegistry := service.FromContext[adapter.CertificateProviderRegistry](ctx) if endpointRegistry == nil { return nil, E.New("missing endpoint registry in context") @@ -122,6 +130,9 @@ func New(options Options) (*Box, error) { if serviceRegistry == nil { return nil, E.New("missing service registry in context") } + if certificateProviderRegistry == nil { + return nil, E.New("missing certificate provider registry in context") + } ctx = pause.WithDefaultManager(ctx) experimentalOptions := common.PtrValueOrDefault(options.Experimental) @@ -179,11 +190,13 @@ func New(options Options) (*Box, error) { outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final) serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry) + certificateProviderManager := boxCertificate.NewManager(logFactory.NewLogger("certificate-provider"), certificateProviderRegistry) service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager) service.MustRegister[adapter.ServiceManager](ctx, serviceManager) + service.MustRegister[adapter.CertificateProviderManager](ctx, certificateProviderManager) dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions) service.MustRegister[adapter.DNSRouter](ctx, dnsRouter) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions, dnsOptions) @@ -272,6 +285,24 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize inbound[", i, "]") } } + for i, serviceOptions := range options.Services { + var tag string + if serviceOptions.Tag != "" { + tag = serviceOptions.Tag + } else { + tag = F.ToString(i) + } + err = serviceManager.Create( + ctx, + logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")), + tag, + serviceOptions.Type, + serviceOptions.Options, + ) + if err != nil { + return nil, E.Cause(err, "initialize service[", i, "]") + } + } for i, outboundOptions := range options.Outbounds { var tag string if outboundOptions.Tag != "" { @@ -298,22 +329,22 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize outbound[", i, "]") } } - for i, serviceOptions := range options.Services { + for i, certificateProviderOptions := range options.CertificateProviders { var tag string - if serviceOptions.Tag != "" { - tag = serviceOptions.Tag + if certificateProviderOptions.Tag != "" { + tag = certificateProviderOptions.Tag } else { tag = F.ToString(i) } - err = serviceManager.Create( + err = certificateProviderManager.Create( ctx, - logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")), + logFactory.NewLogger(F.ToString("certificate-provider/", certificateProviderOptions.Type, "[", tag, "]")), tag, - serviceOptions.Type, - serviceOptions.Options, + certificateProviderOptions.Type, + certificateProviderOptions.Options, ) if err != nil { - return nil, E.Cause(err, "initialize service[", i, "]") + return nil, E.Cause(err, "initialize certificate provider[", i, "]") } } outboundManager.Initialize(func() (adapter.Outbound, error) { @@ -383,20 +414,21 @@ func New(options Options) (*Box, error) { internalServices = append(internalServices, adapter.NewLifecycleService(ntpService, "ntp service")) } return &Box{ - network: networkManager, - endpoint: endpointManager, - inbound: inboundManager, - outbound: outboundManager, - dnsTransport: dnsTransportManager, - service: serviceManager, - dnsRouter: dnsRouter, - connection: connectionManager, - router: router, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - internalService: internalServices, - done: make(chan struct{}), + network: networkManager, + endpoint: endpointManager, + inbound: inboundManager, + outbound: outboundManager, + dnsTransport: dnsTransportManager, + service: serviceManager, + certificateProvider: certificateProviderManager, + dnsRouter: dnsRouter, + connection: connectionManager, + router: router, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + internalService: internalServices, + done: make(chan struct{}), }, nil } @@ -450,7 +482,7 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service, s.certificateProvider) if err != nil { return err } @@ -470,11 +502,19 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateStart, s.endpoint) if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateStart, s.certificateProvider) + if err != nil { + return err + } + err = adapter.Start(s.logger, adapter.StartStateStart, s.inbound, s.service) + if err != nil { + return err + } + err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.endpoint, s.certificateProvider, s.inbound, s.service) if err != nil { return err } @@ -482,7 +522,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.endpoint, s.certificateProvider, s.inbound, s.service) if err != nil { return err } @@ -506,8 +546,9 @@ func (s *Box) Close() error { service adapter.Lifecycle }{ {"service", s.service}, - {"endpoint", s.endpoint}, {"inbound", s.inbound}, + {"certificate-provider", s.certificateProvider}, + {"endpoint", s.endpoint}, {"outbound", s.outbound}, {"router", s.router}, {"connection", s.connection}, diff --git a/common/tls/acme.go b/common/tls/acme.go index c96e002c8a..d576fc6b1e 100644 --- a/common/tls/acme.go +++ b/common/tls/acme.go @@ -38,37 +38,6 @@ func (w *acmeWrapper) Close() error { return nil } -type acmeLogWriter struct { - logger logger.Logger -} - -func (w *acmeLogWriter) Write(p []byte) (n int, err error) { - logLine := strings.ReplaceAll(string(p), " ", ": ") - switch { - case strings.HasPrefix(logLine, "error: "): - w.logger.Error(logLine[7:]) - case strings.HasPrefix(logLine, "warn: "): - w.logger.Warn(logLine[6:]) - case strings.HasPrefix(logLine, "info: "): - w.logger.Info(logLine[6:]) - case strings.HasPrefix(logLine, "debug: "): - w.logger.Debug(logLine[7:]) - default: - w.logger.Debug(logLine) - } - return len(p), nil -} - -func (w *acmeLogWriter) Sync() error { - return nil -} - -func encoderConfig() zapcore.EncoderConfig { - config := zap.NewProductionEncoderConfig() - config.TimeKey = zapcore.OmitKey - return config -} - func startACME(ctx context.Context, logger logger.Logger, options option.InboundACMEOptions) (*tls.Config, adapter.SimpleLifecycle, error) { var acmeServer string switch options.Provider { @@ -91,8 +60,8 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound storage = certmagic.Default.Storage } zapLogger := zap.New(zapcore.NewCore( - zapcore.NewConsoleEncoder(encoderConfig()), - &acmeLogWriter{logger: logger}, + zapcore.NewConsoleEncoder(ACMEEncoderConfig()), + &ACMELogWriter{Logger: logger}, zap.DebugLevel, )) config := &certmagic.Config{ @@ -158,7 +127,7 @@ func startACME(ctx context.Context, logger logger.Logger, options option.Inbound } else { tlsConfig = &tls.Config{ GetCertificate: config.GetCertificate, - NextProtos: []string{ACMETLS1Protocol}, + NextProtos: []string{C.ACMETLS1Protocol}, } } return tlsConfig, &acmeWrapper{ctx: ctx, cfg: config, cache: cache, domain: options.Domain}, nil diff --git a/common/tls/acme_logger.go b/common/tls/acme_logger.go new file mode 100644 index 0000000000..cb3a1e3ce3 --- /dev/null +++ b/common/tls/acme_logger.go @@ -0,0 +1,41 @@ +package tls + +import ( + "strings" + + "github.com/sagernet/sing/common/logger" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type ACMELogWriter struct { + Logger logger.Logger +} + +func (w *ACMELogWriter) Write(p []byte) (n int, err error) { + logLine := strings.ReplaceAll(string(p), " ", ": ") + switch { + case strings.HasPrefix(logLine, "error: "): + w.Logger.Error(logLine[7:]) + case strings.HasPrefix(logLine, "warn: "): + w.Logger.Warn(logLine[6:]) + case strings.HasPrefix(logLine, "info: "): + w.Logger.Info(logLine[6:]) + case strings.HasPrefix(logLine, "debug: "): + w.Logger.Debug(logLine[7:]) + default: + w.Logger.Debug(logLine) + } + return len(p), nil +} + +func (w *ACMELogWriter) Sync() error { + return nil +} + +func ACMEEncoderConfig() zapcore.EncoderConfig { + config := zap.NewProductionEncoderConfig() + config.TimeKey = zapcore.OmitKey + return config +} diff --git a/common/tls/reality_server.go b/common/tls/reality_server.go index 5fc684756b..c2e70733a3 100644 --- a/common/tls/reality_server.go +++ b/common/tls/reality_server.go @@ -32,6 +32,10 @@ type RealityServerConfig struct { func NewRealityServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) { var tlsConfig utls.RealityConfig + if options.CertificateProvider != nil { + return nil, E.New("certificate_provider is unavailable in reality") + } + //nolint:staticcheck if options.ACME != nil && len(options.ACME.Domain) > 0 { return nil, E.New("acme is unavailable in reality") } diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 760c4b3a7f..86584cd482 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -13,19 +13,87 @@ import ( "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/service" ) var errInsecureUnused = E.New("tls: insecure unused") +type managedCertificateProvider interface { + adapter.CertificateProvider + adapter.SimpleLifecycle +} + +type sharedCertificateProvider struct { + tag string + manager adapter.CertificateProviderManager + provider adapter.CertificateProviderService +} + +func (p *sharedCertificateProvider) Start() error { + provider, found := p.manager.Get(p.tag) + if !found { + return E.New("certificate provider not found: ", p.tag) + } + p.provider = provider + return nil +} + +func (p *sharedCertificateProvider) Close() error { + return nil +} + +func (p *sharedCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return p.provider.GetCertificate(hello) +} + +func (p *sharedCertificateProvider) GetACMENextProtos() []string { + return getACMENextProtos(p.provider) +} + +type inlineCertificateProvider struct { + provider adapter.CertificateProviderService +} + +func (p *inlineCertificateProvider) Start() error { + for _, stage := range adapter.ListStartStages { + err := adapter.LegacyStart(p.provider, stage) + if err != nil { + return err + } + } + return nil +} + +func (p *inlineCertificateProvider) Close() error { + return p.provider.Close() +} + +func (p *inlineCertificateProvider) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return p.provider.GetCertificate(hello) +} + +func (p *inlineCertificateProvider) GetACMENextProtos() []string { + return getACMENextProtos(p.provider) +} + +func getACMENextProtos(provider adapter.CertificateProvider) []string { + if acmeProvider, isACME := provider.(adapter.ACMECertificateProvider); isACME { + return acmeProvider.GetACMENextProtos() + } + return nil +} + type STDServerConfig struct { access sync.RWMutex config *tls.Config logger log.Logger + certificateProvider managedCertificateProvider acmeService adapter.SimpleLifecycle certificate []byte key []byte @@ -53,18 +121,17 @@ func (c *STDServerConfig) SetServerName(serverName string) { func (c *STDServerConfig) NextProtos() []string { c.access.RLock() defer c.access.RUnlock() - if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { + if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol { return c.config.NextProtos[1:] - } else { - return c.config.NextProtos } + return c.config.NextProtos } func (c *STDServerConfig) SetNextProtos(nextProto []string) { c.access.Lock() defer c.access.Unlock() config := c.config.Clone() - if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { + if c.hasACMEALPN() && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == C.ACMETLS1Protocol { config.NextProtos = append(c.config.NextProtos[:1], nextProto...) } else { config.NextProtos = nextProto @@ -72,6 +139,18 @@ func (c *STDServerConfig) SetNextProtos(nextProto []string) { c.config = config } +func (c *STDServerConfig) hasACMEALPN() bool { + if c.acmeService != nil { + return true + } + if c.certificateProvider != nil { + if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME { + return len(acmeProvider.GetACMENextProtos()) > 0 + } + } + return false +} + func (c *STDServerConfig) STDConfig() (*STDConfig, error) { return c.config, nil } @@ -91,15 +170,39 @@ func (c *STDServerConfig) Clone() Config { } func (c *STDServerConfig) Start() error { + if c.certificateProvider != nil { + err := c.certificateProvider.Start() + if err != nil { + return err + } + if acmeProvider, isACME := c.certificateProvider.(adapter.ACMECertificateProvider); isACME { + nextProtos := acmeProvider.GetACMENextProtos() + if len(nextProtos) > 0 { + c.access.Lock() + config := c.config.Clone() + mergedNextProtos := append([]string{}, nextProtos...) + for _, nextProto := range config.NextProtos { + if !common.Contains(mergedNextProtos, nextProto) { + mergedNextProtos = append(mergedNextProtos, nextProto) + } + } + config.NextProtos = mergedNextProtos + c.config = config + c.access.Unlock() + } + } + } if c.acmeService != nil { - return c.acmeService.Start() - } else { - err := c.startWatcher() + err := c.acmeService.Start() if err != nil { - c.logger.Warn("create fsnotify watcher: ", err) + return err } - return nil } + err := c.startWatcher() + if err != nil { + c.logger.Warn("create fsnotify watcher: ", err) + } + return nil } func (c *STDServerConfig) startWatcher() error { @@ -203,23 +306,34 @@ func (c *STDServerConfig) certificateUpdated(path string) error { } func (c *STDServerConfig) Close() error { - if c.acmeService != nil { - return c.acmeService.Close() - } - if c.watcher != nil { - return c.watcher.Close() - } - return nil + return common.Close(c.certificateProvider, c.acmeService, c.watcher) } func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.InboundTLSOptions) (ServerConfig, error) { if !options.Enabled { return nil, nil } + //nolint:staticcheck + if options.CertificateProvider != nil && options.ACME != nil { + return nil, E.New("certificate_provider and acme are mutually exclusive") + } var tlsConfig *tls.Config + var certificateProvider managedCertificateProvider var acmeService adapter.SimpleLifecycle var err error - if options.ACME != nil && len(options.ACME.Domain) > 0 { + if options.CertificateProvider != nil { + certificateProvider, err = newCertificateProvider(ctx, logger, options.CertificateProvider) + if err != nil { + return nil, err + } + tlsConfig = &tls.Config{ + GetCertificate: certificateProvider.GetCertificate, + } + if options.Insecure { + return nil, errInsecureUnused + } + } else if options.ACME != nil && len(options.ACME.Domain) > 0 { //nolint:staticcheck + deprecated.Report(ctx, deprecated.OptionInlineACME) //nolint:staticcheck tlsConfig, acmeService, err = startACME(ctx, logger, common.PtrValueOrDefault(options.ACME)) if err != nil { @@ -272,7 +386,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option. certificate []byte key []byte ) - if acmeService == nil { + if certificateProvider == nil && acmeService == nil { if len(options.Certificate) > 0 { certificate = []byte(strings.Join(options.Certificate, "\n")) } else if options.CertificatePath != "" { @@ -360,6 +474,7 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option. serverConfig := &STDServerConfig{ config: tlsConfig, logger: logger, + certificateProvider: certificateProvider, acmeService: acmeService, certificate: certificate, key: key, @@ -369,8 +484,8 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option. echKeyPath: echKeyPath, } serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - serverConfig.access.Lock() - defer serverConfig.access.Unlock() + serverConfig.access.RLock() + defer serverConfig.access.RUnlock() return serverConfig.config, nil } var config ServerConfig = serverConfig @@ -387,3 +502,27 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option. } return config, nil } + +func newCertificateProvider(ctx context.Context, logger log.ContextLogger, options *option.CertificateProviderOptions) (managedCertificateProvider, error) { + if options.IsShared() { + manager := service.FromContext[adapter.CertificateProviderManager](ctx) + if manager == nil { + return nil, E.New("missing certificate provider manager in context") + } + return &sharedCertificateProvider{ + tag: options.Tag, + manager: manager, + }, nil + } + registry := service.FromContext[adapter.CertificateProviderRegistry](ctx) + if registry == nil { + return nil, E.New("missing certificate provider registry in context") + } + provider, err := registry.Create(ctx, logger, "", options.Type, options.Options) + if err != nil { + return nil, E.Cause(err, "create inline certificate provider") + } + return &inlineCertificateProvider{ + provider: provider, + }, nil +} diff --git a/constant/proxy.go b/constant/proxy.go index 278a46c2f6..add66c95e5 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -1,36 +1,38 @@ package constant const ( - TypeTun = "tun" - TypeRedirect = "redirect" - TypeTProxy = "tproxy" - TypeDirect = "direct" - TypeBlock = "block" - TypeDNS = "dns" - TypeSOCKS = "socks" - TypeHTTP = "http" - TypeMixed = "mixed" - TypeShadowsocks = "shadowsocks" - TypeVMess = "vmess" - TypeTrojan = "trojan" - TypeNaive = "naive" - TypeWireGuard = "wireguard" - TypeHysteria = "hysteria" - TypeTor = "tor" - TypeSSH = "ssh" - TypeShadowTLS = "shadowtls" - TypeAnyTLS = "anytls" - TypeShadowsocksR = "shadowsocksr" - TypeVLESS = "vless" - TypeTUIC = "tuic" - TypeHysteria2 = "hysteria2" - TypeTailscale = "tailscale" - TypeDERP = "derp" - TypeResolved = "resolved" - TypeSSMAPI = "ssm-api" - TypeCCM = "ccm" - TypeOCM = "ocm" - TypeOOMKiller = "oom-killer" + TypeTun = "tun" + TypeRedirect = "redirect" + TypeTProxy = "tproxy" + TypeDirect = "direct" + TypeBlock = "block" + TypeDNS = "dns" + TypeSOCKS = "socks" + TypeHTTP = "http" + TypeMixed = "mixed" + TypeShadowsocks = "shadowsocks" + TypeVMess = "vmess" + TypeTrojan = "trojan" + TypeNaive = "naive" + TypeWireGuard = "wireguard" + TypeHysteria = "hysteria" + TypeTor = "tor" + TypeSSH = "ssh" + TypeShadowTLS = "shadowtls" + TypeAnyTLS = "anytls" + TypeShadowsocksR = "shadowsocksr" + TypeVLESS = "vless" + TypeTUIC = "tuic" + TypeHysteria2 = "hysteria2" + TypeTailscale = "tailscale" + TypeDERP = "derp" + TypeResolved = "resolved" + TypeSSMAPI = "ssm-api" + TypeCCM = "ccm" + TypeOCM = "ocm" + TypeOOMKiller = "oom-killer" + TypeACME = "acme" + TypeCloudflareOriginCA = "cloudflare-origin-ca" ) const ( diff --git a/common/tls/acme_contstant.go b/constant/tls.go similarity index 69% rename from common/tls/acme_contstant.go rename to constant/tls.go index c5cd2ff164..2d4f64bc3a 100644 --- a/common/tls/acme_contstant.go +++ b/constant/tls.go @@ -1,3 +1,3 @@ -package tls +package constant const ACMETLS1Protocol = "acme-tls/1" diff --git a/docs/configuration/inbound/tun.md b/docs/configuration/inbound/tun.md index 5a2f58d3db..6dae06e18a 100644 --- a/docs/configuration/inbound/tun.md +++ b/docs/configuration/inbound/tun.md @@ -4,7 +4,7 @@ icon: material/new-box !!! quote "Changes in sing-box 1.14.0" - :material-plus: [include_mac_address](#include_mac_address) + :material-plus: [include_mac_address](#include_mac_address) :material-plus: [exclude_mac_address](#exclude_mac_address) !!! quote "Changes in sing-box 1.13.3" diff --git a/docs/configuration/index.md b/docs/configuration/index.md index 1f6eec1375..81cb8f3863 100644 --- a/docs/configuration/index.md +++ b/docs/configuration/index.md @@ -1,7 +1,6 @@ # Introduction sing-box uses JSON for configuration files. - ### Structure ```json @@ -10,6 +9,7 @@ sing-box uses JSON for configuration files. "dns": {}, "ntp": {}, "certificate": {}, + "certificate_providers": [], "endpoints": [], "inbounds": [], "outbounds": [], @@ -27,6 +27,7 @@ sing-box uses JSON for configuration files. | `dns` | [DNS](./dns/) | | `ntp` | [NTP](./ntp/) | | `certificate` | [Certificate](./certificate/) | +| `certificate_providers` | [Certificate Provider](./shared/certificate-provider/) | | `endpoints` | [Endpoint](./endpoint/) | | `inbounds` | [Inbound](./inbound/) | | `outbounds` | [Outbound](./outbound/) | @@ -50,4 +51,4 @@ sing-box format -w -c config.json -D config_directory ```bash sing-box merge output.json -c config.json -D config_directory -``` \ No newline at end of file +``` diff --git a/docs/configuration/index.zh.md b/docs/configuration/index.zh.md index 3bdc352187..350db5d4c4 100644 --- a/docs/configuration/index.zh.md +++ b/docs/configuration/index.zh.md @@ -1,7 +1,6 @@ # 引言 sing-box 使用 JSON 作为配置文件格式。 - ### 结构 ```json @@ -10,6 +9,7 @@ sing-box 使用 JSON 作为配置文件格式。 "dns": {}, "ntp": {}, "certificate": {}, + "certificate_providers": [], "endpoints": [], "inbounds": [], "outbounds": [], @@ -27,6 +27,7 @@ sing-box 使用 JSON 作为配置文件格式。 | `dns` | [DNS](./dns/) | | `ntp` | [NTP](./ntp/) | | `certificate` | [证书](./certificate/) | +| `certificate_providers` | [证书提供者](./shared/certificate-provider/) | | `endpoints` | [端点](./endpoint/) | | `inbounds` | [入站](./inbound/) | | `outbounds` | [出站](./outbound/) | @@ -50,4 +51,4 @@ sing-box format -w -c config.json -D config_directory ```bash sing-box merge output.json -c config.json -D config_directory -``` \ No newline at end of file +``` diff --git a/docs/configuration/shared/certificate-provider/acme.md b/docs/configuration/shared/certificate-provider/acme.md new file mode 100644 index 0000000000..440ed1568d --- /dev/null +++ b/docs/configuration/shared/certificate-provider/acme.md @@ -0,0 +1,150 @@ +--- +icon: material/new-box +--- + +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [account_key](#account_key) + :material-plus: [key_type](#key_type) + :material-plus: [detour](#detour) + +# ACME + +!!! quote "" + + `with_acme` build tag required. + +### Structure + +```json +{ + "type": "acme", + "tag": "", + + "domain": [], + "data_directory": "", + "default_server_name": "", + "email": "", + "provider": "", + "account_key": "", + "disable_http_challenge": false, + "disable_tls_alpn_challenge": false, + "alternative_http_port": 0, + "alternative_tls_port": 0, + "external_account": { + "key_id": "", + "mac_key": "" + }, + "dns01_challenge": {}, + "key_type": "", + "detour": "" +} +``` + +### Fields + +#### domain + +==Required== + +List of domains. + +#### data_directory + +The directory to store ACME data. + +`$XDG_DATA_HOME/certmagic|$HOME/.local/share/certmagic` will be used if empty. + +#### default_server_name + +Server name to use when choosing a certificate if the ClientHello's ServerName field is empty. + +#### email + +The email address to use when creating or selecting an existing ACME server account. + +#### provider + +The ACME CA provider to use. + +| Value | Provider | +|-------------------------|---------------| +| `letsencrypt (default)` | Let's Encrypt | +| `zerossl` | ZeroSSL | +| `https://...` | Custom | + +When `provider` is `zerossl`, sing-box will automatically request ZeroSSL EAB credentials if `email` is set and +`external_account` is empty. + +When `provider` is `zerossl`, at least one of `external_account`, `email`, or `account_key` is required. + +#### account_key + +!!! question "Since sing-box 1.14.0" + +The PEM-encoded private key of an existing ACME account. + +#### disable_http_challenge + +Disable all HTTP challenges. + +#### disable_tls_alpn_challenge + +Disable all TLS-ALPN challenges + +#### alternative_http_port + +The alternate port to use for the ACME HTTP challenge; if non-empty, this port will be used instead of 80 to spin up a +listener for the HTTP challenge. + +#### alternative_tls_port + +The alternate port to use for the ACME TLS-ALPN challenge; the system must forward 443 to this port for challenge to +succeed. + +#### external_account + +EAB (External Account Binding) contains information necessary to bind or map an ACME account to some other account known +by the CA. + +External account bindings are used to associate an ACME account with an existing account in a non-ACME system, such as +a CA customer database. + +To enable ACME account binding, the CA operating the ACME server needs to provide the ACME client with a MAC key and a +key identifier, using some mechanism outside of ACME. §7.3.4 + +#### external_account.key_id + +The key identifier. + +#### external_account.mac_key + +The MAC key. + +#### dns01_challenge + +ACME DNS01 challenge field. If configured, other challenge methods will be disabled. + +See [DNS01 Challenge Fields](/configuration/shared/dns01_challenge/) for details. + +#### key_type + +!!! question "Since sing-box 1.14.0" + +The private key type to generate for new certificates. + +| Value | Type | +|------------|---------| +| `ed25519` | Ed25519 | +| `p256` | P-256 | +| `p384` | P-384 | +| `rsa2048` | RSA | +| `rsa4096` | RSA | + +#### detour + +!!! question "Since sing-box 1.14.0" + +The tag of the upstream outbound. + +All provider HTTP requests will use this outbound. diff --git a/docs/configuration/shared/certificate-provider/acme.zh.md b/docs/configuration/shared/certificate-provider/acme.zh.md new file mode 100644 index 0000000000..d95930a550 --- /dev/null +++ b/docs/configuration/shared/certificate-provider/acme.zh.md @@ -0,0 +1,145 @@ +--- +icon: material/new-box +--- + +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [account_key](#account_key) + :material-plus: [key_type](#key_type) + :material-plus: [detour](#detour) + +# ACME + +!!! quote "" + + 需要 `with_acme` 构建标签。 + +### 结构 + +```json +{ + "type": "acme", + "tag": "", + + "domain": [], + "data_directory": "", + "default_server_name": "", + "email": "", + "provider": "", + "account_key": "", + "disable_http_challenge": false, + "disable_tls_alpn_challenge": false, + "alternative_http_port": 0, + "alternative_tls_port": 0, + "external_account": { + "key_id": "", + "mac_key": "" + }, + "dns01_challenge": {}, + "key_type": "", + "detour": "" +} +``` + +### 字段 + +#### domain + +==必填== + +域名列表。 + +#### data_directory + +ACME 数据存储目录。 + +如果为空则使用 `$XDG_DATA_HOME/certmagic|$HOME/.local/share/certmagic`。 + +#### default_server_name + +如果 ClientHello 的 ServerName 字段为空,则选择证书时要使用的服务器名称。 + +#### email + +创建或选择现有 ACME 服务器帐户时使用的电子邮件地址。 + +#### provider + +要使用的 ACME CA 提供商。 + +| 值 | 提供商 | +|--------------------|---------------| +| `letsencrypt (默认)` | Let's Encrypt | +| `zerossl` | ZeroSSL | +| `https://...` | 自定义 | + +当 `provider` 为 `zerossl` 时,如果设置了 `email` 且未设置 `external_account`, +sing-box 会自动向 ZeroSSL 请求 EAB 凭据。 + +当 `provider` 为 `zerossl` 时,必须至少设置 `external_account`、`email` 或 `account_key` 之一。 + +#### account_key + +!!! question "自 sing-box 1.14.0 起" + +现有 ACME 帐户的 PEM 编码私钥。 + +#### disable_http_challenge + +禁用所有 HTTP 质询。 + +#### disable_tls_alpn_challenge + +禁用所有 TLS-ALPN 质询。 + +#### alternative_http_port + +用于 ACME HTTP 质询的备用端口;如果非空,将使用此端口而不是 80 来启动 HTTP 质询的侦听器。 + +#### alternative_tls_port + +用于 ACME TLS-ALPN 质询的备用端口; 系统必须将 443 转发到此端口以使质询成功。 + +#### external_account + +EAB(外部帐户绑定)包含将 ACME 帐户绑定或映射到 CA 已知的其他帐户所需的信息。 + +外部帐户绑定用于将 ACME 帐户与非 ACME 系统中的现有帐户相关联,例如 CA 客户数据库。 + +为了启用 ACME 帐户绑定,运行 ACME 服务器的 CA 需要使用 ACME 之外的某种机制向 ACME 客户端提供 MAC 密钥和密钥标识符。§7.3.4 + +#### external_account.key_id + +密钥标识符。 + +#### external_account.mac_key + +MAC 密钥。 + +#### dns01_challenge + +ACME DNS01 质询字段。如果配置,将禁用其他质询方法。 + +参阅 [DNS01 质询字段](/zh/configuration/shared/dns01_challenge/)。 + +#### key_type + +!!! question "自 sing-box 1.14.0 起" + +为新证书生成的私钥类型。 + +| 值 | 类型 | +|-----------|----------| +| `ed25519` | Ed25519 | +| `p256` | P-256 | +| `p384` | P-384 | +| `rsa2048` | RSA | +| `rsa4096` | RSA | + +#### detour + +!!! question "自 sing-box 1.14.0 起" + +上游出站的标签。 + +所有提供者 HTTP 请求将使用此出站。 diff --git a/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.md b/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.md new file mode 100644 index 0000000000..cfd2da4fe1 --- /dev/null +++ b/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.md @@ -0,0 +1,82 @@ +--- +icon: material/new-box +--- + +!!! question "Since sing-box 1.14.0" + +# Cloudflare Origin CA + +### Structure + +```json +{ + "type": "cloudflare-origin-ca", + "tag": "", + + "domain": [], + "data_directory": "", + "api_token": "", + "origin_ca_key": "", + "request_type": "", + "requested_validity": 0, + "detour": "" +} +``` + +### Fields + +#### domain + +==Required== + +List of domain names or wildcard domain names to include in the certificate. + +#### data_directory + +Root directory used to store the issued certificate, private key, and metadata. + +If empty, sing-box uses the same default data directory as the ACME certificate provider: +`$XDG_DATA_HOME/certmagic` or `$HOME/.local/share/certmagic`. + +#### api_token + +Cloudflare API token used to create the certificate. + +Get or create one in [Cloudflare Dashboard > My Profile > API Tokens](https://dash.cloudflare.com/profile/api-tokens). + +Requires the `Zone / SSL and Certificates / Edit` permission. + +Conflict with `origin_ca_key`. + +#### origin_ca_key + +Cloudflare Origin CA Key. + +Get it in [Cloudflare Dashboard > My Profile > API Tokens > API Keys > Origin CA Key](https://dash.cloudflare.com/profile/api-tokens). + +Conflict with `api_token`. + +#### request_type + +The signature type to request from Cloudflare. + +| Value | Type | +|----------------------|-------------| +| `origin-rsa` | RSA | +| `origin-ecc` | ECDSA P-256 | + +`origin-rsa` is used if empty. + +#### requested_validity + +The requested certificate validity in days. + +Available values: `7`, `30`, `90`, `365`, `730`, `1095`, `5475`. + +`5475` days (15 years) is used if empty. + +#### detour + +The tag of the upstream outbound. + +All provider HTTP requests will use this outbound. diff --git a/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.zh.md b/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.zh.md new file mode 100644 index 0000000000..85036268df --- /dev/null +++ b/docs/configuration/shared/certificate-provider/cloudflare-origin-ca.zh.md @@ -0,0 +1,82 @@ +--- +icon: material/new-box +--- + +!!! question "自 sing-box 1.14.0 起" + +# Cloudflare Origin CA + +### 结构 + +```json +{ + "type": "cloudflare-origin-ca", + "tag": "", + + "domain": [], + "data_directory": "", + "api_token": "", + "origin_ca_key": "", + "request_type": "", + "requested_validity": 0, + "detour": "" +} +``` + +### 字段 + +#### domain + +==必填== + +要写入证书的域名或通配符域名列表。 + +#### data_directory + +保存签发证书、私钥和元数据的根目录。 + +如果为空,sing-box 会使用与 ACME 证书提供者相同的默认数据目录: +`$XDG_DATA_HOME/certmagic` 或 `$HOME/.local/share/certmagic`。 + +#### api_token + +用于创建证书的 Cloudflare API Token。 + +可在 [Cloudflare Dashboard > My Profile > API Tokens](https://dash.cloudflare.com/profile/api-tokens) 获取或创建。 + +需要 `Zone / SSL and Certificates / Edit` 权限。 + +与 `origin_ca_key` 冲突。 + +#### origin_ca_key + +Cloudflare Origin CA Key。 + +可在 [Cloudflare Dashboard > My Profile > API Tokens > API Keys > Origin CA Key](https://dash.cloudflare.com/profile/api-tokens) 获取。 + +与 `api_token` 冲突。 + +#### request_type + +向 Cloudflare 请求的签名类型。 + +| 值 | 类型 | +|----------------------|-------------| +| `origin-rsa` | RSA | +| `origin-ecc` | ECDSA P-256 | + +如果为空,使用 `origin-rsa`。 + +#### requested_validity + +请求的证书有效期,单位为天。 + +可用值:`7`、`30`、`90`、`365`、`730`、`1095`、`5475`。 + +如果为空,使用 `5475` 天(15 年)。 + +#### detour + +上游出站的标签。 + +所有提供者 HTTP 请求将使用此出站。 diff --git a/docs/configuration/shared/certificate-provider/index.md b/docs/configuration/shared/certificate-provider/index.md new file mode 100644 index 0000000000..c493550aaa --- /dev/null +++ b/docs/configuration/shared/certificate-provider/index.md @@ -0,0 +1,32 @@ +--- +icon: material/new-box +--- + +!!! question "Since sing-box 1.14.0" + +# Certificate Provider + +### Structure + +```json +{ + "certificate_providers": [ + { + "type": "", + "tag": "" + } + ] +} +``` + +### Fields + +| Type | Format | +|--------|------------------| +| `acme` | [ACME](/configuration/shared/certificate-provider/acme) | +| `tailscale` | [Tailscale](/configuration/shared/certificate-provider/tailscale) | +| `cloudflare-origin-ca` | [Cloudflare Origin CA](/configuration/shared/certificate-provider/cloudflare-origin-ca) | + +#### tag + +The tag of the certificate provider. diff --git a/docs/configuration/shared/certificate-provider/index.zh.md b/docs/configuration/shared/certificate-provider/index.zh.md new file mode 100644 index 0000000000..2df4b36387 --- /dev/null +++ b/docs/configuration/shared/certificate-provider/index.zh.md @@ -0,0 +1,32 @@ +--- +icon: material/new-box +--- + +!!! question "自 sing-box 1.14.0 起" + +# 证书提供者 + +### 结构 + +```json +{ + "certificate_providers": [ + { + "type": "", + "tag": "" + } + ] +} +``` + +### 字段 + +| 类型 | 格式 | +|--------|------------------| +| `acme` | [ACME](/zh/configuration/shared/certificate-provider/acme) | +| `tailscale` | [Tailscale](/zh/configuration/shared/certificate-provider/tailscale) | +| `cloudflare-origin-ca` | [Cloudflare Origin CA](/zh/configuration/shared/certificate-provider/cloudflare-origin-ca) | + +#### tag + +证书提供者的标签。 diff --git a/docs/configuration/shared/certificate-provider/tailscale.md b/docs/configuration/shared/certificate-provider/tailscale.md new file mode 100644 index 0000000000..045f2c5ec5 --- /dev/null +++ b/docs/configuration/shared/certificate-provider/tailscale.md @@ -0,0 +1,27 @@ +--- +icon: material/new-box +--- + +!!! question "Since sing-box 1.14.0" + +# Tailscale + +### Structure + +```json +{ + "type": "tailscale", + "tag": "ts-cert", + "endpoint": "ts-ep" +} +``` + +### Fields + +#### endpoint + +==Required== + +The tag of the [Tailscale endpoint](/configuration/endpoint/tailscale/) to reuse. + +[MagicDNS and HTTPS](https://tailscale.com/kb/1153/enabling-https) must be enabled in the Tailscale admin console. diff --git a/docs/configuration/shared/certificate-provider/tailscale.zh.md b/docs/configuration/shared/certificate-provider/tailscale.zh.md new file mode 100644 index 0000000000..1987da5084 --- /dev/null +++ b/docs/configuration/shared/certificate-provider/tailscale.zh.md @@ -0,0 +1,27 @@ +--- +icon: material/new-box +--- + +!!! question "自 sing-box 1.14.0 起" + +# Tailscale + +### 结构 + +```json +{ + "type": "tailscale", + "tag": "ts-cert", + "endpoint": "ts-ep" +} +``` + +### 字段 + +#### endpoint + +==必填== + +要复用的 [Tailscale 端点](/zh/configuration/endpoint/tailscale/) 的标签。 + +必须在 Tailscale 管理控制台中启用 [MagicDNS 和 HTTPS](https://tailscale.com/kb/1153/enabling-https)。 diff --git a/docs/configuration/shared/dns01_challenge.md b/docs/configuration/shared/dns01_challenge.md index 8bdbfc97a7..0157cb4596 100644 --- a/docs/configuration/shared/dns01_challenge.md +++ b/docs/configuration/shared/dns01_challenge.md @@ -2,6 +2,14 @@ icon: material/new-box --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [ttl](#ttl) + :material-plus: [propagation_delay](#propagation_delay) + :material-plus: [propagation_timeout](#propagation_timeout) + :material-plus: [resolvers](#resolvers) + :material-plus: [override_domain](#override_domain) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [alidns.security_token](#security_token) @@ -12,12 +20,57 @@ icon: material/new-box ```json { + "ttl": "", + "propagation_delay": "", + "propagation_timeout": "", + "resolvers": [], + "override_domain": "", "provider": "", ... // Provider Fields } ``` +### Fields + +#### ttl + +!!! question "Since sing-box 1.14.0" + +The TTL of the temporary TXT record used for the DNS challenge. + +#### propagation_delay + +!!! question "Since sing-box 1.14.0" + +How long to wait after creating the challenge record before starting propagation checks. + +#### propagation_timeout + +!!! question "Since sing-box 1.14.0" + +The maximum time to wait for the challenge record to propagate. + +Set to `-1` to disable propagation checks. + +#### resolvers + +!!! question "Since sing-box 1.14.0" + +Preferred DNS resolvers to use for DNS propagation checks. + +#### override_domain + +!!! question "Since sing-box 1.14.0" + +Override the domain name used for the DNS challenge record. + +Useful when `_acme-challenge` is delegated to a different zone. + +#### provider + +The DNS provider. See below for provider-specific fields. + ### Provider Fields #### Alibaba Cloud DNS diff --git a/docs/configuration/shared/dns01_challenge.zh.md b/docs/configuration/shared/dns01_challenge.zh.md index e6919338cd..8c582bb544 100644 --- a/docs/configuration/shared/dns01_challenge.zh.md +++ b/docs/configuration/shared/dns01_challenge.zh.md @@ -2,6 +2,14 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [ttl](#ttl) + :material-plus: [propagation_delay](#propagation_delay) + :material-plus: [propagation_timeout](#propagation_timeout) + :material-plus: [resolvers](#resolvers) + :material-plus: [override_domain](#override_domain) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [alidns.security_token](#security_token) @@ -12,12 +20,57 @@ icon: material/new-box ```json { + "ttl": "", + "propagation_delay": "", + "propagation_timeout": "", + "resolvers": [], + "override_domain": "", "provider": "", ... // 提供商字段 } ``` +### 字段 + +#### ttl + +!!! question "自 sing-box 1.14.0 起" + +DNS 质询临时 TXT 记录的 TTL。 + +#### propagation_delay + +!!! question "自 sing-box 1.14.0 起" + +创建质询记录后,在开始传播检查前要等待的时间。 + +#### propagation_timeout + +!!! question "自 sing-box 1.14.0 起" + +等待质询记录传播完成的最长时间。 + +设为 `-1` 可禁用传播检查。 + +#### resolvers + +!!! question "自 sing-box 1.14.0 起" + +进行 DNS 传播检查时优先使用的 DNS 解析器。 + +#### override_domain + +!!! question "自 sing-box 1.14.0 起" + +覆盖 DNS 质询记录使用的域名。 + +适用于将 `_acme-challenge` 委托到其他 zone 的场景。 + +#### provider + +DNS 提供商。提供商专有字段见下文。 + ### 提供商字段 #### Alibaba Cloud DNS diff --git a/docs/configuration/shared/tls.md b/docs/configuration/shared/tls.md index 73ceffccef..518b2f9176 100644 --- a/docs/configuration/shared/tls.md +++ b/docs/configuration/shared/tls.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [certificate_provider](#certificate_provider) + :material-delete-clock: [acme](#acme-fields) + !!! quote "Changes in sing-box 1.13.0" :material-plus: [kernel_tx](#kernel_tx) @@ -49,6 +54,10 @@ icon: material/new-box "key_path": "", "kernel_tx": false, "kernel_rx": false, + "certificate_provider": "", + + // Deprecated + "acme": { "domain": [], "data_directory": "", @@ -408,6 +417,18 @@ Enable kernel TLS transmit support. Enable kernel TLS receive support. +#### certificate_provider + +!!! question "Since sing-box 1.14.0" + +==Server only== + +A string or an object. + +When string, the tag of a shared [Certificate Provider](/configuration/shared/certificate-provider/). + +When object, an inline certificate provider. See [Certificate Provider](/configuration/shared/certificate-provider/) for available types and fields. + ## Custom TLS support !!! info "QUIC support" @@ -469,7 +490,7 @@ The ECH key and configuration can be generated by `sing-box generate ech-keypair !!! failure "Deprecated in sing-box 1.12.0" - ECH support has been migrated to use stdlib in sing-box 1.12.0, which does not come with support for PQ signature schemes, so `pq_signature_schemes_enabled` has been deprecated and no longer works. + `pq_signature_schemes_enabled` is deprecated in sing-box 1.12.0 and removed in sing-box 1.13.0. Enable support for post-quantum peer certificate signature schemes. @@ -477,7 +498,7 @@ Enable support for post-quantum peer certificate signature schemes. !!! failure "Deprecated in sing-box 1.12.0" - `dynamic_record_sizing_disabled` has nothing to do with ECH, was added by mistake, has been deprecated and no longer works. + `dynamic_record_sizing_disabled` is deprecated in sing-box 1.12.0 and removed in sing-box 1.13.0. Disables adaptive sizing of TLS records. @@ -566,6 +587,10 @@ Fragment TLS handshake into multiple TLS records to bypass firewalls. ### ACME Fields +!!! failure "Deprecated in sing-box 1.14.0" + + Inline ACME options are deprecated in sing-box 1.14.0 and will be removed in sing-box 1.16.0, check [Migration](/migration/#migrate-inline-acme-to-certificate-provider). + #### domain List of domain. @@ -677,4 +702,4 @@ A hexadecimal string with zero to eight digits. The maximum time difference between the server and the client. -Check disabled if empty. \ No newline at end of file +Check disabled if empty. diff --git a/docs/configuration/shared/tls.zh.md b/docs/configuration/shared/tls.zh.md index 0b47189bc6..56b90d33f1 100644 --- a/docs/configuration/shared/tls.zh.md +++ b/docs/configuration/shared/tls.zh.md @@ -2,6 +2,11 @@ icon: material/new-box --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [certificate_provider](#certificate_provider) + :material-delete-clock: [acme](#acme-字段) + !!! quote "sing-box 1.13.0 中的更改" :material-plus: [kernel_tx](#kernel_tx) @@ -49,6 +54,10 @@ icon: material/new-box "key_path": "", "kernel_tx": false, "kernel_rx": false, + "certificate_provider": "", + + // 废弃的 + "acme": { "domain": [], "data_directory": "", @@ -407,6 +416,18 @@ echo | openssl s_client -servername example.com -connect example.com:443 2>/dev/ 启用内核 TLS 接收支持。 +#### certificate_provider + +!!! question "自 sing-box 1.14.0 起" + +==仅服务器== + +字符串或对象。 + +为字符串时,共享[证书提供者](/zh/configuration/shared/certificate-provider/)的标签。 + +为对象时,内联的证书提供者。可用类型和字段参阅[证书提供者](/zh/configuration/shared/certificate-provider/)。 + ## 自定义 TLS 支持 !!! info "QUIC 支持" @@ -465,7 +486,7 @@ ECH 密钥和配置可以通过 `sing-box generate ech-keypair` 生成。 !!! failure "已在 sing-box 1.12.0 废弃" - ECH 支持已在 sing-box 1.12.0 迁移至使用标准库,但标准库不支持后量子对等证书签名方案,因此 `pq_signature_schemes_enabled` 已被弃用且不再工作。 + `pq_signature_schemes_enabled` 已在 sing-box 1.12.0 废弃且已在 sing-box 1.13.0 中被移除。 启用对后量子对等证书签名方案的支持。 @@ -473,7 +494,7 @@ ECH 密钥和配置可以通过 `sing-box generate ech-keypair` 生成。 !!! failure "已在 sing-box 1.12.0 废弃" - `dynamic_record_sizing_disabled` 与 ECH 无关,是错误添加的,现已弃用且不再工作。 + `dynamic_record_sizing_disabled` 已在 sing-box 1.12.0 废弃且已在 sing-box 1.13.0 中被移除。 禁用 TLS 记录的自适应大小调整。 @@ -561,6 +582,10 @@ ECH 配置路径,PEM 格式。 ### ACME 字段 +!!! failure "已在 sing-box 1.14.0 废弃" + + 内联 ACME 选项已在 sing-box 1.14.0 废弃且将在 sing-box 1.16.0 中被移除,参阅 [迁移指南](/zh/migration/#迁移内联-acme-到证书提供者)。 + #### domain 域名列表。 diff --git a/docs/deprecated.md b/docs/deprecated.md index 8e53bda6db..3faf986e08 100644 --- a/docs/deprecated.md +++ b/docs/deprecated.md @@ -4,6 +4,16 @@ icon: material/delete-alert # Deprecated Feature List +## 1.14.0 + +#### Inline ACME options in TLS + +Inline ACME options (`tls.acme`) are deprecated +and can be replaced by the ACME certificate provider, +check [Migration](../migration/#migrate-inline-acme-to-certificate-provider). + +Old fields will be removed in sing-box 1.16.0. + ## 1.12.0 #### Legacy DNS server formats @@ -28,7 +38,7 @@ so `pq_signature_schemes_enabled` has been deprecated and no longer works. Also, `dynamic_record_sizing_disabled` has nothing to do with ECH, was added by mistake, has been deprecated and no longer works. -These fields will be removed in sing-box 1.13.0. +These fields were removed in sing-box 1.13.0. ## 1.11.0 @@ -38,7 +48,7 @@ Legacy special outbounds (`block` / `dns`) are deprecated and can be replaced by rule actions, check [Migration](../migration/#migrate-legacy-special-outbounds-to-rule-actions). -Old fields will be removed in sing-box 1.13.0. +Old fields were removed in sing-box 1.13.0. #### Legacy inbound fields @@ -46,7 +56,7 @@ Legacy inbound fields (`inbound.` are deprecated and can be replaced by rule actions, check [Migration](../migration/#migrate-legacy-inbound-fields-to-rule-actions). -Old fields will be removed in sing-box 1.13.0. +Old fields were removed in sing-box 1.13.0. #### Destination override fields in direct outbound @@ -54,18 +64,20 @@ Destination override fields (`override_address` / `override_port`) in direct out and can be replaced by rule actions, check [Migration](../migration/#migrate-destination-override-fields-to-route-options). +Old fields were removed in sing-box 1.13.0. + #### WireGuard outbound WireGuard outbound is deprecated and can be replaced by endpoint, check [Migration](../migration/#migrate-wireguard-outbound-to-endpoint). -Old outbound will be removed in sing-box 1.13.0. +Old outbound was removed in sing-box 1.13.0. #### GSO option in TUN GSO has no advantages for transparent proxy scenarios, is deprecated and no longer works in TUN. -Old fields will be removed in sing-box 1.13.0. +Old fields were removed in sing-box 1.13.0. ## 1.10.0 @@ -75,12 +87,12 @@ Old fields will be removed in sing-box 1.13.0. `inet4_route_address` and `inet6_route_address` are merged into `route_address`, `inet4_route_exclude_address` and `inet6_route_exclude_address` are merged into `route_exclude_address`. -Old fields will be removed in sing-box 1.12.0. +Old fields were removed in sing-box 1.12.0. #### Match source rule items are renamed `rule_set_ipcidr_match_source` route and DNS rule items are renamed to -`rule_set_ip_cidr_match_source` and will be remove in sing-box 1.11.0. +`rule_set_ip_cidr_match_source` and were removed in sing-box 1.11.0. #### Drop support for go1.18 and go1.19 @@ -95,7 +107,7 @@ check [Migration](/migration/#migrate-cache-file-from-clash-api-to-independent-o #### GeoIP -GeoIP is deprecated and will be removed in sing-box 1.12.0. +GeoIP is deprecated and was removed in sing-box 1.12.0. The maxmind GeoIP National Database, as an IP classification database, is not entirely suitable for traffic bypassing, @@ -106,7 +118,7 @@ check [Migration](/migration/#migrate-geoip-to-rule-sets). #### Geosite -Geosite is deprecated and will be removed in sing-box 1.12.0. +Geosite is deprecated and was removed in sing-box 1.12.0. Geosite, the `domain-list-community` project maintained by V2Ray as an early traffic bypassing solution, suffers from a number of problems, including lack of maintenance, inaccurate rules, and difficult management. diff --git a/docs/deprecated.zh.md b/docs/deprecated.zh.md index 82b6db042f..e710e78ce7 100644 --- a/docs/deprecated.zh.md +++ b/docs/deprecated.zh.md @@ -4,6 +4,18 @@ icon: material/delete-alert # 废弃功能列表 +## 1.14.0 + +#### TLS 中的内联 ACME 选项 + +TLS 中的内联 ACME 选项(`tls.acme`)已废弃, +且可以通过 ACME 证书提供者替代, +参阅 [迁移指南](/zh/migration/#迁移内联-acme-到证书提供者)。 + +旧字段将在 sing-box 1.16.0 中被移除。 + +## 1.12.0 + #### 旧的 DNS 服务器格式 DNS 服务器已重构, @@ -24,7 +36,7 @@ ECH 支持已在 sing-box 1.12.0 迁移至使用标准库,但标准库不支 另外,`dynamic_record_sizing_disabled` 与 ECH 无关,是错误添加的,现已弃用且不再工作。 -相关字段将在 sing-box 1.13.0 中被移除。 +相关字段已在 sing-box 1.13.0 中被移除。 ## 1.11.0 @@ -33,41 +45,41 @@ ECH 支持已在 sing-box 1.12.0 迁移至使用标准库,但标准库不支 旧的特殊出站(`block` / `dns`)已废弃且可以通过规则动作替代, 参阅 [迁移指南](/zh/migration/#迁移旧的特殊出站到规则动作)。 -旧字段将在 sing-box 1.13.0 中被移除。 +旧字段已在 sing-box 1.13.0 中被移除。 #### 旧的入站字段 旧的入站字段(`inbound.`)已废弃且可以通过规则动作替代, 参阅 [迁移指南](/zh/migration/#迁移旧的入站字段到规则动作)。 -旧字段将在 sing-box 1.13.0 中被移除。 +旧字段已在 sing-box 1.13.0 中被移除。 #### direct 出站中的目标地址覆盖字段 direct 出站中的目标地址覆盖字段(`override_address` / `override_port`)已废弃且可以通过规则动作替代, 参阅 [迁移指南](/zh/migration/#迁移-direct-出站中的目标地址覆盖字段到路由字段)。 -旧字段将在 sing-box 1.13.0 中被移除。 +旧字段已在 sing-box 1.13.0 中被移除。 #### WireGuard 出站 WireGuard 出站已废弃且可以通过端点替代, 参阅 [迁移指南](/zh/migration/#迁移-wireguard-出站到端点)。 -旧出站将在 sing-box 1.13.0 中被移除。 +旧出站已在 sing-box 1.13.0 中被移除。 #### TUN 的 GSO 字段 GSO 对透明代理场景没有优势,已废弃且在 TUN 中不再起作用。 -旧字段将在 sing-box 1.13.0 中被移除。 +旧字段已在 sing-box 1.13.0 中被移除。 ## 1.10.0 #### Match source 规则项已重命名 `rule_set_ipcidr_match_source` 路由和 DNS 规则项已被重命名为 -`rule_set_ip_cidr_match_source` 且将在 sing-box 1.11.0 中被移除。 +`rule_set_ip_cidr_match_source` 且已在 sing-box 1.11.0 中被移除。 #### TUN 地址字段已合并 @@ -75,7 +87,7 @@ GSO 对透明代理场景没有优势,已废弃且在 TUN 中不再起作用 `inet4_route_address` 和 `inet6_route_address` 已合并为 `route_address`, `inet4_route_exclude_address` 和 `inet6_route_exclude_address` 已合并为 `route_exclude_address`。 -旧字段将在 sing-box 1.11.0 中被移除。 +旧字段已在 sing-box 1.12.0 中被移除。 #### 移除对 go1.18 和 go1.19 的支持 @@ -90,7 +102,7 @@ Clash API 中的 `cache_file` 及相关功能已废弃且已迁移到独立的 ` #### GeoIP -GeoIP 已废弃且将在 sing-box 1.12.0 中被移除。 +GeoIP 已废弃且已在 sing-box 1.12.0 中被移除。 maxmind GeoIP 国家数据库作为 IP 分类数据库,不完全适合流量绕过, 且现有的实现均存在内存使用大与管理困难的问题。 @@ -100,7 +112,7 @@ sing-box 1.8.0 引入了[规则集](/zh/configuration/rule-set/), #### Geosite -Geosite 已废弃且将在 sing-box 1.12.0 中被移除。 +Geosite 已废弃且已在 sing-box 1.12.0 中被移除。 Geosite,即由 V2Ray 维护的 domain-list-community 项目,作为早期流量绕过解决方案, 存在着包括缺少维护、规则不准确和管理困难内的大量问题。 diff --git a/docs/migration.md b/docs/migration.md index 86074ac712..810bae190a 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -2,6 +2,83 @@ icon: material/arrange-bring-forward --- +## 1.14.0 + +### Migrate inline ACME to certificate provider + +Inline ACME options in TLS are deprecated and can be replaced by certificate providers. + +Most `tls.acme` fields can be moved into the ACME certificate provider unchanged. +See [ACME](/configuration/shared/certificate-provider/acme/) for fields newly added in sing-box 1.14.0. + +!!! info "References" + + [TLS](/configuration/shared/tls/#certificate_provider) / + [Certificate Provider](/configuration/shared/certificate-provider/) + +=== ":material-card-remove: Deprecated" + + ```json + { + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "acme": { + "domain": ["example.com"], + "email": "admin@example.com" + } + } + } + ] + } + ``` + +=== ":material-card-multiple: Inline" + + ```json + { + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "certificate_provider": { + "type": "acme", + "domain": ["example.com"], + "email": "admin@example.com" + } + } + } + ] + } + ``` + +=== ":material-card-multiple: Shared" + + ```json + { + "certificate_providers": [ + { + "type": "acme", + "tag": "my-cert", + "domain": ["example.com"], + "email": "admin@example.com" + } + ], + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "certificate_provider": "my-cert" + } + } + ] + } + ``` + ## 1.12.0 ### Migrate to new DNS server formats diff --git a/docs/migration.zh.md b/docs/migration.zh.md index c08be78f5c..18e2872613 100644 --- a/docs/migration.zh.md +++ b/docs/migration.zh.md @@ -2,6 +2,83 @@ icon: material/arrange-bring-forward --- +## 1.14.0 + +### 迁移内联 ACME 到证书提供者 + +TLS 中的内联 ACME 选项已废弃,且可以被证书提供者替代。 + +`tls.acme` 的大多数字段都可以原样迁移到 ACME 证书提供者中。 +sing-box 1.14.0 新增字段参阅 [ACME](/zh/configuration/shared/certificate-provider/acme/) 页面。 + +!!! info "参考" + + [TLS](/zh/configuration/shared/tls/#certificate_provider) / + [证书提供者](/zh/configuration/shared/certificate-provider/) + +=== ":material-card-remove: 弃用的" + + ```json + { + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "acme": { + "domain": ["example.com"], + "email": "admin@example.com" + } + } + } + ] + } + ``` + +=== ":material-card-multiple: 内联" + + ```json + { + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "certificate_provider": { + "type": "acme", + "domain": ["example.com"], + "email": "admin@example.com" + } + } + } + ] + } + ``` + +=== ":material-card-multiple: 共享" + + ```json + { + "certificate_providers": [ + { + "type": "acme", + "tag": "my-cert", + "domain": ["example.com"], + "email": "admin@example.com" + } + ], + "inbounds": [ + { + "type": "trojan", + "tls": { + "enabled": true, + "certificate_provider": "my-cert" + } + } + ] + } + ``` + ## 1.12.0 ### 迁移到新的 DNS 服务器格式 diff --git a/experimental/deprecated/constants.go b/experimental/deprecated/constants.go index 385105d383..3526cda831 100644 --- a/experimental/deprecated/constants.go +++ b/experimental/deprecated/constants.go @@ -102,10 +102,20 @@ var OptionLegacyDomainStrategyOptions = Note{ MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-domain-strategy-options", } +var OptionInlineACME = Note{ + Name: "inline-acme-options", + Description: "inline ACME options in TLS", + DeprecatedVersion: "1.14.0", + ScheduledVersion: "1.16.0", + EnvName: "INLINE_ACME_OPTIONS", + MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-inline-acme-to-certificate-provider", +} + var Options = []Note{ OptionLegacyDNSTransport, OptionLegacyDNSFakeIPOptions, OptionOutboundDNSRuleItem, OptionMissingDomainResolver, OptionLegacyDomainStrategyOptions, + OptionInlineACME, } diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 54369bf770..eca3fdf94d 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -33,7 +33,7 @@ func baseContext(platformInterface PlatformInterface) context.Context { } ctx := context.Background() ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) - return box.Context(ctx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), dnsRegistry, include.ServiceRegistry()) + return box.Context(ctx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), dnsRegistry, include.ServiceRegistry(), include.CertificateProviderRegistry()) } func parseConfig(ctx context.Context, configContent string) (option.Options, error) { diff --git a/include/acme.go b/include/acme.go new file mode 100644 index 0000000000..093fd50823 --- /dev/null +++ b/include/acme.go @@ -0,0 +1,12 @@ +//go:build with_acme + +package include + +import ( + "github.com/sagernet/sing-box/adapter/certificate" + "github.com/sagernet/sing-box/service/acme" +) + +func registerACMECertificateProvider(registry *certificate.Registry) { + acme.RegisterCertificateProvider(registry) +} diff --git a/include/acme_stub.go b/include/acme_stub.go new file mode 100644 index 0000000000..bceab3d731 --- /dev/null +++ b/include/acme_stub.go @@ -0,0 +1,20 @@ +//go:build !with_acme + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerACMECertificateProvider(registry *certificate.Registry) { + certificate.Register[option.ACMECertificateProviderOptions](registry, C.TypeACME, func(ctx context.Context, logger log.ContextLogger, tag string, options option.ACMECertificateProviderOptions) (adapter.CertificateProviderService, error) { + return nil, E.New(`ACME is not included in this build, rebuild with -tags with_acme`) + }) +} diff --git a/include/registry.go b/include/registry.go index f090845b51..eb22cce1fe 100644 --- a/include/registry.go +++ b/include/registry.go @@ -5,6 +5,7 @@ import ( "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" @@ -34,13 +35,14 @@ import ( "github.com/sagernet/sing-box/protocol/tun" "github.com/sagernet/sing-box/protocol/vless" "github.com/sagernet/sing-box/protocol/vmess" + originca "github.com/sagernet/sing-box/service/origin_ca" "github.com/sagernet/sing-box/service/resolved" "github.com/sagernet/sing-box/service/ssmapi" E "github.com/sagernet/sing/common/exceptions" ) func Context(ctx context.Context) context.Context { - return box.Context(ctx, InboundRegistry(), OutboundRegistry(), EndpointRegistry(), DNSTransportRegistry(), ServiceRegistry()) + return box.Context(ctx, InboundRegistry(), OutboundRegistry(), EndpointRegistry(), DNSTransportRegistry(), ServiceRegistry(), CertificateProviderRegistry()) } func InboundRegistry() *inbound.Registry { @@ -139,6 +141,16 @@ func ServiceRegistry() *service.Registry { return registry } +func CertificateProviderRegistry() *certificate.Registry { + registry := certificate.NewRegistry() + + registerACMECertificateProvider(registry) + registerTailscaleCertificateProvider(registry) + originca.RegisterCertificateProvider(registry) + + return registry +} + func registerStubForRemovedInbounds(registry *inbound.Registry) { inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0") diff --git a/include/tailscale.go b/include/tailscale.go index 1757283b07..6f85aaac14 100644 --- a/include/tailscale.go +++ b/include/tailscale.go @@ -3,6 +3,7 @@ package include import ( + "github.com/sagernet/sing-box/adapter/certificate" "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/service" "github.com/sagernet/sing-box/dns" @@ -18,6 +19,10 @@ func registerTailscaleTransport(registry *dns.TransportRegistry) { tailscale.RegistryTransport(registry) } +func registerTailscaleCertificateProvider(registry *certificate.Registry) { + tailscale.RegisterCertificateProvider(registry) +} + func registerDERPService(registry *service.Registry) { derp.Register(registry) } diff --git a/include/tailscale_stub.go b/include/tailscale_stub.go index 78398875f8..e6f97f1eab 100644 --- a/include/tailscale_stub.go +++ b/include/tailscale_stub.go @@ -6,6 +6,7 @@ import ( "context" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/service" C "github.com/sagernet/sing-box/constant" @@ -27,6 +28,12 @@ func registerTailscaleTransport(registry *dns.TransportRegistry) { }) } +func registerTailscaleCertificateProvider(registry *certificate.Registry) { + certificate.Register[option.TailscaleCertificateProviderOptions](registry, C.TypeTailscale, func(ctx context.Context, logger log.ContextLogger, tag string, options option.TailscaleCertificateProviderOptions) (adapter.CertificateProviderService, error) { + return nil, E.New(`Tailscale is not included in this build, rebuild with -tags with_tailscale`) + }) +} + func registerDERPService(registry *service.Registry) { service.Register[option.DERPServiceOptions](registry, C.TypeDERP, func(ctx context.Context, logger log.ContextLogger, tag string, options option.DERPServiceOptions) (adapter.Service, error) { return nil, E.New(`DERP is not included in this build, rebuild with -tags with_tailscale`) diff --git a/mkdocs.yml b/mkdocs.yml index 5f95842a5d..65c9db71f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -122,6 +122,11 @@ nav: - Listen Fields: configuration/shared/listen.md - Dial Fields: configuration/shared/dial.md - TLS: configuration/shared/tls.md + - Certificate Provider: + - configuration/shared/certificate-provider/index.md + - ACME: configuration/shared/certificate-provider/acme.md + - Tailscale: configuration/shared/certificate-provider/tailscale.md + - Cloudflare Origin CA: configuration/shared/certificate-provider/cloudflare-origin-ca.md - DNS01 Challenge Fields: configuration/shared/dns01_challenge.md - Pre-match: configuration/shared/pre-match.md - Multiplex: configuration/shared/multiplex.md @@ -273,6 +278,7 @@ plugins: Shared: 通用 Listen Fields: 监听字段 Dial Fields: 拨号字段 + Certificate Provider Fields: 证书提供者字段 DNS01 Challenge Fields: DNS01 验证字段 Multiplex: 多路复用 V2Ray Transport: V2Ray 传输层 @@ -281,6 +287,7 @@ plugins: Endpoint: 端点 Inbound: 入站 Outbound: 出站 + Certificate Provider: 证书提供者 Manual: 手册 reconfigure_material: true diff --git a/option/acme.go b/option/acme.go new file mode 100644 index 0000000000..ea9349b724 --- /dev/null +++ b/option/acme.go @@ -0,0 +1,106 @@ +package option + +import ( + "strings" + + C "github.com/sagernet/sing-box/constant" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/common/json/badoption" +) + +type ACMECertificateProviderOptions struct { + Domain badoption.Listable[string] `json:"domain,omitempty"` + DataDirectory string `json:"data_directory,omitempty"` + DefaultServerName string `json:"default_server_name,omitempty"` + Email string `json:"email,omitempty"` + Provider string `json:"provider,omitempty"` + AccountKey string `json:"account_key,omitempty"` + DisableHTTPChallenge bool `json:"disable_http_challenge,omitempty"` + DisableTLSALPNChallenge bool `json:"disable_tls_alpn_challenge,omitempty"` + AlternativeHTTPPort uint16 `json:"alternative_http_port,omitempty"` + AlternativeTLSPort uint16 `json:"alternative_tls_port,omitempty"` + ExternalAccount *ACMEExternalAccountOptions `json:"external_account,omitempty"` + DNS01Challenge *ACMEProviderDNS01ChallengeOptions `json:"dns01_challenge,omitempty"` + KeyType ACMEKeyType `json:"key_type,omitempty"` + Detour string `json:"detour,omitempty"` +} + +type _ACMEProviderDNS01ChallengeOptions struct { + TTL badoption.Duration `json:"ttl,omitempty"` + PropagationDelay badoption.Duration `json:"propagation_delay,omitempty"` + PropagationTimeout badoption.Duration `json:"propagation_timeout,omitempty"` + Resolvers badoption.Listable[string] `json:"resolvers,omitempty"` + OverrideDomain string `json:"override_domain,omitempty"` + Provider string `json:"provider,omitempty"` + AliDNSOptions ACMEDNS01AliDNSOptions `json:"-"` + CloudflareOptions ACMEDNS01CloudflareOptions `json:"-"` + ACMEDNSOptions ACMEDNS01ACMEDNSOptions `json:"-"` +} + +type ACMEProviderDNS01ChallengeOptions _ACMEProviderDNS01ChallengeOptions + +func (o ACMEProviderDNS01ChallengeOptions) MarshalJSON() ([]byte, error) { + var v any + switch o.Provider { + case C.DNSProviderAliDNS: + v = o.AliDNSOptions + case C.DNSProviderCloudflare: + v = o.CloudflareOptions + case C.DNSProviderACMEDNS: + v = o.ACMEDNSOptions + case "": + return nil, E.New("missing provider type") + default: + return nil, E.New("unknown provider type: ", o.Provider) + } + return badjson.MarshallObjects((_ACMEProviderDNS01ChallengeOptions)(o), v) +} + +func (o *ACMEProviderDNS01ChallengeOptions) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_ACMEProviderDNS01ChallengeOptions)(o)) + if err != nil { + return err + } + var v any + switch o.Provider { + case C.DNSProviderAliDNS: + v = &o.AliDNSOptions + case C.DNSProviderCloudflare: + v = &o.CloudflareOptions + case C.DNSProviderACMEDNS: + v = &o.ACMEDNSOptions + case "": + return E.New("missing provider type") + default: + return E.New("unknown provider type: ", o.Provider) + } + return badjson.UnmarshallExcluded(bytes, (*_ACMEProviderDNS01ChallengeOptions)(o), v) +} + +type ACMEKeyType string + +const ( + ACMEKeyTypeED25519 = ACMEKeyType("ed25519") + ACMEKeyTypeP256 = ACMEKeyType("p256") + ACMEKeyTypeP384 = ACMEKeyType("p384") + ACMEKeyTypeRSA2048 = ACMEKeyType("rsa2048") + ACMEKeyTypeRSA4096 = ACMEKeyType("rsa4096") +) + +func (t *ACMEKeyType) UnmarshalJSON(data []byte) error { + var value string + err := json.Unmarshal(data, &value) + if err != nil { + return err + } + value = strings.ToLower(value) + switch ACMEKeyType(value) { + case "", ACMEKeyTypeED25519, ACMEKeyTypeP256, ACMEKeyTypeP384, ACMEKeyTypeRSA2048, ACMEKeyTypeRSA4096: + *t = ACMEKeyType(value) + default: + return E.New("unknown ACME key type: ", value) + } + return nil +} diff --git a/option/certificate_provider.go b/option/certificate_provider.go new file mode 100644 index 0000000000..a24abdc570 --- /dev/null +++ b/option/certificate_provider.go @@ -0,0 +1,100 @@ +package option + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/service" +) + +type CertificateProviderOptionsRegistry interface { + CreateOptions(providerType string) (any, bool) +} + +type _CertificateProvider struct { + Type string `json:"type"` + Tag string `json:"tag,omitempty"` + Options any `json:"-"` +} + +type CertificateProvider _CertificateProvider + +func (h *CertificateProvider) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return badjson.MarshallObjectsContext(ctx, (*_CertificateProvider)(h), h.Options) +} + +func (h *CertificateProvider) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContext(ctx, content, (*_CertificateProvider)(h)) + if err != nil { + return err + } + registry := service.FromContext[CertificateProviderOptionsRegistry](ctx) + if registry == nil { + return E.New("missing certificate provider options registry in context") + } + options, loaded := registry.CreateOptions(h.Type) + if !loaded { + return E.New("unknown certificate provider type: ", h.Type) + } + err = badjson.UnmarshallExcludedContext(ctx, content, (*_CertificateProvider)(h), options) + if err != nil { + return err + } + h.Options = options + return nil +} + +type CertificateProviderOptions struct { + Tag string `json:"-"` + Type string `json:"-"` + Options any `json:"-"` +} + +type _CertificateProviderInline struct { + Type string `json:"type"` +} + +func (o *CertificateProviderOptions) MarshalJSONContext(ctx context.Context) ([]byte, error) { + if o.Tag != "" { + return json.Marshal(o.Tag) + } + return badjson.MarshallObjectsContext(ctx, _CertificateProviderInline{Type: o.Type}, o.Options) +} + +func (o *CertificateProviderOptions) UnmarshalJSONContext(ctx context.Context, content []byte) error { + if len(content) == 0 { + return E.New("empty certificate_provider value") + } + if content[0] == '"' { + return json.UnmarshalContext(ctx, content, &o.Tag) + } + var inline _CertificateProviderInline + err := json.UnmarshalContext(ctx, content, &inline) + if err != nil { + return err + } + o.Type = inline.Type + if o.Type == "" { + return E.New("missing certificate provider type") + } + registry := service.FromContext[CertificateProviderOptionsRegistry](ctx) + if registry == nil { + return E.New("missing certificate provider options registry in context") + } + options, loaded := registry.CreateOptions(o.Type) + if !loaded { + return E.New("unknown certificate provider type: ", o.Type) + } + err = badjson.UnmarshallExcludedContext(ctx, content, &inline, options) + if err != nil { + return err + } + o.Options = options + return nil +} + +func (o *CertificateProviderOptions) IsShared() bool { + return o.Tag != "" +} diff --git a/option/options.go b/option/options.go index 8bebd48fc6..a08dcbc0f1 100644 --- a/option/options.go +++ b/option/options.go @@ -10,18 +10,19 @@ import ( ) type _Options struct { - RawMessage json.RawMessage `json:"-"` - Schema string `json:"$schema,omitempty"` - Log *LogOptions `json:"log,omitempty"` - DNS *DNSOptions `json:"dns,omitempty"` - NTP *NTPOptions `json:"ntp,omitempty"` - Certificate *CertificateOptions `json:"certificate,omitempty"` - Endpoints []Endpoint `json:"endpoints,omitempty"` - Inbounds []Inbound `json:"inbounds,omitempty"` - Outbounds []Outbound `json:"outbounds,omitempty"` - Route *RouteOptions `json:"route,omitempty"` - Services []Service `json:"services,omitempty"` - Experimental *ExperimentalOptions `json:"experimental,omitempty"` + RawMessage json.RawMessage `json:"-"` + Schema string `json:"$schema,omitempty"` + Log *LogOptions `json:"log,omitempty"` + DNS *DNSOptions `json:"dns,omitempty"` + NTP *NTPOptions `json:"ntp,omitempty"` + Certificate *CertificateOptions `json:"certificate,omitempty"` + CertificateProviders []CertificateProvider `json:"certificate_providers,omitempty"` + Endpoints []Endpoint `json:"endpoints,omitempty"` + Inbounds []Inbound `json:"inbounds,omitempty"` + Outbounds []Outbound `json:"outbounds,omitempty"` + Route *RouteOptions `json:"route,omitempty"` + Services []Service `json:"services,omitempty"` + Experimental *ExperimentalOptions `json:"experimental,omitempty"` } type Options _Options @@ -56,6 +57,25 @@ func checkOptions(options *Options) error { if err != nil { return err } + err = checkCertificateProviders(options.CertificateProviders) + if err != nil { + return err + } + return nil +} + +func checkCertificateProviders(providers []CertificateProvider) error { + seen := make(map[string]bool) + for i, provider := range providers { + tag := provider.Tag + if tag == "" { + tag = F.ToString(i) + } + if seen[tag] { + return E.New("duplicate certificate provider tag: ", tag) + } + seen[tag] = true + } return nil } diff --git a/option/origin_ca.go b/option/origin_ca.go new file mode 100644 index 0000000000..ee8b370414 --- /dev/null +++ b/option/origin_ca.go @@ -0,0 +1,76 @@ +package option + +import ( + "strings" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badoption" +) + +type CloudflareOriginCACertificateProviderOptions struct { + Domain badoption.Listable[string] `json:"domain,omitempty"` + DataDirectory string `json:"data_directory,omitempty"` + APIToken string `json:"api_token,omitempty"` + OriginCAKey string `json:"origin_ca_key,omitempty"` + RequestType CloudflareOriginCARequestType `json:"request_type,omitempty"` + RequestedValidity CloudflareOriginCARequestValidity `json:"requested_validity,omitempty"` + Detour string `json:"detour,omitempty"` +} + +type CloudflareOriginCARequestType string + +const ( + CloudflareOriginCARequestTypeOriginRSA = CloudflareOriginCARequestType("origin-rsa") + CloudflareOriginCARequestTypeOriginECC = CloudflareOriginCARequestType("origin-ecc") +) + +func (t *CloudflareOriginCARequestType) UnmarshalJSON(data []byte) error { + var value string + err := json.Unmarshal(data, &value) + if err != nil { + return err + } + value = strings.ToLower(value) + switch CloudflareOriginCARequestType(value) { + case "", CloudflareOriginCARequestTypeOriginRSA, CloudflareOriginCARequestTypeOriginECC: + *t = CloudflareOriginCARequestType(value) + default: + return E.New("unsupported Cloudflare Origin CA request type: ", value) + } + return nil +} + +type CloudflareOriginCARequestValidity uint16 + +const ( + CloudflareOriginCARequestValidity7 = CloudflareOriginCARequestValidity(7) + CloudflareOriginCARequestValidity30 = CloudflareOriginCARequestValidity(30) + CloudflareOriginCARequestValidity90 = CloudflareOriginCARequestValidity(90) + CloudflareOriginCARequestValidity365 = CloudflareOriginCARequestValidity(365) + CloudflareOriginCARequestValidity730 = CloudflareOriginCARequestValidity(730) + CloudflareOriginCARequestValidity1095 = CloudflareOriginCARequestValidity(1095) + CloudflareOriginCARequestValidity5475 = CloudflareOriginCARequestValidity(5475) +) + +func (v *CloudflareOriginCARequestValidity) UnmarshalJSON(data []byte) error { + var value uint16 + err := json.Unmarshal(data, &value) + if err != nil { + return err + } + switch CloudflareOriginCARequestValidity(value) { + case 0, + CloudflareOriginCARequestValidity7, + CloudflareOriginCARequestValidity30, + CloudflareOriginCARequestValidity90, + CloudflareOriginCARequestValidity365, + CloudflareOriginCARequestValidity730, + CloudflareOriginCARequestValidity1095, + CloudflareOriginCARequestValidity5475: + *v = CloudflareOriginCARequestValidity(value) + default: + return E.New("unsupported Cloudflare Origin CA requested validity: ", value) + } + return nil +} diff --git a/option/tailscale.go b/option/tailscale.go index 68a143693e..a4f82ce0de 100644 --- a/option/tailscale.go +++ b/option/tailscale.go @@ -36,6 +36,10 @@ type TailscaleDNSServerOptions struct { AcceptDefaultResolvers bool `json:"accept_default_resolvers,omitempty"` } +type TailscaleCertificateProviderOptions struct { + Endpoint string `json:"endpoint,omitempty"` +} + type DERPServiceOptions struct { ListenOptions InboundTLSOptionsContainer diff --git a/option/tls.go b/option/tls.go index 60343a15f1..dbbb7620ed 100644 --- a/option/tls.go +++ b/option/tls.go @@ -28,9 +28,13 @@ type InboundTLSOptions struct { KeyPath string `json:"key_path,omitempty"` KernelTx bool `json:"kernel_tx,omitempty"` KernelRx bool `json:"kernel_rx,omitempty"` - ACME *InboundACMEOptions `json:"acme,omitempty"` - ECH *InboundECHOptions `json:"ech,omitempty"` - Reality *InboundRealityOptions `json:"reality,omitempty"` + CertificateProvider *CertificateProviderOptions `json:"certificate_provider,omitempty"` + + // Deprecated: use certificate_provider + ACME *InboundACMEOptions `json:"acme,omitempty"` + + ECH *InboundECHOptions `json:"ech,omitempty"` + Reality *InboundRealityOptions `json:"reality,omitempty"` } type ClientAuthType tls.ClientAuthType diff --git a/protocol/tailscale/certificate_provider.go b/protocol/tailscale/certificate_provider.go new file mode 100644 index 0000000000..5ac18a3073 --- /dev/null +++ b/protocol/tailscale/certificate_provider.go @@ -0,0 +1,98 @@ +//go:build with_gvisor + +package tailscale + +import ( + "context" + "crypto/tls" + "net" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + "github.com/sagernet/tailscale/client/local" +) + +func RegisterCertificateProvider(registry *certificate.Registry) { + certificate.Register[option.TailscaleCertificateProviderOptions](registry, C.TypeTailscale, NewCertificateProvider) +} + +var _ adapter.CertificateProviderService = (*CertificateProvider)(nil) + +type CertificateProvider struct { + certificate.Adapter + endpointTag string + endpoint *Endpoint + dialer N.Dialer + localClient *local.Client +} + +func NewCertificateProvider(ctx context.Context, _ log.ContextLogger, tag string, options option.TailscaleCertificateProviderOptions) (adapter.CertificateProviderService, error) { + if options.Endpoint == "" { + return nil, E.New("missing tailscale endpoint tag") + } + endpointManager := service.FromContext[adapter.EndpointManager](ctx) + if endpointManager == nil { + return nil, E.New("missing endpoint manager in context") + } + rawEndpoint, loaded := endpointManager.Get(options.Endpoint) + if !loaded { + return nil, E.New("endpoint not found: ", options.Endpoint) + } + endpoint, isTailscale := rawEndpoint.(*Endpoint) + if !isTailscale { + return nil, E.New("endpoint is not Tailscale: ", options.Endpoint) + } + providerDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{}, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create tailscale certificate provider dialer") + } + return &CertificateProvider{ + Adapter: certificate.NewAdapter(C.TypeTailscale, tag), + endpointTag: options.Endpoint, + endpoint: endpoint, + dialer: providerDialer, + }, nil +} + +func (p *CertificateProvider) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + localClient, err := p.endpoint.Server().LocalClient() + if err != nil { + return E.Cause(err, "initialize tailscale local client for endpoint ", p.endpointTag) + } + originalDial := localClient.Dial + localClient.Dial = func(ctx context.Context, network, addr string) (net.Conn, error) { + if originalDial != nil && addr == "local-tailscaled.sock:80" { + return originalDial(ctx, network, addr) + } + return p.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + } + p.localClient = localClient + return nil +} + +func (p *CertificateProvider) Close() error { + return nil +} + +func (p *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + localClient := p.localClient + if localClient == nil { + return nil, E.New("Tailscale is not ready yet") + } + return localClient.GetCertificate(clientHello) +} diff --git a/service/acme/service.go b/service/acme/service.go new file mode 100644 index 0000000000..8286a19717 --- /dev/null +++ b/service/acme/service.go @@ -0,0 +1,411 @@ +//go:build with_acme + +package acme + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "net" + "net/http" + "net/url" + "reflect" + "strings" + "time" + "unsafe" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" + "github.com/sagernet/sing-box/common/dialer" + boxtls "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" + + "github.com/caddyserver/certmagic" + "github.com/caddyserver/zerossl" + "github.com/libdns/alidns" + "github.com/libdns/cloudflare" + "github.com/libdns/libdns" + "github.com/mholt/acmez/v3/acme" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func RegisterCertificateProvider(registry *certificate.Registry) { + certificate.Register[option.ACMECertificateProviderOptions](registry, C.TypeACME, NewCertificateProvider) +} + +var ( + _ adapter.CertificateProviderService = (*Service)(nil) + _ adapter.ACMECertificateProvider = (*Service)(nil) +) + +type Service struct { + certificate.Adapter + ctx context.Context + config *certmagic.Config + cache *certmagic.Cache + domain []string + nextProtos []string +} + +func NewCertificateProvider(ctx context.Context, logger log.ContextLogger, tag string, options option.ACMECertificateProviderOptions) (adapter.CertificateProviderService, error) { + if len(options.Domain) == 0 { + return nil, E.New("missing domain") + } + var acmeServer string + switch options.Provider { + case "", "letsencrypt": + acmeServer = certmagic.LetsEncryptProductionCA + case "zerossl": + acmeServer = certmagic.ZeroSSLProductionCA + default: + if !strings.HasPrefix(options.Provider, "https://") { + return nil, E.New("unsupported ACME provider: ", options.Provider) + } + acmeServer = options.Provider + } + if acmeServer == certmagic.ZeroSSLProductionCA && + (options.ExternalAccount == nil || options.ExternalAccount.KeyID == "") && + strings.TrimSpace(options.Email) == "" && + strings.TrimSpace(options.AccountKey) == "" { + return nil, E.New("email is required to use the ZeroSSL ACME endpoint without external_account or account_key") + } + + var storage certmagic.Storage + if options.DataDirectory != "" { + storage = &certmagic.FileStorage{Path: options.DataDirectory} + } else { + storage = certmagic.Default.Storage + } + + zapLogger := zap.New(zapcore.NewCore( + zapcore.NewConsoleEncoder(boxtls.ACMEEncoderConfig()), + &boxtls.ACMELogWriter{Logger: logger}, + zap.DebugLevel, + )) + + config := &certmagic.Config{ + DefaultServerName: options.DefaultServerName, + Storage: storage, + Logger: zapLogger, + } + if options.KeyType != "" { + var keyType certmagic.KeyType + switch options.KeyType { + case option.ACMEKeyTypeED25519: + keyType = certmagic.ED25519 + case option.ACMEKeyTypeP256: + keyType = certmagic.P256 + case option.ACMEKeyTypeP384: + keyType = certmagic.P384 + case option.ACMEKeyTypeRSA2048: + keyType = certmagic.RSA2048 + case option.ACMEKeyTypeRSA4096: + keyType = certmagic.RSA4096 + default: + return nil, E.New("unsupported ACME key type: ", options.KeyType) + } + config.KeySource = certmagic.StandardKeyGenerator{KeyType: keyType} + } + + acmeIssuer := certmagic.ACMEIssuer{ + CA: acmeServer, + Email: options.Email, + AccountKeyPEM: options.AccountKey, + Agreed: true, + DisableHTTPChallenge: options.DisableHTTPChallenge, + DisableTLSALPNChallenge: options.DisableTLSALPNChallenge, + AltHTTPPort: int(options.AlternativeHTTPPort), + AltTLSALPNPort: int(options.AlternativeTLSPort), + Logger: zapLogger, + } + acmeHTTPClient, err := newACMEHTTPClient(ctx, options.Detour) + if err != nil { + return nil, err + } + dnsSolver, err := newDNSSolver(options.DNS01Challenge, zapLogger, acmeHTTPClient) + if err != nil { + return nil, err + } + if dnsSolver != nil { + acmeIssuer.DNS01Solver = dnsSolver + } + if options.ExternalAccount != nil && options.ExternalAccount.KeyID != "" { + acmeIssuer.ExternalAccount = (*acme.EAB)(options.ExternalAccount) + } + if acmeServer == certmagic.ZeroSSLProductionCA { + acmeIssuer.NewAccountFunc = func(ctx context.Context, acmeIssuer *certmagic.ACMEIssuer, account acme.Account) (acme.Account, error) { + if acmeIssuer.ExternalAccount != nil { + return account, nil + } + var err error + acmeIssuer.ExternalAccount, account, err = createZeroSSLExternalAccountBinding(ctx, acmeIssuer, account, acmeHTTPClient) + return account, err + } + } + + certmagicIssuer := certmagic.NewACMEIssuer(config, acmeIssuer) + httpClientField := reflect.ValueOf(certmagicIssuer).Elem().FieldByName("httpClient") + if !httpClientField.IsValid() || !httpClientField.CanAddr() { + return nil, E.New("certmagic ACME issuer HTTP client field is unavailable") + } + reflect.NewAt(httpClientField.Type(), unsafe.Pointer(httpClientField.UnsafeAddr())).Elem().Set(reflect.ValueOf(acmeHTTPClient)) + config.Issuers = []certmagic.Issuer{certmagicIssuer} + cache := certmagic.NewCache(certmagic.CacheOptions{ + GetConfigForCert: func(certificate certmagic.Certificate) (*certmagic.Config, error) { + return config, nil + }, + Logger: zapLogger, + }) + config = certmagic.New(cache, *config) + + var nextProtos []string + if !acmeIssuer.DisableTLSALPNChallenge && acmeIssuer.DNS01Solver == nil { + nextProtos = []string{C.ACMETLS1Protocol} + } + return &Service{ + Adapter: certificate.NewAdapter(C.TypeACME, tag), + ctx: ctx, + config: config, + cache: cache, + domain: options.Domain, + nextProtos: nextProtos, + }, nil +} + +func (s *Service) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + return s.config.ManageAsync(s.ctx, s.domain) +} + +func (s *Service) Close() error { + if s.cache != nil { + s.cache.Stop() + } + return nil +} + +func (s *Service) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return s.config.GetCertificate(hello) +} + +func (s *Service) GetACMENextProtos() []string { + return s.nextProtos +} + +func newDNSSolver(dnsOptions *option.ACMEProviderDNS01ChallengeOptions, logger *zap.Logger, httpClient *http.Client) (*certmagic.DNS01Solver, error) { + if dnsOptions == nil || dnsOptions.Provider == "" { + return nil, nil + } + if dnsOptions.TTL < 0 { + return nil, E.New("invalid ACME DNS01 ttl: ", dnsOptions.TTL) + } + if dnsOptions.PropagationDelay < 0 { + return nil, E.New("invalid ACME DNS01 propagation_delay: ", dnsOptions.PropagationDelay) + } + if dnsOptions.PropagationTimeout < -1 { + return nil, E.New("invalid ACME DNS01 propagation_timeout: ", dnsOptions.PropagationTimeout) + } + solver := &certmagic.DNS01Solver{ + DNSManager: certmagic.DNSManager{ + TTL: time.Duration(dnsOptions.TTL), + PropagationDelay: time.Duration(dnsOptions.PropagationDelay), + PropagationTimeout: time.Duration(dnsOptions.PropagationTimeout), + Resolvers: dnsOptions.Resolvers, + OverrideDomain: dnsOptions.OverrideDomain, + Logger: logger.Named("dns_manager"), + }, + } + switch dnsOptions.Provider { + case C.DNSProviderAliDNS: + solver.DNSProvider = &alidns.Provider{ + CredentialInfo: alidns.CredentialInfo{ + AccessKeyID: dnsOptions.AliDNSOptions.AccessKeyID, + AccessKeySecret: dnsOptions.AliDNSOptions.AccessKeySecret, + RegionID: dnsOptions.AliDNSOptions.RegionID, + SecurityToken: dnsOptions.AliDNSOptions.SecurityToken, + }, + } + case C.DNSProviderCloudflare: + solver.DNSProvider = &cloudflare.Provider{ + APIToken: dnsOptions.CloudflareOptions.APIToken, + ZoneToken: dnsOptions.CloudflareOptions.ZoneToken, + HTTPClient: httpClient, + } + case C.DNSProviderACMEDNS: + solver.DNSProvider = &acmeDNSProvider{ + username: dnsOptions.ACMEDNSOptions.Username, + password: dnsOptions.ACMEDNSOptions.Password, + subdomain: dnsOptions.ACMEDNSOptions.Subdomain, + serverURL: dnsOptions.ACMEDNSOptions.ServerURL, + httpClient: httpClient, + } + default: + return nil, E.New("unsupported ACME DNS01 provider type: ", dnsOptions.Provider) + } + return solver, nil +} + +func createZeroSSLExternalAccountBinding(ctx context.Context, acmeIssuer *certmagic.ACMEIssuer, account acme.Account, httpClient *http.Client) (*acme.EAB, acme.Account, error) { + email := strings.TrimSpace(acmeIssuer.Email) + if email == "" { + return nil, acme.Account{}, E.New("email is required to use the ZeroSSL ACME endpoint without external_account") + } + if len(account.Contact) == 0 { + account.Contact = []string{"mailto:" + email} + } + if acmeIssuer.CertObtainTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, acmeIssuer.CertObtainTimeout) + defer cancel() + } + + form := url.Values{"email": []string{email}} + request, err := http.NewRequestWithContext(ctx, http.MethodPost, zerossl.BaseURL+"/acme/eab-credentials-email", strings.NewReader(form.Encode())) + if err != nil { + return nil, account, E.Cause(err, "create ZeroSSL EAB request") + } + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + request.Header.Set("User-Agent", certmagic.UserAgent) + + response, err := httpClient.Do(request) + if err != nil { + return nil, account, E.Cause(err, "request ZeroSSL EAB") + } + defer response.Body.Close() + + var result struct { + Success bool `json:"success"` + Error struct { + Code int `json:"code"` + Type string `json:"type"` + } `json:"error"` + EABKID string `json:"eab_kid"` + EABHMACKey string `json:"eab_hmac_key"` + } + err = json.NewDecoder(response.Body).Decode(&result) + if err != nil { + return nil, account, E.Cause(err, "decode ZeroSSL EAB response") + } + if response.StatusCode != http.StatusOK { + return nil, account, E.New("failed getting ZeroSSL EAB credentials: HTTP ", response.StatusCode) + } + if result.Error.Code != 0 { + return nil, account, E.New("failed getting ZeroSSL EAB credentials: ", result.Error.Type, " (code ", result.Error.Code, ")") + } + + acmeIssuer.Logger.Info("generated ZeroSSL EAB credentials", zap.String("key_id", result.EABKID)) + + return &acme.EAB{ + KeyID: result.EABKID, + MACKey: result.EABHMACKey, + }, account, nil +} + +func newACMEHTTPClient(ctx context.Context, detour string) (*http.Client, error) { + outboundDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "create ACME provider dialer") + } + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return outboundDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSClientConfig: &tls.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: ntp.TimeFuncFromContext(ctx), + }, + // from certmagic defaults (acmeissuer.go) + TLSHandshakeTimeout: 30 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + ExpectContinueTimeout: 2 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: certmagic.HTTPTimeout, + }, nil +} + +type acmeDNSProvider struct { + username string + password string + subdomain string + serverURL string + httpClient *http.Client +} + +type acmeDNSRecord struct { + resourceRecord libdns.RR +} + +func (r acmeDNSRecord) RR() libdns.RR { + return r.resourceRecord +} + +func (p *acmeDNSProvider) AppendRecords(ctx context.Context, _ string, records []libdns.Record) ([]libdns.Record, error) { + if p.username == "" { + return nil, E.New("ACME-DNS username cannot be empty") + } + if p.password == "" { + return nil, E.New("ACME-DNS password cannot be empty") + } + if p.subdomain == "" { + return nil, E.New("ACME-DNS subdomain cannot be empty") + } + if p.serverURL == "" { + return nil, E.New("ACME-DNS server_url cannot be empty") + } + appendedRecords := make([]libdns.Record, 0, len(records)) + for _, record := range records { + resourceRecord := record.RR() + if resourceRecord.Type != "TXT" { + return appendedRecords, E.New("ACME-DNS only supports adding TXT records") + } + requestBody, err := json.Marshal(map[string]string{ + "subdomain": p.subdomain, + "txt": resourceRecord.Data, + }) + if err != nil { + return appendedRecords, E.Cause(err, "marshal ACME-DNS update request") + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, p.serverURL+"/update", bytes.NewReader(requestBody)) + if err != nil { + return appendedRecords, E.Cause(err, "create ACME-DNS update request") + } + request.Header.Set("X-Api-User", p.username) + request.Header.Set("X-Api-Key", p.password) + request.Header.Set("Content-Type", "application/json") + response, err := p.httpClient.Do(request) + if err != nil { + return appendedRecords, E.Cause(err, "update ACME-DNS record") + } + _ = response.Body.Close() + if response.StatusCode != http.StatusOK { + return appendedRecords, E.New("update ACME-DNS record: HTTP ", response.StatusCode) + } + appendedRecords = append(appendedRecords, acmeDNSRecord{resourceRecord: libdns.RR{ + Type: "TXT", + Name: resourceRecord.Name, + Data: resourceRecord.Data, + }}) + } + return appendedRecords, nil +} + +func (p *acmeDNSProvider) DeleteRecords(context.Context, string, []libdns.Record) ([]libdns.Record, error) { + return nil, nil +} diff --git a/service/acme/stub.go b/service/acme/stub.go new file mode 100644 index 0000000000..43a58d6449 --- /dev/null +++ b/service/acme/stub.go @@ -0,0 +1,3 @@ +//go:build !with_acme + +package acme diff --git a/service/origin_ca/service.go b/service/origin_ca/service.go new file mode 100644 index 0000000000..85588c37d5 --- /dev/null +++ b/service/origin_ca/service.go @@ -0,0 +1,618 @@ +package originca + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "errors" + "io" + "io/fs" + "net" + "net/http" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/certificate" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/ntp" + + "github.com/caddyserver/certmagic" +) + +const ( + cloudflareOriginCAEndpoint = "https://api.cloudflare.com/client/v4/certificates" + defaultRequestedValidity = option.CloudflareOriginCARequestValidity5475 + // min of 30 days and certmagic's 1/3 lifetime ratio (maintain.go) + defaultRenewBefore = 30 * 24 * time.Hour + // from certmagic retry backoff range (async.go) + minimumRenewRetryDelay = time.Minute + maximumRenewRetryDelay = time.Hour + storageLockPrefix = "cloudflare-origin-ca" +) + +func RegisterCertificateProvider(registry *certificate.Registry) { + certificate.Register[option.CloudflareOriginCACertificateProviderOptions](registry, C.TypeCloudflareOriginCA, NewCertificateProvider) +} + +var _ adapter.CertificateProviderService = (*Service)(nil) + +type Service struct { + certificate.Adapter + logger log.ContextLogger + ctx context.Context + cancel context.CancelFunc + done chan struct{} + timeFunc func() time.Time + httpClient *http.Client + storage certmagic.Storage + storageIssuerKey string + storageNamesKey string + storageLockKey string + apiToken string + originCAKey string + domain []string + requestType option.CloudflareOriginCARequestType + requestedValidity option.CloudflareOriginCARequestValidity + + access sync.RWMutex + currentCertificate *tls.Certificate + currentLeaf *x509.Certificate +} + +func NewCertificateProvider(ctx context.Context, logger log.ContextLogger, tag string, options option.CloudflareOriginCACertificateProviderOptions) (adapter.CertificateProviderService, error) { + domain, err := normalizeHostnames(options.Domain) + if err != nil { + return nil, err + } + if len(domain) == 0 { + return nil, E.New("missing domain") + } + apiToken := strings.TrimSpace(options.APIToken) + originCAKey := strings.TrimSpace(options.OriginCAKey) + switch { + case apiToken == "" && originCAKey == "": + return nil, E.New("api_token or origin_ca_key is required") + case apiToken != "" && originCAKey != "": + return nil, E.New("api_token and origin_ca_key are mutually exclusive") + } + requestType := options.RequestType + if requestType == "" { + requestType = option.CloudflareOriginCARequestTypeOriginRSA + } + requestedValidity := options.RequestedValidity + if requestedValidity == 0 { + requestedValidity = defaultRequestedValidity + } + ctx, cancel := context.WithCancel(ctx) + serviceDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: option.DialerOptions{ + Detour: options.Detour, + }, + RemoteIsDomain: true, + }) + if err != nil { + cancel() + return nil, E.Cause(err, "create Cloudflare Origin CA dialer") + } + var storage certmagic.Storage + if options.DataDirectory != "" { + storage = &certmagic.FileStorage{Path: options.DataDirectory} + } else { + storage = certmagic.Default.Storage + } + timeFunc := ntp.TimeFuncFromContext(ctx) + if timeFunc == nil { + timeFunc = time.Now + } + storageIssuerKey := C.TypeCloudflareOriginCA + "-" + string(requestType) + storageNamesKey := (&certmagic.CertificateResource{SANs: slices.Clone(domain)}).NamesKey() + storageLockKey := strings.Join([]string{ + storageLockPrefix, + certmagic.StorageKeys.Safe(storageIssuerKey), + certmagic.StorageKeys.Safe(storageNamesKey), + }, "/") + return &Service{ + Adapter: certificate.NewAdapter(C.TypeCloudflareOriginCA, tag), + logger: logger, + ctx: ctx, + cancel: cancel, + timeFunc: timeFunc, + httpClient: &http.Client{Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSClientConfig: &tls.Config{ + RootCAs: adapter.RootPoolFromContext(ctx), + Time: timeFunc, + }, + ForceAttemptHTTP2: true, + }}, + storage: storage, + storageIssuerKey: storageIssuerKey, + storageNamesKey: storageNamesKey, + storageLockKey: storageLockKey, + apiToken: apiToken, + originCAKey: originCAKey, + domain: domain, + requestType: requestType, + requestedValidity: requestedValidity, + }, nil +} + +func (s *Service) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + cachedCertificate, cachedLeaf, err := s.loadCachedCertificate() + if err != nil { + s.logger.Warn(E.Cause(err, "load cached Cloudflare Origin CA certificate")) + } else if cachedCertificate != nil { + s.setCurrentCertificate(cachedCertificate, cachedLeaf) + } + if cachedCertificate == nil { + err = s.issueAndStoreCertificate() + if err != nil { + return err + } + } else if s.shouldRenew(cachedLeaf, s.timeFunc()) { + err = s.issueAndStoreCertificate() + if err != nil { + s.logger.Warn(E.Cause(err, "renew cached Cloudflare Origin CA certificate")) + } + } + s.done = make(chan struct{}) + go s.refreshLoop() + return nil +} + +func (s *Service) Close() error { + s.cancel() + if done := s.done; done != nil { + <-done + } + if transport, loaded := s.httpClient.Transport.(*http.Transport); loaded { + transport.CloseIdleConnections() + } + return nil +} + +func (s *Service) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + s.access.RLock() + certificate := s.currentCertificate + s.access.RUnlock() + if certificate == nil { + return nil, E.New("Cloudflare Origin CA certificate is unavailable") + } + return certificate, nil +} + +func (s *Service) refreshLoop() { + defer close(s.done) + var retryDelay time.Duration + for { + waitDuration := retryDelay + if waitDuration == 0 { + s.access.RLock() + leaf := s.currentLeaf + s.access.RUnlock() + if leaf == nil { + waitDuration = minimumRenewRetryDelay + } else { + refreshAt := leaf.NotAfter.Add(-s.effectiveRenewBefore(leaf)) + waitDuration = refreshAt.Sub(s.timeFunc()) + if waitDuration < minimumRenewRetryDelay { + waitDuration = minimumRenewRetryDelay + } + } + } + timer := time.NewTimer(waitDuration) + select { + case <-s.ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return + case <-timer.C: + } + err := s.issueAndStoreCertificate() + if err != nil { + s.logger.Error(E.Cause(err, "renew Cloudflare Origin CA certificate")) + s.access.RLock() + leaf := s.currentLeaf + s.access.RUnlock() + if leaf == nil { + retryDelay = minimumRenewRetryDelay + } else { + remaining := leaf.NotAfter.Sub(s.timeFunc()) + switch { + case remaining <= minimumRenewRetryDelay: + retryDelay = minimumRenewRetryDelay + case remaining < maximumRenewRetryDelay: + retryDelay = max(remaining/2, minimumRenewRetryDelay) + default: + retryDelay = maximumRenewRetryDelay + } + } + continue + } + retryDelay = 0 + } +} + +func (s *Service) shouldRenew(leaf *x509.Certificate, now time.Time) bool { + return !now.Before(leaf.NotAfter.Add(-s.effectiveRenewBefore(leaf))) +} + +func (s *Service) effectiveRenewBefore(leaf *x509.Certificate) time.Duration { + lifetime := leaf.NotAfter.Sub(leaf.NotBefore) + if lifetime <= 0 { + return 0 + } + return min(lifetime/3, defaultRenewBefore) +} + +func (s *Service) issueAndStoreCertificate() error { + err := s.storage.Lock(s.ctx, s.storageLockKey) + if err != nil { + return E.Cause(err, "lock Cloudflare Origin CA certificate storage") + } + defer func() { + err = s.storage.Unlock(context.WithoutCancel(s.ctx), s.storageLockKey) + if err != nil { + s.logger.Warn(E.Cause(err, "unlock Cloudflare Origin CA certificate storage")) + } + }() + cachedCertificate, cachedLeaf, err := s.loadCachedCertificate() + if err != nil { + s.logger.Warn(E.Cause(err, "load cached Cloudflare Origin CA certificate")) + } else if cachedCertificate != nil && !s.shouldRenew(cachedLeaf, s.timeFunc()) { + s.setCurrentCertificate(cachedCertificate, cachedLeaf) + return nil + } + certificatePEM, privateKeyPEM, tlsCertificate, leaf, err := s.requestCertificate(s.ctx) + if err != nil { + return err + } + issuerData, err := json.Marshal(originCAIssuerData{ + RequestType: s.requestType, + RequestedValidity: s.requestedValidity, + }) + if err != nil { + return E.Cause(err, "encode Cloudflare Origin CA certificate metadata") + } + err = storeCertificateResource(s.ctx, s.storage, s.storageIssuerKey, certmagic.CertificateResource{ + SANs: slices.Clone(s.domain), + CertificatePEM: certificatePEM, + PrivateKeyPEM: privateKeyPEM, + IssuerData: issuerData, + }) + if err != nil { + return E.Cause(err, "store Cloudflare Origin CA certificate") + } + s.setCurrentCertificate(tlsCertificate, leaf) + s.logger.Info("updated Cloudflare Origin CA certificate, expires at ", leaf.NotAfter.Format(time.RFC3339)) + return nil +} + +func (s *Service) requestCertificate(ctx context.Context) ([]byte, []byte, *tls.Certificate, *x509.Certificate, error) { + var privateKey crypto.Signer + switch s.requestType { + case option.CloudflareOriginCARequestTypeOriginRSA: + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, nil, nil, err + } + privateKey = rsaKey + case option.CloudflareOriginCARequestTypeOriginECC: + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, nil, nil, err + } + privateKey = ecKey + default: + return nil, nil, nil, nil, E.New("unsupported Cloudflare Origin CA request type: ", s.requestType) + } + privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "encode private key") + } + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: privateKeyDER, + }) + certificateRequestDER, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: s.domain[0]}, + DNSNames: s.domain, + }, privateKey) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "create certificate request") + } + certificateRequestPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: certificateRequestDER, + }) + requestBody, err := json.Marshal(originCARequest{ + CSR: string(certificateRequestPEM), + Hostnames: s.domain, + RequestType: string(s.requestType), + RequestedValidity: uint16(s.requestedValidity), + }) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "marshal request") + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, cloudflareOriginCAEndpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "create request") + } + request.Header.Set("Accept", "application/json") + request.Header.Set("Content-Type", "application/json") + request.Header.Set("User-Agent", "sing-box/"+C.Version) + if s.apiToken != "" { + request.Header.Set("Authorization", "Bearer "+s.apiToken) + } else { + request.Header.Set("X-Auth-User-Service-Key", s.originCAKey) + } + response, err := s.httpClient.Do(request) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "request certificate from Cloudflare") + } + defer response.Body.Close() + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "read Cloudflare response") + } + var responseEnvelope originCAResponse + err = json.Unmarshal(responseBody, &responseEnvelope) + if err != nil && response.StatusCode >= http.StatusOK && response.StatusCode < http.StatusMultipleChoices { + return nil, nil, nil, nil, E.Cause(err, "decode Cloudflare response") + } + if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices { + return nil, nil, nil, nil, buildOriginCAError(response.StatusCode, responseEnvelope.Errors, responseBody) + } + if !responseEnvelope.Success { + return nil, nil, nil, nil, buildOriginCAError(response.StatusCode, responseEnvelope.Errors, responseBody) + } + if responseEnvelope.Result.Certificate == "" { + return nil, nil, nil, nil, E.New("Cloudflare Origin CA response is missing certificate data") + } + certificatePEM := []byte(responseEnvelope.Result.Certificate) + tlsCertificate, leaf, err := parseKeyPair(certificatePEM, privateKeyPEM) + if err != nil { + return nil, nil, nil, nil, E.Cause(err, "parse issued certificate") + } + if !s.matchesCertificate(leaf) { + return nil, nil, nil, nil, E.New("issued Cloudflare Origin CA certificate does not match requested hostnames or key type") + } + return certificatePEM, privateKeyPEM, tlsCertificate, leaf, nil +} + +func (s *Service) loadCachedCertificate() (*tls.Certificate, *x509.Certificate, error) { + certificateResource, err := loadCertificateResource(s.ctx, s.storage, s.storageIssuerKey, s.storageNamesKey) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, nil, nil + } + return nil, nil, err + } + tlsCertificate, leaf, err := parseKeyPair(certificateResource.CertificatePEM, certificateResource.PrivateKeyPEM) + if err != nil { + return nil, nil, E.Cause(err, "parse cached key pair") + } + if s.timeFunc().After(leaf.NotAfter) { + return nil, nil, nil + } + if !s.matchesCertificate(leaf) { + return nil, nil, nil + } + return tlsCertificate, leaf, nil +} + +func (s *Service) matchesCertificate(leaf *x509.Certificate) bool { + if leaf == nil { + return false + } + leafHostnames := leaf.DNSNames + if len(leafHostnames) == 0 && leaf.Subject.CommonName != "" { + leafHostnames = []string{leaf.Subject.CommonName} + } + normalizedLeafHostnames, err := normalizeHostnames(leafHostnames) + if err != nil { + return false + } + if !slices.Equal(normalizedLeafHostnames, s.domain) { + return false + } + switch s.requestType { + case option.CloudflareOriginCARequestTypeOriginRSA: + return leaf.PublicKeyAlgorithm == x509.RSA + case option.CloudflareOriginCARequestTypeOriginECC: + return leaf.PublicKeyAlgorithm == x509.ECDSA + default: + return false + } +} + +func (s *Service) setCurrentCertificate(certificate *tls.Certificate, leaf *x509.Certificate) { + s.access.Lock() + s.currentCertificate = certificate + s.currentLeaf = leaf + s.access.Unlock() +} + +func normalizeHostnames(hostnames []string) ([]string, error) { + normalizedHostnames := make([]string, 0, len(hostnames)) + seen := make(map[string]struct{}, len(hostnames)) + for _, hostname := range hostnames { + normalizedHostname := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(hostname, "."))) + if normalizedHostname == "" { + return nil, E.New("hostname is empty") + } + if net.ParseIP(normalizedHostname) != nil { + return nil, E.New("hostname cannot be an IP address: ", normalizedHostname) + } + if strings.Contains(normalizedHostname, "*") { + if !strings.HasPrefix(normalizedHostname, "*.") || strings.Count(normalizedHostname, "*") != 1 { + return nil, E.New("invalid wildcard hostname: ", normalizedHostname) + } + suffix := strings.TrimPrefix(normalizedHostname, "*.") + if strings.Count(suffix, ".") == 0 { + return nil, E.New("wildcard hostname must cover a multi-label domain: ", normalizedHostname) + } + normalizedHostname = "*." + suffix + } + if _, loaded := seen[normalizedHostname]; loaded { + continue + } + seen[normalizedHostname] = struct{}{} + normalizedHostnames = append(normalizedHostnames, normalizedHostname) + } + slices.Sort(normalizedHostnames) + return normalizedHostnames, nil +} + +func parseKeyPair(certificatePEM []byte, privateKeyPEM []byte) (*tls.Certificate, *x509.Certificate, error) { + keyPair, err := tls.X509KeyPair(certificatePEM, privateKeyPEM) + if err != nil { + return nil, nil, err + } + if len(keyPair.Certificate) == 0 { + return nil, nil, E.New("certificate chain is empty") + } + leaf, err := x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + return nil, nil, err + } + keyPair.Leaf = leaf + return &keyPair, leaf, nil +} + +func storeCertificateResource(ctx context.Context, storage certmagic.Storage, issuerKey string, certificateResource certmagic.CertificateResource) error { + metaBytes, err := json.MarshalIndent(certificateResource, "", "\t") + if err != nil { + return err + } + namesKey := certificateResource.NamesKey() + keyValueList := []struct { + key string + value []byte + }{ + { + key: certmagic.StorageKeys.SitePrivateKey(issuerKey, namesKey), + value: certificateResource.PrivateKeyPEM, + }, + { + key: certmagic.StorageKeys.SiteCert(issuerKey, namesKey), + value: certificateResource.CertificatePEM, + }, + { + key: certmagic.StorageKeys.SiteMeta(issuerKey, namesKey), + value: metaBytes, + }, + } + for i, item := range keyValueList { + err = storage.Store(ctx, item.key, item.value) + if err != nil { + for j := i - 1; j >= 0; j-- { + storage.Delete(ctx, keyValueList[j].key) + } + return err + } + } + return nil +} + +func loadCertificateResource(ctx context.Context, storage certmagic.Storage, issuerKey string, namesKey string) (certmagic.CertificateResource, error) { + privateKeyPEM, err := storage.Load(ctx, certmagic.StorageKeys.SitePrivateKey(issuerKey, namesKey)) + if err != nil { + return certmagic.CertificateResource{}, err + } + certificatePEM, err := storage.Load(ctx, certmagic.StorageKeys.SiteCert(issuerKey, namesKey)) + if err != nil { + return certmagic.CertificateResource{}, err + } + metaBytes, err := storage.Load(ctx, certmagic.StorageKeys.SiteMeta(issuerKey, namesKey)) + if err != nil { + return certmagic.CertificateResource{}, err + } + var certificateResource certmagic.CertificateResource + err = json.Unmarshal(metaBytes, &certificateResource) + if err != nil { + return certmagic.CertificateResource{}, E.Cause(err, "decode Cloudflare Origin CA certificate metadata") + } + certificateResource.PrivateKeyPEM = privateKeyPEM + certificateResource.CertificatePEM = certificatePEM + return certificateResource, nil +} + +func buildOriginCAError(statusCode int, responseErrors []originCAResponseError, responseBody []byte) error { + if len(responseErrors) > 0 { + messageList := make([]string, 0, len(responseErrors)) + for _, responseError := range responseErrors { + if responseError.Message == "" { + continue + } + if responseError.Code != 0 { + messageList = append(messageList, responseError.Message+" (code "+strconv.Itoa(responseError.Code)+")") + } else { + messageList = append(messageList, responseError.Message) + } + } + if len(messageList) > 0 { + return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode, " ", strings.Join(messageList, ", ")) + } + } + responseText := strings.TrimSpace(string(responseBody)) + if responseText == "" { + return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode) + } + return E.New("Cloudflare Origin CA request failed: HTTP ", statusCode, " ", responseText) +} + +type originCARequest struct { + CSR string `json:"csr"` + Hostnames []string `json:"hostnames"` + RequestType string `json:"request_type"` + RequestedValidity uint16 `json:"requested_validity"` +} + +type originCAResponse struct { + Success bool `json:"success"` + Errors []originCAResponseError `json:"errors"` + Result originCAResponseResult `json:"result"` +} + +type originCAResponseError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type originCAResponseResult struct { + Certificate string `json:"certificate"` +} + +type originCAIssuerData struct { + RequestType option.CloudflareOriginCARequestType `json:"request_type,omitempty"` + RequestedValidity option.CloudflareOriginCARequestValidity `json:"requested_validity,omitempty"` +} From ab323e0eb9df5e4ed7116c5b3a5949b3fdca6145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 30 Mar 2026 23:45:16 +0800 Subject: [PATCH 07/41] Add BBR profile and hop interval randomization for Hysteria2 --- docs/configuration/inbound/hysteria2.md | 13 ++++++++++ docs/configuration/inbound/hysteria2.zh.md | 13 ++++++++++ docs/configuration/outbound/hysteria2.md | 27 +++++++++++++++++++-- docs/configuration/outbound/hysteria2.zh.md | 25 ++++++++++++++++++- go.mod | 6 ++--- go.sum | 4 +-- option/hysteria2.go | 19 +++++++++------ protocol/hysteria2/inbound.go | 1 + protocol/hysteria2/outbound.go | 2 ++ 9 files changed, 94 insertions(+), 16 deletions(-) diff --git a/docs/configuration/inbound/hysteria2.md b/docs/configuration/inbound/hysteria2.md index 3b7332b064..8426be2459 100644 --- a/docs/configuration/inbound/hysteria2.md +++ b/docs/configuration/inbound/hysteria2.md @@ -2,6 +2,10 @@ icon: material/alert-decagram --- +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [bbr_profile](#bbr_profile) + !!! quote "Changes in sing-box 1.11.0" :material-alert: [masquerade](#masquerade) @@ -31,6 +35,7 @@ icon: material/alert-decagram "ignore_client_bandwidth": false, "tls": {}, "masquerade": "", // or {} + "bbr_profile": "", "brutal_debug": false } ``` @@ -141,6 +146,14 @@ Fixed response headers. Fixed response content. +#### bbr_profile + +!!! question "Since sing-box 1.14.0" + +BBR congestion control algorithm profile, one of `conservative` `standard` `aggressive`. + +`standard` is used by default. + #### brutal_debug Enable debug information logging for Hysteria Brutal CC. diff --git a/docs/configuration/inbound/hysteria2.zh.md b/docs/configuration/inbound/hysteria2.zh.md index 35a3c25bc7..0c5e918ed9 100644 --- a/docs/configuration/inbound/hysteria2.zh.md +++ b/docs/configuration/inbound/hysteria2.zh.md @@ -2,6 +2,10 @@ icon: material/alert-decagram --- +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [bbr_profile](#bbr_profile) + !!! quote "sing-box 1.11.0 中的更改" :material-alert: [masquerade](#masquerade) @@ -31,6 +35,7 @@ icon: material/alert-decagram "ignore_client_bandwidth": false, "tls": {}, "masquerade": "", // 或 {} + "bbr_profile": "", "brutal_debug": false } ``` @@ -138,6 +143,14 @@ HTTP3 服务器认证失败时的行为 (对象配置)。 固定响应内容。 +#### bbr_profile + +!!! question "自 sing-box 1.14.0 起" + +BBR 拥塞控制算法配置,可选 `conservative` `standard` `aggressive`。 + +默认使用 `standard`。 + #### brutal_debug 启用 Hysteria Brutal CC 的调试信息日志记录。 diff --git a/docs/configuration/outbound/hysteria2.md b/docs/configuration/outbound/hysteria2.md index dc0a496500..a71dd1e070 100644 --- a/docs/configuration/outbound/hysteria2.md +++ b/docs/configuration/outbound/hysteria2.md @@ -1,3 +1,8 @@ +!!! quote "Changes in sing-box 1.14.0" + + :material-plus: [hop_interval_max](#hop_interval_max) + :material-plus: [bbr_profile](#bbr_profile) + !!! quote "Changes in sing-box 1.11.0" :material-plus: [server_ports](#server_ports) @@ -9,13 +14,14 @@ { "type": "hysteria2", "tag": "hy2-out", - + "server": "127.0.0.1", "server_port": 1080, "server_ports": [ "2080:3000" ], "hop_interval": "", + "hop_interval_max": "", "up_mbps": 100, "down_mbps": 100, "obfs": { @@ -25,8 +31,9 @@ "password": "goofy_ahh_password", "network": "tcp", "tls": {}, + "bbr_profile": "", "brutal_debug": false, - + ... // Dial Fields } ``` @@ -75,6 +82,14 @@ Port hopping interval. `30s` is used by default. +#### hop_interval_max + +!!! question "Since sing-box 1.14.0" + +Maximum port hopping interval, used for randomization. + +If set, the actual hop interval will be randomly chosen between `hop_interval` and `hop_interval_max`. + #### up_mbps, down_mbps Max bandwidth, in Mbps. @@ -109,6 +124,14 @@ Both is enabled by default. TLS configuration, see [TLS](/configuration/shared/tls/#outbound). +#### bbr_profile + +!!! question "Since sing-box 1.14.0" + +BBR congestion control algorithm profile, one of `conservative` `standard` `aggressive`. + +`standard` is used by default. + #### brutal_debug Enable debug information logging for Hysteria Brutal CC. diff --git a/docs/configuration/outbound/hysteria2.zh.md b/docs/configuration/outbound/hysteria2.zh.md index bc77f4ec92..0fb17bbdc3 100644 --- a/docs/configuration/outbound/hysteria2.zh.md +++ b/docs/configuration/outbound/hysteria2.zh.md @@ -1,3 +1,8 @@ +!!! quote "sing-box 1.14.0 中的更改" + + :material-plus: [hop_interval_max](#hop_interval_max) + :material-plus: [bbr_profile](#bbr_profile) + !!! quote "sing-box 1.11.0 中的更改" :material-plus: [server_ports](#server_ports) @@ -16,6 +21,7 @@ "2080:3000" ], "hop_interval": "", + "hop_interval_max": "", "up_mbps": 100, "down_mbps": 100, "obfs": { @@ -25,8 +31,9 @@ "password": "goofy_ahh_password", "network": "tcp", "tls": {}, + "bbr_profile": "", "brutal_debug": false, - + ... // 拨号字段 } ``` @@ -73,6 +80,14 @@ 默认使用 `30s`。 +#### hop_interval_max + +!!! question "自 sing-box 1.14.0 起" + +最大端口跳跃间隔,用于随机化。 + +如果设置,实际跳跃间隔将在 `hop_interval` 和 `hop_interval_max` 之间随机选择。 + #### up_mbps, down_mbps 最大带宽。 @@ -107,6 +122,14 @@ QUIC 流量混淆器密码. TLS 配置, 参阅 [TLS](/zh/configuration/shared/tls/#出站)。 +#### bbr_profile + +!!! question "自 sing-box 1.14.0 起" + +BBR 拥塞控制算法配置,可选 `conservative` `standard` `aggressive`。 + +默认使用 `standard`。 + #### brutal_debug 启用 Hysteria Brutal CC 的调试信息日志记录。 diff --git a/go.mod b/go.mod index 4726fd753e..630e254bfa 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/anthropics/anthropic-sdk-go v1.26.0 github.com/anytls/sing-anytls v0.0.11 github.com/caddyserver/certmagic v0.25.2 + github.com/caddyserver/zerossl v0.1.5 github.com/coder/websocket v1.8.14 github.com/cretz/bine v0.2.0 github.com/database64128/tfo-go/v2 v2.3.2 @@ -19,6 +20,7 @@ require ( github.com/libdns/acmedns v0.5.0 github.com/libdns/alidns v1.0.6 github.com/libdns/cloudflare v0.2.2 + github.com/libdns/libdns v1.1.1 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mdlayher/netlink v1.9.0 github.com/metacubex/utls v1.8.4 @@ -37,7 +39,7 @@ require ( github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 github.com/sagernet/sing v0.8.3 github.com/sagernet/sing-mux v0.3.4 - github.com/sagernet/sing-quic v0.6.1 + github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212 github.com/sagernet/sing-shadowsocks v0.2.8 github.com/sagernet/sing-shadowsocks2 v0.2.1 github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 @@ -69,7 +71,6 @@ require ( github.com/akutz/memconn v0.1.0 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/andybalholm/brotli v1.1.0 // indirect - github.com/caddyserver/zerossl v0.1.5 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect github.com/database64128/netx-go v0.1.1 // indirect @@ -96,7 +97,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/libdns/libdns v1.1.1 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect diff --git a/go.sum b/go.sum index a297e3a9b2..a7c05508e3 100644 --- a/go.sum +++ b/go.sum @@ -240,8 +240,8 @@ github.com/sagernet/sing v0.8.3 h1:zGMy9M1deBPEew9pCYIUHKeE+/lDQ5A2CBqjBjjzqkA= github.com/sagernet/sing v0.8.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s= github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= -github.com/sagernet/sing-quic v0.6.1 h1:lx0tcm99wIA1RkyvILNzRSsMy1k7TTQYIhx71E/WBlw= -github.com/sagernet/sing-quic v0.6.1/go.mod h1:K5bWvITOm4vE10fwLfrWpw27bCoVJ+tfQ79tOWg+Ko8= +github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212 h1:7mFOUqy+DyOj7qKGd1X54UMXbnbJiiMileK/tn17xYc= +github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212/go.mod h1:K5bWvITOm4vE10fwLfrWpw27bCoVJ+tfQ79tOWg+Ko8= github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE= github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI= github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo= diff --git a/option/hysteria2.go b/option/hysteria2.go index a014513630..e31c8de345 100644 --- a/option/hysteria2.go +++ b/option/hysteria2.go @@ -19,6 +19,7 @@ type Hysteria2InboundOptions struct { IgnoreClientBandwidth bool `json:"ignore_client_bandwidth,omitempty"` InboundTLSOptionsContainer Masquerade *Hysteria2Masquerade `json:"masquerade,omitempty"` + BBRProfile string `json:"bbr_profile,omitempty"` BrutalDebug bool `json:"brutal_debug,omitempty"` } @@ -112,13 +113,15 @@ type Hysteria2MasqueradeString struct { type Hysteria2OutboundOptions struct { DialerOptions ServerOptions - ServerPorts badoption.Listable[string] `json:"server_ports,omitempty"` - HopInterval badoption.Duration `json:"hop_interval,omitempty"` - UpMbps int `json:"up_mbps,omitempty"` - DownMbps int `json:"down_mbps,omitempty"` - Obfs *Hysteria2Obfs `json:"obfs,omitempty"` - Password string `json:"password,omitempty"` - Network NetworkList `json:"network,omitempty"` + ServerPorts badoption.Listable[string] `json:"server_ports,omitempty"` + HopInterval badoption.Duration `json:"hop_interval,omitempty"` + HopIntervalMax badoption.Duration `json:"hop_interval_max,omitempty"` + UpMbps int `json:"up_mbps,omitempty"` + DownMbps int `json:"down_mbps,omitempty"` + Obfs *Hysteria2Obfs `json:"obfs,omitempty"` + Password string `json:"password,omitempty"` + Network NetworkList `json:"network,omitempty"` OutboundTLSOptionsContainer - BrutalDebug bool `json:"brutal_debug,omitempty"` + BBRProfile string `json:"bbr_profile,omitempty"` + BrutalDebug bool `json:"brutal_debug,omitempty"` } diff --git a/protocol/hysteria2/inbound.go b/protocol/hysteria2/inbound.go index bb5980701f..5fe8848d9a 100644 --- a/protocol/hysteria2/inbound.go +++ b/protocol/hysteria2/inbound.go @@ -125,6 +125,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo UDPTimeout: udpTimeout, Handler: inbound, MasqueradeHandler: masqueradeHandler, + BBRProfile: options.BBRProfile, }) if err != nil { return nil, err diff --git a/protocol/hysteria2/outbound.go b/protocol/hysteria2/outbound.go index d4382fdcdf..4a0c9f2430 100644 --- a/protocol/hysteria2/outbound.go +++ b/protocol/hysteria2/outbound.go @@ -73,12 +73,14 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL ServerAddress: options.ServerOptions.Build(), ServerPorts: options.ServerPorts, HopInterval: time.Duration(options.HopInterval), + HopIntervalMax: time.Duration(options.HopIntervalMax), SendBPS: uint64(options.UpMbps * hysteria.MbpsToBps), ReceiveBPS: uint64(options.DownMbps * hysteria.MbpsToBps), SalamanderPassword: salamanderPassword, Password: options.Password, TLSConfig: tlsConfig, UDPDisabled: !common.Contains(networkList, N.NetworkUDP), + BBRProfile: options.BBRProfile, }) if err != nil { return nil, err From ebf8a213b67a26f7ad8b4a86d1764e23a4fa6eb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 7 Mar 2026 16:40:34 +0800 Subject: [PATCH 08/41] Bump version --- docs/changelog.md | 89 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 0b152c95d7..3ef3ca02ae 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,14 +2,50 @@ icon: material/alert-decagram --- +#### 1.14.0-alpha.8 + +* Add BBR profile and hop interval randomization for Hysteria2 **1** +* Fixes and improvements + +**1**: + +See [Hysteria2 Inbound](/configuration/inbound/hysteria2/#bbr_profile) and [Hysteria2 Outbound](/configuration/outbound/hysteria2/#bbr_profile). + +#### 1.14.0-alpha.8 + +* Fixes and improvements + #### 1.13.5 * Fixes and improvements +#### 1.14.0-alpha.7 + +* Fixes and improvements + #### 1.13.4 * Fixes and improvements +#### 1.14.0-alpha.4 + +* Refactor ACME support to certificate provider system **1** +* Add Cloudflare Origin CA certificate provider **2** +* Add Tailscale certificate provider **3** +* Fixes and improvements + +**1**: + +See [Certificate Provider](/configuration/shared/certificate-provider/) and [Migration](/migration/#migrate-inline-acme-to-certificate-provider). + +**2**: + +See [Cloudflare Origin CA](/configuration/shared/certificate-provider/cloudflare-origin-ca). + +**3**: + +See [Tailscale](/configuration/shared/certificate-provider/tailscale). + #### 1.13.3 * Add OpenWrt and Alpine APK packages to release **1** @@ -34,6 +70,59 @@ from [SagerNet/go](https://github.com/SagerNet/go). See [OCM](/configuration/service/ocm). +#### 1.12.24 + +* Fixes and improvements + +#### 1.14.0-alpha.2 + +* Add OpenWrt and Alpine APK packages to release **1** +* Backport to macOS 10.13 High Sierra **2** +* OCM service: Add WebSocket support for Responses API **3** +* Fixes and improvements + +**1**: + +Alpine APK files use `linux` in the filename to distinguish from OpenWrt APKs which use the `openwrt` prefix: + +- OpenWrt: `sing-box_{version}_openwrt_{architecture}.apk` +- Alpine: `sing-box_{version}_linux_{architecture}.apk` + +**2**: + +Legacy macOS binaries (with `-legacy-macos-10.13` suffix) now support +macOS 10.13 High Sierra, built using Go 1.25 with patches +from [SagerNet/go](https://github.com/SagerNet/go). + +**3**: + +See [OCM](/configuration/service/ocm). + +#### 1.14.0-alpha.1 + +* Add `source_mac_address` and `source_hostname` rule items **1** +* Add `include_mac_address` and `exclude_mac_address` TUN options **2** +* Update NaiveProxy to 145.0.7632.159 **3** +* Fixes and improvements + +**1**: + +New rule items for matching LAN devices by MAC address and hostname via neighbor resolution. +Supported on Linux, macOS, or in graphical clients on Android and macOS. + +See [Route Rule](/configuration/route/rule/#source_mac_address), [DNS Rule](/configuration/dns/rule/#source_mac_address) and [Neighbor Resolution](/configuration/shared/neighbor/). + +**2**: + +Limit or exclude devices from TUN routing by MAC address. +Only supported on Linux with `auto_route` and `auto_redirect` enabled. + +See [TUN](/configuration/inbound/tun/#include_mac_address). + +**3**: + +This is not an official update from NaiveProxy. Instead, it's a Chromium codebase update maintained by Project S. + #### 1.13.2 * Fixes and improvements From b68f4670b031b18670c531f6892f273fd5fc71bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 18 Mar 2026 14:30:37 +0800 Subject: [PATCH 09/41] Add cloudflare tunnel inbound --- constant/proxy.go | 3 + go.mod | 4 +- go.sum | 6 + include/cloudflare_tunnel.go | 12 + include/cloudflare_tunnel_stub.go | 20 + include/registry.go | 1 + option/cloudflare_tunnel.go | 13 + protocol/cloudflare/cloudflare_ca.pem | 75 + protocol/cloudflare/connection_http2.go | 367 ++ protocol/cloudflare/connection_quic.go | 245 + protocol/cloudflare/control.go | 176 + protocol/cloudflare/credentials.go | 44 + protocol/cloudflare/credentials_test.go | 94 + protocol/cloudflare/datagram_v2.go | 346 ++ protocol/cloudflare/datagram_v3.go | 319 ++ protocol/cloudflare/dispatch.go | 373 ++ protocol/cloudflare/dispatch_test.go | 137 + protocol/cloudflare/edge_discovery.go | 122 + protocol/cloudflare/edge_discovery_test.go | 88 + protocol/cloudflare/header.go | 55 + protocol/cloudflare/helpers_test.go | 194 + protocol/cloudflare/inbound.go | 417 ++ protocol/cloudflare/ingress_test.go | 148 + protocol/cloudflare/integration_test.go | 166 + protocol/cloudflare/root_ca.go | 24 + protocol/cloudflare/stream.go | 212 + protocol/cloudflare/stream_test.go | 95 + protocol/cloudflare/tunnelrpc/go.capnp | 31 + .../tunnelrpc/quic_metadata_protocol.capnp | 28 + .../tunnelrpc/quic_metadata_protocol.capnp.go | 394 ++ protocol/cloudflare/tunnelrpc/tunnelrpc.capnp | 195 + .../cloudflare/tunnelrpc/tunnelrpc.capnp.go | 4843 +++++++++++++++++ release/DEFAULT_BUILD_TAGS | 2 +- release/DEFAULT_BUILD_TAGS_OTHERS | 2 +- release/DEFAULT_BUILD_TAGS_WINDOWS | 2 +- 35 files changed, 9249 insertions(+), 4 deletions(-) create mode 100644 include/cloudflare_tunnel.go create mode 100644 include/cloudflare_tunnel_stub.go create mode 100644 option/cloudflare_tunnel.go create mode 100644 protocol/cloudflare/cloudflare_ca.pem create mode 100644 protocol/cloudflare/connection_http2.go create mode 100644 protocol/cloudflare/connection_quic.go create mode 100644 protocol/cloudflare/control.go create mode 100644 protocol/cloudflare/credentials.go create mode 100644 protocol/cloudflare/credentials_test.go create mode 100644 protocol/cloudflare/datagram_v2.go create mode 100644 protocol/cloudflare/datagram_v3.go create mode 100644 protocol/cloudflare/dispatch.go create mode 100644 protocol/cloudflare/dispatch_test.go create mode 100644 protocol/cloudflare/edge_discovery.go create mode 100644 protocol/cloudflare/edge_discovery_test.go create mode 100644 protocol/cloudflare/header.go create mode 100644 protocol/cloudflare/helpers_test.go create mode 100644 protocol/cloudflare/inbound.go create mode 100644 protocol/cloudflare/ingress_test.go create mode 100644 protocol/cloudflare/integration_test.go create mode 100644 protocol/cloudflare/root_ca.go create mode 100644 protocol/cloudflare/stream.go create mode 100644 protocol/cloudflare/stream_test.go create mode 100644 protocol/cloudflare/tunnelrpc/go.capnp create mode 100644 protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp create mode 100644 protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp.go create mode 100644 protocol/cloudflare/tunnelrpc/tunnelrpc.capnp create mode 100644 protocol/cloudflare/tunnelrpc/tunnelrpc.capnp.go diff --git a/constant/proxy.go b/constant/proxy.go index add66c95e5..91b3bc98e9 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -25,6 +25,7 @@ const ( TypeTUIC = "tuic" TypeHysteria2 = "hysteria2" TypeTailscale = "tailscale" + TypeCloudflareTunnel = "cloudflare-tunnel" TypeDERP = "derp" TypeResolved = "resolved" TypeSSMAPI = "ssm-api" @@ -90,6 +91,8 @@ func ProxyDisplayName(proxyType string) string { return "AnyTLS" case TypeTailscale: return "Tailscale" + case TypeCloudflareTunnel: + return "Cloudflare Tunnel" case TypeSelector: return "Selector" case TypeURLTest: diff --git a/go.mod b/go.mod index 630e254bfa..9709176e4e 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/go-chi/render v1.0.3 github.com/godbus/dbus/v5 v5.2.2 github.com/gofrs/uuid/v5 v5.4.0 + github.com/google/uuid v1.6.0 github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91 github.com/jsimonetti/rtnetlink v1.4.0 github.com/keybase/go-keychain v0.0.1 @@ -63,6 +64,7 @@ require ( google.golang.org/grpc v1.79.1 google.golang.org/protobuf v1.36.11 howett.net/plist v1.0.1 + zombiezen.com/go/capnproto2 v2.18.2+incompatible ) require ( @@ -91,7 +93,6 @@ require ( github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -148,6 +149,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/tinylib/msgp v1.6.3 // indirect github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/zeebo/blake3 v0.2.4 // indirect diff --git a/go.sum b/go.sum index a7c05508e3..c5d7315e8d 100644 --- a/go.sum +++ b/go.sum @@ -142,6 +142,8 @@ github.com/openai/openai-go/v3 v3.26.0 h1:bRt6H/ozMNt/dDkN4gobnLqaEGrRGBzmbVs0xx github.com/openai/openai-go/v3 v3.26.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= @@ -294,6 +296,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s= +github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= @@ -401,3 +405,5 @@ lukechampine.com/blake3 v1.3.0 h1:sJ3XhFINmHSrYCgl958hscfIa3bw8x4DqMP3u1YvoYE= lukechampine.com/blake3 v1.3.0/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= +zombiezen.com/go/capnproto2 v2.18.2+incompatible h1:v3BD1zbruvffn7zjJUU5Pn8nZAB11bhZSQC4W+YnnKo= +zombiezen.com/go/capnproto2 v2.18.2+incompatible/go.mod h1:XO5Pr2SbXgqZwn0m0Ru54QBqpOf4K5AYBO+8LAOBQEQ= diff --git a/include/cloudflare_tunnel.go b/include/cloudflare_tunnel.go new file mode 100644 index 0000000000..80273a313a --- /dev/null +++ b/include/cloudflare_tunnel.go @@ -0,0 +1,12 @@ +//go:build with_cloudflare_tunnel + +package include + +import ( + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/protocol/cloudflare" +) + +func registerCloudflareTunnelInbound(registry *inbound.Registry) { + cloudflare.RegisterInbound(registry) +} diff --git a/include/cloudflare_tunnel_stub.go b/include/cloudflare_tunnel_stub.go new file mode 100644 index 0000000000..65c676ab0c --- /dev/null +++ b/include/cloudflare_tunnel_stub.go @@ -0,0 +1,20 @@ +//go:build !with_cloudflare_tunnel + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerCloudflareTunnelInbound(registry *inbound.Registry) { + inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { + return nil, E.New(`Cloudflare Tunnel is not included in this build, rebuild with -tags with_cloudflare_tunnel`) + }) +} diff --git a/include/registry.go b/include/registry.go index eb22cce1fe..4cecfcda0b 100644 --- a/include/registry.go +++ b/include/registry.go @@ -66,6 +66,7 @@ func InboundRegistry() *inbound.Registry { anytls.RegisterInbound(registry) registerQUICInbounds(registry) + registerCloudflareTunnelInbound(registry) registerStubForRemovedInbounds(registry) return registry diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go new file mode 100644 index 0000000000..a1a2c44425 --- /dev/null +++ b/option/cloudflare_tunnel.go @@ -0,0 +1,13 @@ +package option + +import "github.com/sagernet/sing/common/json/badoption" + +type CloudflareTunnelInboundOptions struct { + Token string `json:"token,omitempty"` + CredentialPath string `json:"credential_path,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod badoption.Duration `json:"grace_period,omitempty"` +} diff --git a/protocol/cloudflare/cloudflare_ca.pem b/protocol/cloudflare/cloudflare_ca.pem new file mode 100644 index 0000000000..c9c4819f76 --- /dev/null +++ b/protocol/cloudflare/cloudflare_ca.pem @@ -0,0 +1,75 @@ +-----BEGIN CERTIFICATE----- +MIICiTCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw +gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T +YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL +Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0 +eTAeFw0xOTA4MjMyMTA4MDBaFw0yOTA4MTUxNzAwMDBaMIGPMQswCQYDVQQGEwJV +UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ +MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP +cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB +BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA +xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC +AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5 +tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E +AwIDSQAwRgIhAKilfntP2ILGZjwajktkBtXE1pB4Y/fjAfLkIRUzrI15AiEA5UCL +XYZZ9m2c3fKwIenMMojL1eqydsgqj/wK4p5kagQ= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV +BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91 +ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH +Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx +MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD +bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg +Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw +EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm +V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb +OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB +yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ +SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I +lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD +VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw +HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD +ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ +wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8 +VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz +hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk +MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6 +0iykLhH1trywrKRMVw67F44IE8Y= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIGCjCCA/KgAwIBAgIIV5G6lVbCLmEwDQYJKoZIhvcNAQENBQAwgZAxCzAJBgNV +BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmln +aW4gUHVsbDEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZv +cm5pYTEjMCEGA1UEAxMab3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwHhcNMTkx +MDEwMTg0NTAwWhcNMjkxMTAxMTcwMDAwWjCBkDELMAkGA1UEBhMCVVMxGTAXBgNV +BAoTEENsb3VkRmxhcmUsIEluYy4xFDASBgNVBAsTC09yaWdpbiBQdWxsMRYwFAYD +VQQHEw1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMSMwIQYDVQQD +ExpvcmlnaW4tcHVsbC5jbG91ZGZsYXJlLm5ldDCCAiIwDQYJKoZIhvcNAQEBBQAD +ggIPADCCAgoCggIBAN2y2zojYfl0bKfhp0AJBFeV+jQqbCw3sHmvEPwLmqDLqynI +42tZXR5y914ZB9ZrwbL/K5O46exd/LujJnV2b3dzcx5rtiQzso0xzljqbnbQT20e +ihx/WrF4OkZKydZzsdaJsWAPuplDH5P7J82q3re88jQdgE5hqjqFZ3clCG7lxoBw +hLaazm3NJJlUfzdk97ouRvnFGAuXd5cQVx8jYOOeU60sWqmMe4QHdOvpqB91bJoY +QSKVFjUgHeTpN8tNpKJfb9LIn3pun3bC9NKNHtRKMNX3Kl/sAPq7q/AlndvA2Kw3 +Dkum2mHQUGdzVHqcOgea9BGjLK2h7SuX93zTWL02u799dr6Xkrad/WShHchfjjRn +aL35niJUDr02YJtPgxWObsrfOU63B8juLUphW/4BOjjJyAG5l9j1//aUGEi/sEe5 +lqVv0P78QrxoxR+MMXiJwQab5FB8TG/ac6mRHgF9CmkX90uaRh+OC07XjTdfSKGR +PpM9hB2ZhLol/nf8qmoLdoD5HvODZuKu2+muKeVHXgw2/A6wM7OwrinxZiyBk5Hh +CvaADH7PZpU6z/zv5NU5HSvXiKtCzFuDu4/Zfi34RfHXeCUfHAb4KfNRXJwMsxUa ++4ZpSAX2G6RnGU5meuXpU5/V+DQJp/e69XyyY6RXDoMywaEFlIlXBqjRRA2pAgMB +AAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgECMB0GA1Ud +DgQWBBRDWUsraYuA4REzalfNVzjann3F6zAfBgNVHSMEGDAWgBRDWUsraYuA4REz +alfNVzjann3F6zANBgkqhkiG9w0BAQ0FAAOCAgEAkQ+T9nqcSlAuW/90DeYmQOW1 +QhqOor5psBEGvxbNGV2hdLJY8h6QUq48BCevcMChg/L1CkznBNI40i3/6heDn3IS +zVEwXKf34pPFCACWVMZxbQjkNRTiH8iRur9EsaNQ5oXCPJkhwg2+IFyoPAAYURoX +VcI9SCDUa45clmYHJ/XYwV1icGVI8/9b2JUqklnOTa5tugwIUi5sTfipNcJXHhgz +6BKYDl0/UP0lLKbsUETXeTGDiDpxZYIgbcFrRDDkHC6BSvdWVEiH5b9mH2BON60z +0O0j8EEKTwi9jnafVtZQXP/D8yoVowdFDjXcKkOPF/1gIh9qrFR6GdoPVgB3SkLc +5ulBqZaCHm563jsvWb/kXJnlFxW+1bsO9BDD6DweBcGdNurgmH625wBXksSdD7y/ +fakk8DagjbjKShYlPEFOAqEcliwjF45eabL0t27MJV61O/jHzHL3dknXeE4BDa2j +bA+JbyJeUMtU7KMsxvx82RmhqBEJJDBCJ3scVptvhDMRrtqDBW5JShxoAOcpFQGm +iYWicn46nPDjgTU0bX1ZPpTpryXbvciVL5RkVBuyX2ntcOLDPlZWgxZCBp96x07F +AnOzKgZk4RzZPNAxCXERVxajn/FLcOhglVAKo5H0ac+AitlQ0ip55D2/mf8o72tM +fVQ6VpyjEXdiIXWUq/o= +-----END CERTIFICATE----- diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go new file mode 100644 index 0000000000..50e07d40ae --- /dev/null +++ b/protocol/cloudflare/connection_http2.go @@ -0,0 +1,367 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "crypto/tls" + "io" + "math" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + + "github.com/google/uuid" + "golang.org/x/net/http2" +) + +const ( + h2EdgeSNI = "h2.cftunnel.com" +) + +// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge. +// Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server. +type HTTP2Connection struct { + conn net.Conn + server *http2.Server + logger log.ContextLogger + edgeAddr *EdgeAddr + connIndex uint8 + credentials Credentials + connectorID uuid.UUID + features []string + gracePeriod time.Duration + inbound *Inbound + + registrationClient *RegistrationClient + registrationResult *RegistrationResult + + activeRequests sync.WaitGroup + closeOnce sync.Once +} + +// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal. +func NewHTTP2Connection( + ctx context.Context, + edgeAddr *EdgeAddr, + connIndex uint8, + credentials Credentials, + connectorID uuid.UUID, + features []string, + gracePeriod time.Duration, + inbound *Inbound, + logger log.ContextLogger, +) (*HTTP2Connection, error) { + rootCAs, err := cloudflareRootCertPool() + if err != nil { + return nil, E.Cause(err, "load Cloudflare root CAs") + } + + tlsConfig := &tls.Config{ + RootCAs: rootCAs, + ServerName: h2EdgeSNI, + CurvePreferences: []tls.CurveID{tls.CurveP256}, + } + + dialer := &net.Dialer{} + tcpConn, err := dialer.DialContext(ctx, "tcp", edgeAddr.TCP.String()) + if err != nil { + return nil, E.Cause(err, "dial edge TCP") + } + + tlsConn := tls.Client(tcpConn, tlsConfig) + err = tlsConn.HandshakeContext(ctx) + if err != nil { + tcpConn.Close() + return nil, E.Cause(err, "TLS handshake") + } + + return &HTTP2Connection{ + conn: tlsConn, + server: &http2.Server{ + MaxConcurrentStreams: math.MaxUint32, + }, + logger: logger, + edgeAddr: edgeAddr, + connIndex: connIndex, + credentials: credentials, + connectorID: connectorID, + features: features, + gracePeriod: gracePeriod, + inbound: inbound, + }, nil +} + +// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends. +func (c *HTTP2Connection) Serve(ctx context.Context) error { + go func() { + <-ctx.Done() + c.close() + }() + + c.server.ServeConn(c.conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: c, + }) + + if c.registrationResult != nil { + return nil + } + return E.New("edge connection closed before registration") +} + +func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.activeRequests.Add(1) + defer c.activeRequests.Done() + + switch { + case r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream: + c.handleControlStream(r.Context(), r, w) + case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket: + c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket) + case r.Header.Get(h2HeaderTCPSrc) != "": + c.handleH2DataStream(r.Context(), r, w, ConnectionTypeTCP) + case r.Header.Get(h2HeaderUpgrade) == h2UpgradeConfiguration: + c.handleConfigurationUpdate(r, w) + default: + c.handleH2DataStream(r.Context(), r, w, ConnectionTypeHTTP) + } +} + +func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Request, w http.ResponseWriter) { + flusher, ok := w.(http.Flusher) + if !ok { + c.logger.Error("response writer does not support flushing") + return + } + + w.WriteHeader(http.StatusOK) + flusher.Flush() + + stream := newHTTP2Stream(r.Body, &http2FlushWriter{w: w, flusher: flusher}) + + c.registrationClient = NewRegistrationClient(ctx, stream) + + options := BuildConnectionOptions(c.connectorID, c.features, 0) + result, err := c.registrationClient.RegisterConnection( + ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options, + ) + if err != nil { + c.logger.Error("register connection: ", err) + return + } + c.registrationResult = result + + c.logger.Info("connected to ", result.Location, + " (connection ", result.ConnectionID, ")") + + <-ctx.Done() + c.registrationClient.Close() +} + +func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) { + r.Header.Del(h2HeaderUpgrade) + r.Header.Del(h2HeaderTCPSrc) + + flusher, ok := w.(http.Flusher) + if !ok { + c.logger.Error("response writer does not support flushing") + return + } + + var destination string + if connectionType == ConnectionTypeTCP { + destination = r.Host + if destination == "" && r.URL != nil { + destination = r.URL.Host + } + } else { + if r.URL.Scheme == "" { + r.URL.Scheme = "http" + } + if r.URL.Host == "" { + r.URL.Host = r.Host + } + destination = r.URL.String() + } + + request := &ConnectRequest{ + Dest: destination, + Type: connectionType, + } + request.Metadata = append(request.Metadata, Metadata{ + Key: metadataHTTPMethod, + Val: r.Method, + }) + request.Metadata = append(request.Metadata, Metadata{ + Key: metadataHTTPHost, + Val: r.Host, + }) + for name, values := range r.Header { + for _, value := range values { + request.Metadata = append(request.Metadata, Metadata{ + Key: metadataHTTPHeader + ":" + name, + Val: value, + }) + } + } + + stream := &http2DataStream{ + reader: r.Body, + writer: w, + flusher: flusher, + } + respWriter := &http2ResponseWriter{ + writer: w, + flusher: flusher, + } + + c.inbound.dispatchRequest(ctx, stream, respWriter, request) +} + +type h2ConfigurationUpdateBody struct { + Version int32 `json:"version"` + Config json.RawMessage `json:"config"` +} + +func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.ResponseWriter) { + var body h2ConfigurationUpdateBody + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + c.logger.Error("decode configuration update: ", err) + w.WriteHeader(http.StatusBadRequest) + return + } + c.inbound.UpdateIngress(body.Version, body.Config) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(body.Version), 10) + `,"err":null}`)) +} + +func (c *HTTP2Connection) close() { + c.closeOnce.Do(func() { + c.conn.Close() + c.activeRequests.Wait() + }) +} + +// Close closes the HTTP/2 connection. +func (c *HTTP2Connection) Close() error { + c.close() + return nil +} + +// http2Stream wraps an HTTP/2 request body (reader) and a flush-writer (writer) as an io.ReadWriteCloser. +// Used for the control stream. +type http2Stream struct { + reader io.ReadCloser + writer io.Writer +} + +func newHTTP2Stream(reader io.ReadCloser, writer io.Writer) *http2Stream { + return &http2Stream{reader: reader, writer: writer} +} + +func (s *http2Stream) Read(p []byte) (int, error) { return s.reader.Read(p) } +func (s *http2Stream) Write(p []byte) (int, error) { return s.writer.Write(p) } +func (s *http2Stream) Close() error { return s.reader.Close() } + +// http2FlushWriter wraps an http.ResponseWriter and flushes after every write. +type http2FlushWriter struct { + w http.ResponseWriter + flusher http.Flusher +} + +func (w *http2FlushWriter) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + if err == nil { + w.flusher.Flush() + } + return n, err +} + +// http2DataStream wraps an HTTP/2 request/response pair as io.ReadWriteCloser for data streams. +type http2DataStream struct { + reader io.ReadCloser + writer http.ResponseWriter + flusher http.Flusher +} + +func (s *http2DataStream) Read(p []byte) (int, error) { + return s.reader.Read(p) +} + +func (s *http2DataStream) Write(p []byte) (int, error) { + n, err := s.writer.Write(p) + if err == nil { + s.flusher.Flush() + } + return n, err +} + +func (s *http2DataStream) Close() error { + return s.reader.Close() +} + +// http2ResponseWriter translates ConnectResponse metadata to HTTP/2 response headers. +type http2ResponseWriter struct { + writer http.ResponseWriter + flusher http.Flusher + headersSent bool +} + +func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + if w.headersSent { + return nil + } + w.headersSent = true + + if responseError != nil { + w.writer.Header().Set(h2HeaderResponseMeta, `{"src":"cloudflared"}`) + w.writer.WriteHeader(http.StatusBadGateway) + w.flusher.Flush() + return nil + } + + statusCode := http.StatusOK + userHeaders := make(http.Header) + + for _, entry := range metadata { + if entry.Key == metadataHTTPStatus { + code, err := strconv.Atoi(entry.Val) + if err == nil { + statusCode = code + } + continue + } + if strings.HasPrefix(entry.Key, metadataHTTPHeader+":") { + headerName := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":") + lower := strings.ToLower(headerName) + + if lower == "content-length" { + w.writer.Header().Set(headerName, entry.Val) + } + + if !isControlResponseHeader(lower) || isWebsocketClientHeader(lower) { + userHeaders.Add(headerName, entry.Val) + } + } + } + + w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders)) + w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin) + + if statusCode == http.StatusSwitchingProtocols { + statusCode = http.StatusOK + } + + w.writer.WriteHeader(statusCode) + w.flusher.Flush() + return nil +} diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go new file mode 100644 index 0000000000..e7cfc073dd --- /dev/null +++ b/protocol/cloudflare/connection_quic.go @@ -0,0 +1,245 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "crypto/tls" + "io" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/google/uuid" +) + +const ( + quicEdgeSNI = "quic.cftunnel.com" + quicEdgeALPN = "argotunnel" + + quicHandshakeIdleTimeout = 5 * time.Second + quicMaxIdleTimeout = 5 * time.Second + quicKeepAlivePeriod = 1 * time.Second +) + +// QUICConnection manages a single QUIC connection to the Cloudflare edge. +type QUICConnection struct { + conn *quic.Conn + logger log.ContextLogger + edgeAddr *EdgeAddr + connIndex uint8 + credentials Credentials + connectorID uuid.UUID + features []string + gracePeriod time.Duration + registrationClient *RegistrationClient + registrationResult *RegistrationResult + + closeOnce sync.Once +} + +// NewQUICConnection dials the edge and establishes a QUIC connection. +func NewQUICConnection( + ctx context.Context, + edgeAddr *EdgeAddr, + connIndex uint8, + credentials Credentials, + connectorID uuid.UUID, + features []string, + gracePeriod time.Duration, + logger log.ContextLogger, +) (*QUICConnection, error) { + rootCAs, err := cloudflareRootCertPool() + if err != nil { + return nil, E.Cause(err, "load Cloudflare root CAs") + } + + tlsConfig := &tls.Config{ + RootCAs: rootCAs, + ServerName: quicEdgeSNI, + NextProtos: []string{quicEdgeALPN}, + CurvePreferences: []tls.CurveID{tls.CurveP256}, + } + + quicConfig := &quic.Config{ + HandshakeIdleTimeout: quicHandshakeIdleTimeout, + MaxIdleTimeout: quicMaxIdleTimeout, + KeepAlivePeriod: quicKeepAlivePeriod, + EnableDatagrams: true, + } + + conn, err := quic.DialAddr(ctx, edgeAddr.UDP.String(), tlsConfig, quicConfig) + if err != nil { + return nil, E.Cause(err, "dial QUIC edge") + } + + return &QUICConnection{ + conn: conn, + logger: logger, + edgeAddr: edgeAddr, + connIndex: connIndex, + credentials: credentials, + connectorID: connectorID, + features: features, + gracePeriod: gracePeriod, + }, nil +} + +// Serve runs the QUIC connection: registers, accepts streams, handles datagrams. +// Blocks until the context is cancelled or a fatal error occurs. +func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error { + controlStream, err := q.conn.OpenStream() + if err != nil { + return E.Cause(err, "open control stream") + } + + err = q.register(ctx, controlStream) + if err != nil { + controlStream.Close() + return err + } + + q.logger.Info("connected to ", q.registrationResult.Location, + " (connection ", q.registrationResult.ConnectionID, ")") + + errChan := make(chan error, 2) + + go func() { + errChan <- q.acceptStreams(ctx, handler) + }() + + go func() { + errChan <- q.handleDatagrams(ctx, handler) + }() + + select { + case <-ctx.Done(): + q.gracefulShutdown() + return ctx.Err() + case err = <-errChan: + q.gracefulShutdown() + return err + } +} + +func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) error { + q.registrationClient = NewRegistrationClient(ctx, newStreamReadWriteCloser(stream)) + + options := BuildConnectionOptions(q.connectorID, q.features, 0) + result, err := q.registrationClient.RegisterConnection( + ctx, q.credentials.Auth(), q.credentials.TunnelID, q.connIndex, options, + ) + if err != nil { + return E.Cause(err, "register connection") + } + q.registrationResult = result + return nil +} + +func (q *QUICConnection) acceptStreams(ctx context.Context, handler StreamHandler) error { + for { + stream, err := q.conn.AcceptStream(ctx) + if err != nil { + return E.Cause(err, "accept stream") + } + go q.handleStream(ctx, stream, handler) + } +} + +func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, handler StreamHandler) { + rwc := newStreamReadWriteCloser(stream) + defer rwc.Close() + + streamType, err := ReadStreamSignature(rwc) + if err != nil { + q.logger.Debug("failed to read stream signature: ", err) + return + } + + switch streamType { + case StreamTypeData: + request, err := ReadConnectRequest(rwc) + if err != nil { + q.logger.Debug("failed to read connect request: ", err) + return + } + handler.HandleDataStream(ctx, rwc, request, q.connIndex) + + case StreamTypeRPC: + handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q) + } +} + +func (q *QUICConnection) handleDatagrams(ctx context.Context, handler StreamHandler) error { + for { + datagram, err := q.conn.ReceiveDatagram(ctx) + if err != nil { + return E.Cause(err, "receive datagram") + } + handler.HandleDatagram(ctx, datagram, q) + } +} + +// SendDatagram sends a QUIC datagram to the edge. +func (q *QUICConnection) SendDatagram(data []byte) error { + return q.conn.SendDatagram(data) +} + +func (q *QUICConnection) gracefulShutdown() { + q.closeOnce.Do(func() { + if q.registrationClient != nil { + ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod) + defer cancel() + err := q.registrationClient.Unregister(ctx) + if err != nil { + q.logger.Debug("failed to unregister: ", err) + } + q.registrationClient.Close() + } + q.conn.CloseWithError(0, "graceful shutdown") + }) +} + +// Close closes the QUIC connection immediately. +func (q *QUICConnection) Close() error { + q.gracefulShutdown() + return nil +} + +// StreamHandler handles incoming edge streams and datagrams. +type StreamHandler interface { + HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8) + HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8) + HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) + HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) +} + +// DatagramSender can send QUIC datagrams back to the edge. +type DatagramSender interface { + SendDatagram(data []byte) error +} + +// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser. +type streamReadWriteCloser struct { + stream *quic.Stream +} + +func newStreamReadWriteCloser(stream *quic.Stream) *streamReadWriteCloser { + return &streamReadWriteCloser{stream: stream} +} + +func (s *streamReadWriteCloser) Read(p []byte) (int, error) { + return s.stream.Read(p) +} + +func (s *streamReadWriteCloser) Write(p []byte) (int, error) { + return s.stream.Write(p) +} + +func (s *streamReadWriteCloser) Close() error { + s.stream.CancelRead(0) + return s.stream.Close() +} diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go new file mode 100644 index 0000000000..9f583f29e8 --- /dev/null +++ b/protocol/cloudflare/control.go @@ -0,0 +1,176 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "runtime" + "time" + + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/google/uuid" + "zombiezen.com/go/capnproto2/pogs" + "zombiezen.com/go/capnproto2/rpc" +) + +const ( + registrationTimeout = 10 * time.Second + clientVersion = "sing-box" +) + +// RegistrationClient handles the Cap'n Proto RPC for tunnel registration. +type RegistrationClient struct { + client tunnelrpc.TunnelServer + rpcConn *rpc.Conn + transport rpc.Transport +} + +// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. +// The stream should be the first QUIC stream (control stream). +func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { + transport := rpc.StreamTransport(stream) + conn := rpc.NewConn(transport) + return &RegistrationClient{ + client: tunnelrpc.TunnelServer{Client: conn.Bootstrap(ctx)}, + rpcConn: conn, + transport: transport, + } +} + +// RegisterConnection registers this tunnel connection with the edge. +func (c *RegistrationClient) RegisterConnection( + ctx context.Context, + auth TunnelAuth, + tunnelID uuid.UUID, + connIndex uint8, + options *RegistrationConnectionOptions, +) (*RegistrationResult, error) { + ctx, cancel := context.WithTimeout(ctx, registrationTimeout) + defer cancel() + + promise := c.client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error { + // Marshal TunnelAuth + tunnelAuth, err := p.NewAuth() + if err != nil { + return err + } + authPogs := &RegistrationTunnelAuth{ + AccountTag: auth.AccountTag, + TunnelSecret: auth.TunnelSecret, + } + err = pogs.Insert(tunnelrpc.TunnelAuth_TypeID, tunnelAuth.Struct, authPogs) + if err != nil { + return err + } + + // Set tunnel ID + err = p.SetTunnelId(tunnelID[:]) + if err != nil { + return err + } + + // Set connection index + p.SetConnIndex(connIndex) + + // Marshal ConnectionOptions + connOptions, err := p.NewOptions() + if err != nil { + return err + } + return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, connOptions.Struct, options) + }) + + response, err := promise.Result().Struct() + if err != nil { + return nil, E.Cause(err, "registration RPC") + } + + result := response.Result() + switch result.Which() { + case tunnelrpc.ConnectionResponse_result_Which_error: + resultError, err := result.Error() + if err != nil { + return nil, E.Cause(err, "read registration error") + } + cause, _ := resultError.Cause() + registrationError := E.New(cause) + if resultError.ShouldRetry() { + return nil, &RetryableError{ + Err: registrationError, + Delay: time.Duration(resultError.RetryAfter()), + } + } + return nil, registrationError + + case tunnelrpc.ConnectionResponse_result_Which_connectionDetails: + connDetails, err := result.ConnectionDetails() + if err != nil { + return nil, E.Cause(err, "read connection details") + } + uuidBytes, err := connDetails.Uuid() + if err != nil { + return nil, E.Cause(err, "read connection UUID") + } + connectionID, err := uuid.FromBytes(uuidBytes) + if err != nil { + return nil, E.Cause(err, "parse connection UUID") + } + location, _ := connDetails.LocationName() + return &RegistrationResult{ + ConnectionID: connectionID, + Location: location, + TunnelIsRemotelyManaged: connDetails.TunnelIsRemotelyManaged(), + }, nil + + default: + return nil, E.New("unexpected registration response type") + } +} + +// Unregister sends the UnregisterConnection RPC. +func (c *RegistrationClient) Unregister(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, registrationTimeout) + defer cancel() + promise := c.client.UnregisterConnection(ctx, nil) + _, err := promise.Struct() + return err +} + +// Close closes the RPC connection and transport. +func (c *RegistrationClient) Close() error { + return E.Errors( + c.rpcConn.Close(), + c.transport.Close(), + ) +} + +// BuildConnectionOptions creates the ConnectionOptions to send during registration. +func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8) *RegistrationConnectionOptions { + return &RegistrationConnectionOptions{ + Client: RegistrationClientInfo{ + ClientID: connectorID[:], + Features: features, + Version: clientVersion, + Arch: runtime.GOARCH, + }, + NumPreviousAttempts: numPreviousAttempts, + } +} + +// DefaultFeatures returns the feature strings to advertise. +func DefaultFeatures(datagramVersion string) []string { + features := []string{ + "serialized_headers", + "support_datagram_v2", + "support_quic_eof", + "allow_remote_config", + "management_logs", + } + if datagramVersion == "v3" { + features = append(features, "support_datagram_v3_2") + } + return features +} diff --git a/protocol/cloudflare/credentials.go b/protocol/cloudflare/credentials.go new file mode 100644 index 0000000000..443da061c0 --- /dev/null +++ b/protocol/cloudflare/credentials.go @@ -0,0 +1,44 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import "github.com/google/uuid" + +// Credentials contains all info needed to run a tunnel. +type Credentials struct { + AccountTag string `json:"AccountTag"` + TunnelSecret []byte `json:"TunnelSecret"` + TunnelID uuid.UUID `json:"TunnelID"` + Endpoint string `json:"Endpoint,omitempty"` +} + +// TunnelToken is the compact token format used in the --token flag. +// Field names match cloudflared's JSON encoding. +type TunnelToken struct { + AccountTag string `json:"a"` + TunnelSecret []byte `json:"s"` + TunnelID uuid.UUID `json:"t"` + Endpoint string `json:"e,omitempty"` +} + +func (t TunnelToken) ToCredentials() Credentials { + return Credentials{ + AccountTag: t.AccountTag, + TunnelSecret: t.TunnelSecret, + TunnelID: t.TunnelID, + Endpoint: t.Endpoint, + } +} + +// TunnelAuth is the authentication data sent during tunnel registration. +type TunnelAuth struct { + AccountTag string + TunnelSecret []byte +} + +func (c *Credentials) Auth() TunnelAuth { + return TunnelAuth{ + AccountTag: c.AccountTag, + TunnelSecret: c.TunnelSecret, + } +} diff --git a/protocol/cloudflare/credentials_test.go b/protocol/cloudflare/credentials_test.go new file mode 100644 index 0000000000..31759aa34c --- /dev/null +++ b/protocol/cloudflare/credentials_test.go @@ -0,0 +1,94 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "encoding/base64" + "os" + "path/filepath" + "testing" + + "github.com/google/uuid" +) + +func TestParseToken(t *testing.T) { + tunnelID := uuid.New() + secret := []byte("test-secret-32-bytes-long-xxxxx") + tokenJSON := `{"a":"account123","t":"` + tunnelID.String() + `","s":"` + base64.StdEncoding.EncodeToString(secret) + `"}` + token := base64.StdEncoding.EncodeToString([]byte(tokenJSON)) + + credentials, err := parseToken(token) + if err != nil { + t.Fatal("parseToken: ", err) + } + if credentials.AccountTag != "account123" { + t.Error("expected AccountTag account123, got ", credentials.AccountTag) + } + if credentials.TunnelID != tunnelID { + t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID) + } +} + +func TestParseTokenInvalidBase64(t *testing.T) { + _, err := parseToken("not-valid-base64!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestParseTokenInvalidJSON(t *testing.T) { + token := base64.StdEncoding.EncodeToString([]byte("{bad json")) + _, err := parseToken(token) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestParseCredentialFile(t *testing.T) { + tunnelID := uuid.New() + content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"` + tunnelID.String() + `"}` + path := filepath.Join(t.TempDir(), "creds.json") + err := os.WriteFile(path, []byte(content), 0o644) + if err != nil { + t.Fatal(err) + } + + credentials, err := parseCredentialFile(path) + if err != nil { + t.Fatal("parseCredentialFile: ", err) + } + if credentials.AccountTag != "acct" { + t.Error("expected AccountTag acct, got ", credentials.AccountTag) + } + if credentials.TunnelID != tunnelID { + t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID) + } +} + +func TestParseCredentialFileMissingTunnelID(t *testing.T) { + content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"00000000-0000-0000-0000-000000000000"}` + path := filepath.Join(t.TempDir(), "creds.json") + err := os.WriteFile(path, []byte(content), 0o644) + if err != nil { + t.Fatal(err) + } + + _, err = parseCredentialFile(path) + if err == nil { + t.Fatal("expected error for missing tunnel ID") + } +} + +func TestParseCredentialsBothSpecified(t *testing.T) { + _, err := parseCredentials("sometoken", "/some/path") + if err == nil { + t.Fatal("expected error when both specified") + } +} + +func TestParseCredentialsNoneSpecified(t *testing.T) { + _, err := parseCredentials("", "") + if err == nil { + t.Fatal("expected error when none specified") + } +} diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go new file mode 100644 index 0000000000..8159b04cca --- /dev/null +++ b/protocol/cloudflare/datagram_v2.go @@ -0,0 +1,346 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/netip" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/google/uuid" + "zombiezen.com/go/capnproto2/rpc" +) + +// V2 wire format: [payload | 16B sessionID | 1B type] (suffix-based) + +// DatagramV2Type identifies the type of a V2 datagram. +type DatagramV2Type byte + +const ( + DatagramV2TypeUDP DatagramV2Type = 0 + DatagramV2TypeIP DatagramV2Type = 1 + DatagramV2TypeIPWithTrace DatagramV2Type = 2 + DatagramV2TypeTracingSpan DatagramV2Type = 3 + + sessionIDLength = 16 + typeIDLength = 1 +) + +// DatagramV2Muxer handles V2 datagram demuxing and session management. +type DatagramV2Muxer struct { + inbound *Inbound + logger log.ContextLogger + sender DatagramSender + + sessionAccess sync.RWMutex + sessions map[uuid.UUID]*udpSession +} + +// NewDatagramV2Muxer creates a new V2 datagram muxer. +func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV2Muxer { + return &DatagramV2Muxer{ + inbound: inbound, + logger: logger, + sender: sender, + sessions: make(map[uuid.UUID]*udpSession), + } +} + +// HandleDatagram demuxes an incoming V2 datagram. +func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) { + if len(data) < typeIDLength { + return + } + + datagramType := DatagramV2Type(data[len(data)-typeIDLength]) + payload := data[:len(data)-typeIDLength] + + switch datagramType { + case DatagramV2TypeUDP: + m.handleUDPDatagram(ctx, payload) + case DatagramV2TypeIP: + // TODO: ICMP handling + m.logger.Debug("received V2 IP datagram (ICMP not yet implemented)") + case DatagramV2TypeIPWithTrace: + m.logger.Debug("received V2 IP+trace datagram") + case DatagramV2TypeTracingSpan: + // Tracing spans, ignore + } +} + +func (m *DatagramV2Muxer) handleUDPDatagram(ctx context.Context, data []byte) { + if len(data) < sessionIDLength { + return + } + + payload := data[:len(data)-sessionIDLength] + sessionID, err := uuid.FromBytes(data[len(data)-sessionIDLength:]) + if err != nil { + m.logger.Debug("invalid session ID in V2 datagram: ", err) + return + } + + m.sessionAccess.RLock() + session, exists := m.sessions[sessionID] + m.sessionAccess.RUnlock() + + if !exists { + m.logger.Debug("unknown V2 UDP session: ", sessionID) + return + } + + session.writeToOrigin(payload) +} + +// RegisterSession registers a new UDP session from an RPC call. +func (m *DatagramV2Muxer) RegisterSession( + ctx context.Context, + sessionID uuid.UUID, + destinationIP net.IP, + destinationPort uint16, + closeAfterIdle time.Duration, +) error { + var destinationAddr netip.Addr + if ip4 := destinationIP.To4(); ip4 != nil { + destinationAddr = netip.AddrFrom4([4]byte(ip4)) + } else { + destinationAddr = netip.AddrFrom16([16]byte(destinationIP.To16())) + } + destination := netip.AddrPortFrom(destinationAddr, destinationPort) + + if closeAfterIdle == 0 { + closeAfterIdle = 210 * time.Second + } + + m.sessionAccess.Lock() + if _, exists := m.sessions[sessionID]; exists { + m.sessionAccess.Unlock() + return nil + } + + session := newUDPSession(sessionID, destination, closeAfterIdle, m) + m.sessions[sessionID] = session + m.sessionAccess.Unlock() + + m.logger.Info("registered V2 UDP session ", sessionID, " to ", destination) + + go m.serveSession(ctx, session) + return nil +} + +// UnregisterSession removes a UDP session. +func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID) { + m.sessionAccess.Lock() + session, exists := m.sessions[sessionID] + if exists { + delete(m.sessions, sessionID) + } + m.sessionAccess.Unlock() + + if exists { + session.close() + m.logger.Info("unregistered V2 UDP session ", sessionID) + } +} + +func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession) { + defer m.UnregisterSession(session.id) + + metadata := adapter.InboundContext{ + Inbound: m.inbound.Tag(), + InboundType: m.inbound.Type(), + Network: N.NetworkUDP, + } + metadata.Destination = M.SocksaddrFromNetIP(session.destination) + + done := make(chan struct{}) + m.inbound.router.RoutePacketConnectionEx( + ctx, + session, + metadata, + N.OnceClose(func(it error) { + close(done) + }), + ) + <-done +} + +// sendToEdge sends a V2 UDP datagram back to the edge. +func (m *DatagramV2Muxer) sendToEdge(sessionID uuid.UUID, payload []byte) { + data := make([]byte, len(payload)+sessionIDLength+typeIDLength) + copy(data, payload) + copy(data[len(payload):], sessionID[:]) + data[len(data)-1] = byte(DatagramV2TypeUDP) + m.sender.SendDatagram(data) +} + +// Close closes all sessions. +func (m *DatagramV2Muxer) Close() { + m.sessionAccess.Lock() + sessions := m.sessions + m.sessions = make(map[uuid.UUID]*udpSession) + m.sessionAccess.Unlock() + + for _, session := range sessions { + session.close() + } +} + +// udpSession represents a V2 UDP session. +type udpSession struct { + id uuid.UUID + destination netip.AddrPort + closeAfterIdle time.Duration + muxer *DatagramV2Muxer + + writeChan chan []byte + closeOnce sync.Once + closeChan chan struct{} +} + +func newUDPSession(id uuid.UUID, destination netip.AddrPort, closeAfterIdle time.Duration, muxer *DatagramV2Muxer) *udpSession { + return &udpSession{ + id: id, + destination: destination, + closeAfterIdle: closeAfterIdle, + muxer: muxer, + writeChan: make(chan []byte, 256), + closeChan: make(chan struct{}), + } +} + +func (s *udpSession) writeToOrigin(payload []byte) { + data := make([]byte, len(payload)) + copy(data, payload) + select { + case s.writeChan <- data: + default: + } +} + +func (s *udpSession) close() { + s.closeOnce.Do(func() { + close(s.closeChan) + }) +} + +// ReadPacket implements N.PacketConn - reads packets from the edge to forward to origin. +func (s *udpSession) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + select { + case data := <-s.writeChan: + _, err := buffer.Write(data) + return M.SocksaddrFromNetIP(s.destination), err + case <-s.closeChan: + return M.Socksaddr{}, io.EOF + } +} + +// WritePacket implements N.PacketConn - receives packets from origin to forward to edge. +func (s *udpSession) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + s.muxer.sendToEdge(s.id, buffer.Bytes()) + return nil +} + +func (s *udpSession) Close() error { + s.close() + return nil +} + +func (s *udpSession) LocalAddr() net.Addr { return nil } +func (s *udpSession) SetDeadline(_ time.Time) error { return nil } +func (s *udpSession) SetReadDeadline(_ time.Time) error { return nil } +func (s *udpSession) SetWriteDeadline(_ time.Time) error { return nil } + +// V2 RPC server implementation for HandleRPCStream. + +type cloudflaredServer struct { + inbound *Inbound + muxer *DatagramV2Muxer + ctx context.Context + logger log.ContextLogger +} + +func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error { + sessionIDBytes, err := call.Params.SessionId() + if err != nil { + return err + } + sessionID, err := uuid.FromBytes(sessionIDBytes) + if err != nil { + return err + } + + destinationIP, err := call.Params.DstIp() + if err != nil { + return err + } + + destinationPort := call.Params.DstPort() + closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint()) + + err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle) + + result, allocErr := call.Results.NewResult() + if allocErr != nil { + return allocErr + } + if err != nil { + result.SetErr(err.Error()) + } + return nil +} + +func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error { + sessionIDBytes, err := call.Params.SessionId() + if err != nil { + return err + } + sessionID, err := uuid.FromBytes(sessionIDBytes) + if err != nil { + return err + } + + s.muxer.UnregisterSession(sessionID) + return nil +} + +func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { + version := call.Params.Version() + configData, _ := call.Params.Config() + s.inbound.UpdateIngress(version, configData) + result, err := call.Results.NewResult() + if err != nil { + return err + } + result.SetErr("") + return nil +} + +// ServeRPCStream handles an incoming V2 RPC stream (session management + configuration). +func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, muxer *DatagramV2Muxer, logger log.ContextLogger) { + srv := &cloudflaredServer{ + inbound: inbound, + muxer: muxer, + ctx: ctx, + logger: logger, + } + client := tunnelrpc.CloudflaredServer_ServerToClient(srv) + transport := rpc.StreamTransport(stream) + rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + <-rpcConn.Done() + E.Errors( + rpcConn.Close(), + transport.Close(), + ) +} diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go new file mode 100644 index 0000000000..a47f1c2ed5 --- /dev/null +++ b/protocol/cloudflare/datagram_v3.go @@ -0,0 +1,319 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/binary" + "io" + "net" + "net/netip" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +// V3 wire format: [1B type | payload] (prefix-based) + +// DatagramV3Type identifies the type of a V3 datagram. +type DatagramV3Type byte + +const ( + DatagramV3TypeRegistration DatagramV3Type = 0 + DatagramV3TypePayload DatagramV3Type = 1 + DatagramV3TypeICMP DatagramV3Type = 2 + DatagramV3TypeRegistrationResponse DatagramV3Type = 3 + + // V3 registration header sizes + v3RegistrationFlagLen = 1 + v3RegistrationPortLen = 2 + v3RegistrationIdleLen = 2 + v3RequestIDLength = 16 + v3IPv4AddrLen = 4 + v3IPv6AddrLen = 16 + v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22 + v3PayloadHeaderLen = 1 + v3RequestIDLength // 17 + v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20 + + // V3 registration flags + v3FlagIPv6 byte = 0x01 + v3FlagTraced byte = 0x02 + v3FlagBundle byte = 0x04 + + // V3 registration response types + v3ResponseOK byte = 0x00 + v3ResponseDestinationUnreachable byte = 0x01 + v3ResponseUnableToBindSocket byte = 0x02 + v3ResponseTooManyActiveFlows byte = 0x03 + v3ResponseErrorWithMsg byte = 0xFF +) + +// RequestID is a 128-bit session identifier for V3. +type RequestID [v3RequestIDLength]byte + +// DatagramV3Muxer handles V3 datagram demuxing and session management. +type DatagramV3Muxer struct { + inbound *Inbound + logger log.ContextLogger + sender DatagramSender + + sessionAccess sync.RWMutex + sessions map[RequestID]*v3Session +} + +// NewDatagramV3Muxer creates a new V3 datagram muxer. +func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV3Muxer { + return &DatagramV3Muxer{ + inbound: inbound, + logger: logger, + sender: sender, + sessions: make(map[RequestID]*v3Session), + } +} + +// HandleDatagram demuxes an incoming V3 datagram. +func (m *DatagramV3Muxer) HandleDatagram(ctx context.Context, data []byte) { + if len(data) < 1 { + return + } + + datagramType := DatagramV3Type(data[0]) + payload := data[1:] + + switch datagramType { + case DatagramV3TypeRegistration: + m.handleRegistration(ctx, payload) + case DatagramV3TypePayload: + m.handlePayload(payload) + case DatagramV3TypeICMP: + // TODO: ICMP handling + m.logger.Debug("received V3 ICMP datagram (not yet implemented)") + case DatagramV3TypeRegistrationResponse: + // Unexpected - we never send registrations + m.logger.Debug("received unexpected V3 registration response") + } +} + +func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { + if len(data) < v3RegistrationFlagLen+v3RegistrationPortLen+v3RegistrationIdleLen+v3RequestIDLength { + m.logger.Debug("V3 registration too short") + return + } + + flags := data[0] + destinationPort := binary.BigEndian.Uint16(data[1:3]) + idleDurationSeconds := binary.BigEndian.Uint16(data[3:5]) + + var requestID RequestID + copy(requestID[:], data[5:5+v3RequestIDLength]) + + offset := 5 + v3RequestIDLength + var destination netip.AddrPort + + if flags&v3FlagIPv6 != 0 { + if len(data) < offset+v3IPv6AddrLen { + m.logger.Debug("V3 registration too short for IPv6") + return + } + var addr [16]byte + copy(addr[:], data[offset:offset+v3IPv6AddrLen]) + destination = netip.AddrPortFrom(netip.AddrFrom16(addr), destinationPort) + offset += v3IPv6AddrLen + } else { + if len(data) < offset+v3IPv4AddrLen { + m.logger.Debug("V3 registration too short for IPv4") + return + } + var addr [4]byte + copy(addr[:], data[offset:offset+v3IPv4AddrLen]) + destination = netip.AddrPortFrom(netip.AddrFrom4(addr), destinationPort) + offset += v3IPv4AddrLen + } + + closeAfterIdle := time.Duration(idleDurationSeconds) * time.Second + if closeAfterIdle == 0 { + closeAfterIdle = 210 * time.Second + } + + m.sessionAccess.Lock() + if existing, exists := m.sessions[requestID]; exists { + m.sessionAccess.Unlock() + // Session already exists - re-ack + m.sendRegistrationResponse(requestID, v3ResponseOK, "") + // Handle bundled payload + if flags&v3FlagBundle != 0 && len(data) > offset { + existing.writeToOrigin(data[offset:]) + } + return + } + + session := newV3Session(requestID, destination, closeAfterIdle, m) + m.sessions[requestID] = session + m.sessionAccess.Unlock() + + m.logger.Info("registered V3 UDP session to ", destination) + m.sendRegistrationResponse(requestID, v3ResponseOK, "") + + // Handle bundled first payload + if flags&v3FlagBundle != 0 && len(data) > offset { + session.writeToOrigin(data[offset:]) + } + + go m.serveV3Session(ctx, session) +} + +func (m *DatagramV3Muxer) handlePayload(data []byte) { + if len(data) < v3RequestIDLength { + return + } + + var requestID RequestID + copy(requestID[:], data[:v3RequestIDLength]) + payload := data[v3RequestIDLength:] + + m.sessionAccess.RLock() + session, exists := m.sessions[requestID] + m.sessionAccess.RUnlock() + + if !exists { + return + } + + session.writeToOrigin(payload) +} + +func (m *DatagramV3Muxer) sendRegistrationResponse(requestID RequestID, responseType byte, errorMessage string) { + errorBytes := []byte(errorMessage) + data := make([]byte, v3RegistrationRespLen+len(errorBytes)) + data[0] = byte(DatagramV3TypeRegistrationResponse) + data[1] = responseType + copy(data[2:2+v3RequestIDLength], requestID[:]) + binary.BigEndian.PutUint16(data[2+v3RequestIDLength:], uint16(len(errorBytes))) + copy(data[v3RegistrationRespLen:], errorBytes) + m.sender.SendDatagram(data) +} + +func (m *DatagramV3Muxer) sendPayload(requestID RequestID, payload []byte) { + data := make([]byte, v3PayloadHeaderLen+len(payload)) + data[0] = byte(DatagramV3TypePayload) + copy(data[1:1+v3RequestIDLength], requestID[:]) + copy(data[v3PayloadHeaderLen:], payload) + m.sender.SendDatagram(data) +} + +func (m *DatagramV3Muxer) unregisterSession(requestID RequestID) { + m.sessionAccess.Lock() + session, exists := m.sessions[requestID] + if exists { + delete(m.sessions, requestID) + } + m.sessionAccess.Unlock() + + if exists { + session.close() + } +} + +func (m *DatagramV3Muxer) serveV3Session(ctx context.Context, session *v3Session) { + defer m.unregisterSession(session.id) + + metadata := adapter.InboundContext{ + Inbound: m.inbound.Tag(), + InboundType: m.inbound.Type(), + Network: N.NetworkUDP, + } + metadata.Destination = M.SocksaddrFromNetIP(session.destination) + + done := make(chan struct{}) + m.inbound.router.RoutePacketConnectionEx( + ctx, + session, + metadata, + N.OnceClose(func(it error) { + close(done) + }), + ) + <-done +} + +// Close closes all V3 sessions. +func (m *DatagramV3Muxer) Close() { + m.sessionAccess.Lock() + sessions := m.sessions + m.sessions = make(map[RequestID]*v3Session) + m.sessionAccess.Unlock() + + for _, session := range sessions { + session.close() + } +} + +// v3Session represents a V3 UDP session. +type v3Session struct { + id RequestID + destination netip.AddrPort + closeAfterIdle time.Duration + muxer *DatagramV3Muxer + + writeChan chan []byte + closeOnce sync.Once + closeChan chan struct{} +} + +func newV3Session(id RequestID, destination netip.AddrPort, closeAfterIdle time.Duration, muxer *DatagramV3Muxer) *v3Session { + return &v3Session{ + id: id, + destination: destination, + closeAfterIdle: closeAfterIdle, + muxer: muxer, + writeChan: make(chan []byte, 512), + closeChan: make(chan struct{}), + } +} + +func (s *v3Session) writeToOrigin(payload []byte) { + data := make([]byte, len(payload)) + copy(data, payload) + select { + case s.writeChan <- data: + default: + } +} + +func (s *v3Session) close() { + s.closeOnce.Do(func() { + close(s.closeChan) + }) +} + +// ReadPacket implements N.PacketConn. +func (s *v3Session) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + select { + case data := <-s.writeChan: + _, err := buffer.Write(data) + return M.SocksaddrFromNetIP(s.destination), err + case <-s.closeChan: + return M.Socksaddr{}, io.EOF + } +} + +// WritePacket implements N.PacketConn. +func (s *v3Session) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + s.muxer.sendPayload(s.id, buffer.Bytes()) + return nil +} + +func (s *v3Session) Close() error { + s.close() + return nil +} + +func (s *v3Session) LocalAddr() net.Addr { return nil } +func (s *v3Session) SetDeadline(_ time.Time) error { return nil } +func (s *v3Session) SetReadDeadline(_ time.Time) error { return nil } +func (s *v3Session) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go new file mode 100644 index 0000000000..7f949a7a40 --- /dev/null +++ b/protocol/cloudflare/dispatch.go @@ -0,0 +1,373 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +const ( + metadataHTTPMethod = "HttpMethod" + metadataHTTPHost = "HttpHost" + metadataHTTPHeader = "HttpHeader" + metadataHTTPStatus = "HttpStatus" +) + +// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2. +type ConnectResponseWriter interface { + // WriteResponse sends the connect response (ack or error) with optional metadata. + WriteResponse(responseError error, metadata []Metadata) error +} + +// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp). +type quicResponseWriter struct { + stream io.Writer +} + +func (w *quicResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + return WriteConnectResponse(w.stream, responseError, metadata...) +} + +// HandleDataStream dispatches an incoming edge data stream (QUIC path). +func (i *Inbound) HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8) { + ctx = log.ContextWithNewID(ctx) + respWriter := &quicResponseWriter{stream: stream} + i.dispatchRequest(ctx, stream, respWriter, request) +} + +// HandleRPCStream handles an incoming edge RPC stream (session management, configuration). +func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8) { + i.logger.DebugContext(ctx, "received RPC stream on connection ", connIndex) + // V2 RPC streams are handled here - the edge calls RegisterUdpSession/UnregisterUdpSession + // We need the sender (DatagramSender) to find the muxer - but HandleRPCStream doesn't have it. + // The V2 muxer is looked up via GetOrCreateV2Muxer in HandleDatagram when first datagram arrives. + // For RPC, we need a different approach - see handleRPCStreamWithSender below. +} + +// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup. +func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) { + muxer := i.getOrCreateV2Muxer(sender) + ServeRPCStream(ctx, stream, i, muxer, i.logger) +} + +// HandleDatagram handles an incoming QUIC datagram. +func (i *Inbound) HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) { + switch i.datagramVersion { + case "v3": + muxer := i.getOrCreateV3Muxer(sender) + muxer.HandleDatagram(ctx, datagram) + default: + muxer := i.getOrCreateV2Muxer(sender) + muxer.HandleDatagram(ctx, datagram) + } +} + +func (i *Inbound) getOrCreateV2Muxer(sender DatagramSender) *DatagramV2Muxer { + i.datagramMuxerAccess.Lock() + defer i.datagramMuxerAccess.Unlock() + muxer, exists := i.datagramV2Muxers[sender] + if !exists { + muxer = NewDatagramV2Muxer(i, sender, i.logger) + i.datagramV2Muxers[sender] = muxer + } + return muxer +} + +func (i *Inbound) getOrCreateV3Muxer(sender DatagramSender) *DatagramV3Muxer { + i.datagramMuxerAccess.Lock() + defer i.datagramMuxerAccess.Unlock() + muxer, exists := i.datagramV3Muxers[sender] + if !exists { + muxer = NewDatagramV3Muxer(i, sender, i.logger) + i.datagramV3Muxers[sender] = muxer + } + return muxer +} + +// RemoveDatagramMuxer cleans up muxers when a connection closes. +func (i *Inbound) RemoveDatagramMuxer(sender DatagramSender) { + i.datagramMuxerAccess.Lock() + if muxer, exists := i.datagramV2Muxers[sender]; exists { + muxer.Close() + delete(i.datagramV2Muxers, sender) + } + if muxer, exists := i.datagramV3Muxers[sender]; exists { + muxer.Close() + delete(i.datagramV3Muxers, sender) + } + i.datagramMuxerAccess.Unlock() +} + +func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest) { + metadata := adapter.InboundContext{ + Inbound: i.Tag(), + InboundType: i.Type(), + } + + switch request.Type { + case ConnectionTypeTCP: + metadata.Destination = M.ParseSocksaddr(request.Dest) + i.handleTCPStream(ctx, stream, respWriter, metadata) + case ConnectionTypeHTTP, ConnectionTypeWebsocket: + originURL := i.ResolveOriginURL(request.Dest) + request.Dest = originURL + metadata.Destination = parseHTTPDestination(originURL) + if request.Type == ConnectionTypeHTTP { + i.handleHTTPStream(ctx, stream, respWriter, request, metadata) + } else { + i.handleWebSocketStream(ctx, stream, respWriter, request, metadata) + } + default: + i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type) + } +} + +func parseHTTPDestination(dest string) M.Socksaddr { + parsed, err := url.Parse(dest) + if err != nil { + return M.ParseSocksaddr(dest) + } + host := parsed.Hostname() + port := parsed.Port() + if port == "" { + switch parsed.Scheme { + case "https", "wss": + port = "443" + default: + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(host, port)) +} + +func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, metadata adapter.InboundContext) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound TCP connection to ", metadata.Destination) + + err := respWriter.WriteResponse(nil, nil) + if err != nil { + i.logger.ErrorContext(ctx, "write connect response: ", err) + return + } + + done := make(chan struct{}) + i.router.RouteConnectionEx(ctx, newStreamConn(stream), metadata, N.OnceClose(func(it error) { + close(done) + })) + <-done +} + +func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) + + httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) + if err != nil { + i.logger.ErrorContext(ctx, "build HTTP request: ", err) + respWriter.WriteResponse(err, nil) + return + } + + input, output := pipe.Pipe() + var innerError error + + done := make(chan struct{}) + go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { + innerError = it + common.Close(input, output) + close(done) + })) + + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return input, nil + }, + }, + CheckRedirect: func(request *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + defer httpClient.CloseIdleConnections() + + response, err := httpClient.Do(httpRequest) + if err != nil { + <-done + i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err)) + respWriter.WriteResponse(err, nil) + return + } + + responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) + err = respWriter.WriteResponse(nil, responseMetadata) + if err != nil { + response.Body.Close() + i.logger.ErrorContext(ctx, "write HTTP response headers: ", err) + <-done + return + } + + _, err = io.Copy(stream, response.Body) + response.Body.Close() + common.Close(input, output) + if err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "copy HTTP response body: ", err) + } + <-done +} + +func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) + + httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) + if err != nil { + i.logger.ErrorContext(ctx, "build WebSocket request: ", err) + respWriter.WriteResponse(err, nil) + return + } + + input, output := pipe.Pipe() + var innerError error + + done := make(chan struct{}) + go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { + innerError = it + common.Close(input, output) + close(done) + })) + + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return input, nil + }, + }, + CheckRedirect: func(request *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + defer httpClient.CloseIdleConnections() + + response, err := httpClient.Do(httpRequest) + if err != nil { + <-done + i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err)) + respWriter.WriteResponse(err, nil) + return + } + + responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) + err = respWriter.WriteResponse(nil, responseMetadata) + if err != nil { + response.Body.Close() + i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err) + <-done + return + } + + _, err = io.Copy(stream, response.Body) + response.Body.Close() + common.Close(input, output) + if err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "copy WebSocket response body: ", err) + } + <-done +} + +func buildHTTPRequestFromMetadata(ctx context.Context, connectRequest *ConnectRequest, body io.Reader) (*http.Request, error) { + metadataMap := connectRequest.MetadataMap() + method := metadataMap[metadataHTTPMethod] + host := metadataMap[metadataHTTPHost] + + request, err := http.NewRequestWithContext(ctx, method, connectRequest.Dest, body) + if err != nil { + return nil, E.Cause(err, "create HTTP request") + } + request.Host = host + + for _, entry := range connectRequest.Metadata { + if !strings.Contains(entry.Key, metadataHTTPHeader) { + continue + } + parts := strings.SplitN(entry.Key, ":", 2) + if len(parts) != 2 { + continue + } + request.Header.Add(parts[1], entry.Val) + } + + contentLengthStr := request.Header.Get("Content-Length") + if contentLengthStr != "" { + request.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, E.Cause(err, "parse content-length") + } + } + + if connectRequest.Type != ConnectionTypeWebsocket && !isTransferEncodingChunked(request) && request.ContentLength == 0 { + request.Body = http.NoBody + } + + request.Header.Del("Cf-Cloudflared-Proxy-Connection-Upgrade") + + return request, nil +} + +func isTransferEncodingChunked(request *http.Request) bool { + for _, encoding := range request.TransferEncoding { + if strings.EqualFold(encoding, "chunked") { + return true + } + } + return false +} + +func encodeResponseHeaders(statusCode int, header http.Header) []Metadata { + metadata := make([]Metadata, 0, len(header)+1) + metadata = append(metadata, Metadata{ + Key: metadataHTTPStatus, + Val: strconv.Itoa(statusCode), + }) + for name, values := range header { + for _, value := range values { + metadata = append(metadata, Metadata{ + Key: metadataHTTPHeader + ":" + name, + Val: value, + }) + } + } + return metadata +} + +// streamConn wraps an io.ReadWriteCloser as a net.Conn. +type streamConn struct { + io.ReadWriteCloser +} + +func newStreamConn(stream io.ReadWriteCloser) *streamConn { + return &streamConn{ReadWriteCloser: stream} +} + +func (c *streamConn) LocalAddr() net.Addr { return nil } +func (c *streamConn) RemoteAddr() net.Addr { return nil } +func (c *streamConn) SetDeadline(_ time.Time) error { return nil } +func (c *streamConn) SetReadDeadline(_ time.Time) error { return nil } +func (c *streamConn) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/protocol/cloudflare/dispatch_test.go b/protocol/cloudflare/dispatch_test.go new file mode 100644 index 0000000000..e4645cbd38 --- /dev/null +++ b/protocol/cloudflare/dispatch_test.go @@ -0,0 +1,137 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "net/http" + "testing" +) + +func TestParseHTTPDestination(t *testing.T) { + tests := []struct { + name string + dest string + expected string + }{ + {"http with port", "http://127.0.0.1:8083/path", "127.0.0.1:8083"}, + {"https default port", "https://example.com", "example.com:443"}, + {"http default port", "http://example.com", "example.com:80"}, + {"wss default port", "wss://example.com/ws", "example.com:443"}, + {"explicit port", "https://example.com:9443/api", "example.com:9443"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseHTTPDestination(tt.dest) + if result.String() != tt.expected { + t.Errorf("parseHTTPDestination(%q) = %q, want %q", tt.dest, result.String(), tt.expected) + } + }) + } +} + +func TestSerializeHeaders(t *testing.T) { + header := http.Header{} + header.Set("Content-Type", "text/html") + header.Set("X-Foo", "bar") + + serialized := SerializeHeaders(header) + if serialized == "" { + t.Fatal("expected non-empty serialized headers") + } + + decoded := make(map[string]string) + for _, pair := range splitNonEmpty(serialized, ";") { + parts := splitNonEmpty(pair, ":") + if len(parts) != 2 { + t.Fatalf("malformed pair: %q", pair) + } + name, err := headerEncoding.DecodeString(parts[0]) + if err != nil { + t.Fatal("decode name: ", err) + } + value, err := headerEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatal("decode value: ", err) + } + decoded[string(name)] = string(value) + } + + if decoded["Content-Type"] != "text/html" { + t.Error("expected Content-Type=text/html, got ", decoded["Content-Type"]) + } + if decoded["X-Foo"] != "bar" { + t.Error("expected X-Foo=bar, got ", decoded["X-Foo"]) + } +} + +func splitNonEmpty(s string, sep string) []string { + var result []string + for _, part := range splitString(s, sep) { + if part != "" { + result = append(result, part) + } + } + return result +} + +func splitString(s string, sep string) []string { + if len(sep) == 0 { + return []string{s} + } + var result []string + start := 0 + for i := 0; i <= len(s)-len(sep); i++ { + if s[i:i+len(sep)] == sep { + result = append(result, s[start:i]) + start = i + len(sep) + i += len(sep) - 1 + } + } + result = append(result, s[start:]) + return result +} + +func TestIsControlResponseHeader(t *testing.T) { + tests := []struct { + name string + expected bool + }{ + {":status", true}, + {"cf-int-foo", true}, + {"cf-cloudflared-response-meta", true}, + {"cf-proxy-src", true}, + {"content-type", false}, + {"x-custom", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isControlResponseHeader(tt.name) + if result != tt.expected { + t.Errorf("isControlResponseHeader(%q) = %v, want %v", tt.name, result, tt.expected) + } + }) + } +} + +func TestIsWebsocketClientHeader(t *testing.T) { + tests := []struct { + name string + expected bool + }{ + {"sec-websocket-accept", true}, + {"connection", true}, + {"upgrade", true}, + {"content-type", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isWebsocketClientHeader(tt.name) + if result != tt.expected { + t.Errorf("isWebsocketClientHeader(%q) = %v, want %v", tt.name, result, tt.expected) + } + }) + } +} diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go new file mode 100644 index 0000000000..6ca9403d0d --- /dev/null +++ b/protocol/cloudflare/edge_discovery.go @@ -0,0 +1,122 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "crypto/tls" + "net" + "time" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + edgeSRVService = "v2-origintunneld" + edgeSRVProto = "tcp" + edgeSRVName = "argotunnel.com" + + dotServerName = "cloudflare-dns.com" + dotServerAddr = "1.1.1.1:853" + dotTimeout = 15 * time.Second +) + +// EdgeAddr represents a Cloudflare edge server address. +type EdgeAddr struct { + TCP *net.TCPAddr + UDP *net.UDPAddr + IPVersion int // 4 or 6 +} + +// DiscoverEdge performs SRV-based edge discovery and returns addresses +// partitioned into regions (typically 2). +func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) { + regions, err := lookupEdgeSRV() + if err != nil { + regions, err = lookupEdgeSRVWithDoT(ctx) + if err != nil { + return nil, E.Cause(err, "edge discovery") + } + } + if len(regions) == 0 { + return nil, E.New("edge discovery: no edge addresses found") + } + return regions, nil +} + +func lookupEdgeSRV() ([][]*EdgeAddr, error) { + _, addrs, err := net.LookupSRV(edgeSRVService, edgeSRVProto, edgeSRVName) + if err != nil { + return nil, err + } + return resolveSRVRecords(addrs) +} + +func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) { + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + var dialer net.Dialer + conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr) + if err != nil { + return nil, err + } + return tls.Client(conn, &tls.Config{ServerName: dotServerName}), nil + }, + } + lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout) + defer cancel() + _, addrs, err := resolver.LookupSRV(lookupCtx, edgeSRVService, edgeSRVProto, edgeSRVName) + if err != nil { + return nil, err + } + return resolveSRVRecords(addrs) +} + +func resolveSRVRecords(records []*net.SRV) ([][]*EdgeAddr, error) { + var regions [][]*EdgeAddr + for _, record := range records { + ips, err := net.LookupIP(record.Target) + if err != nil { + return nil, E.Cause(err, "resolve SRV target: ", record.Target) + } + if len(ips) == 0 { + continue + } + edgeAddrs := make([]*EdgeAddr, 0, len(ips)) + for _, ip := range ips { + ipVersion := 6 + if ip.To4() != nil { + ipVersion = 4 + } + edgeAddrs = append(edgeAddrs, &EdgeAddr{ + TCP: &net.TCPAddr{IP: ip, Port: int(record.Port)}, + UDP: &net.UDPAddr{IP: ip, Port: int(record.Port)}, + IPVersion: ipVersion, + }) + } + regions = append(regions, edgeAddrs) + } + return regions, nil +} + +// FilterByIPVersion filters edge addresses to only include the specified IP version. +// version 0 means no filtering (auto). +func FilterByIPVersion(regions [][]*EdgeAddr, version int) [][]*EdgeAddr { + if version == 0 { + return regions + } + var filtered [][]*EdgeAddr + for _, region := range regions { + var addrs []*EdgeAddr + for _, addr := range region { + if addr.IPVersion == version { + addrs = append(addrs, addr) + } + } + if len(addrs) > 0 { + filtered = append(filtered, addrs) + } + } + return filtered +} diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go new file mode 100644 index 0000000000..6d602cfa60 --- /dev/null +++ b/protocol/cloudflare/edge_discovery_test.go @@ -0,0 +1,88 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "net" + "testing" +) + +func TestDiscoverEdge(t *testing.T) { + regions, err := DiscoverEdge(context.Background()) + if err != nil { + t.Fatal("DiscoverEdge: ", err) + } + if len(regions) == 0 { + t.Fatal("expected at least 1 region") + } + for i, region := range regions { + if len(region) == 0 { + t.Errorf("region %d is empty", i) + continue + } + for j, addr := range region { + if addr.TCP == nil { + t.Errorf("region %d addr %d: TCP is nil", i, j) + } + if addr.UDP == nil { + t.Errorf("region %d addr %d: UDP is nil", i, j) + } + if addr.IPVersion != 4 && addr.IPVersion != 6 { + t.Errorf("region %d addr %d: invalid IPVersion %d", i, j, addr.IPVersion) + } + } + } +} + +func TestFilterByIPVersion(t *testing.T) { + v4Addr := &EdgeAddr{ + TCP: &net.TCPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844}, + UDP: &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844}, + IPVersion: 4, + } + v6Addr := &EdgeAddr{ + TCP: &net.TCPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844}, + UDP: &net.UDPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844}, + IPVersion: 6, + } + mixed := [][]*EdgeAddr{{v4Addr, v6Addr}} + + tests := []struct { + name string + version int + expected int + }{ + {"auto", 0, 2}, + {"v4 only", 4, 1}, + {"v6 only", 6, 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FilterByIPVersion(mixed, tt.version) + total := 0 + for _, region := range result { + total += len(region) + } + if total != tt.expected { + t.Errorf("expected %d addrs, got %d", tt.expected, total) + } + }) + } + + t.Run("no match", func(t *testing.T) { + v4Only := [][]*EdgeAddr{{v4Addr}} + result := FilterByIPVersion(v4Only, 6) + if len(result) != 0 { + t.Error("expected empty result for no match") + } + }) + + t.Run("empty input", func(t *testing.T) { + result := FilterByIPVersion(nil, 4) + if len(result) != 0 { + t.Error("expected empty result for nil input") + } + }) +} diff --git a/protocol/cloudflare/header.go b/protocol/cloudflare/header.go new file mode 100644 index 0000000000..05aa3765df --- /dev/null +++ b/protocol/cloudflare/header.go @@ -0,0 +1,55 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "encoding/base64" + "net/http" + "strings" +) + +const ( + h2HeaderUpgrade = "Cf-Cloudflared-Proxy-Connection-Upgrade" + h2HeaderTCPSrc = "Cf-Cloudflared-Proxy-Src" + h2HeaderResponseMeta = "Cf-Cloudflared-Response-Meta" + h2HeaderResponseUser = "Cf-Cloudflared-Response-Headers" + h2UpgradeControlStream = "control-stream" + h2UpgradeWebsocket = "websocket" + h2UpgradeConfiguration = "update-configuration" + h2ResponseMetaOrigin = `{"src":"origin"}` +) + +var headerEncoding = base64.RawStdEncoding + +// SerializeHeaders encodes HTTP/1 headers into base64 pairs: base64(name):base64(value);... +func SerializeHeaders(header http.Header) string { + var builder strings.Builder + for name, values := range header { + for _, value := range values { + if builder.Len() > 0 { + builder.WriteByte(';') + } + builder.WriteString(headerEncoding.EncodeToString([]byte(name))) + builder.WriteByte(':') + builder.WriteString(headerEncoding.EncodeToString([]byte(value))) + } + } + return builder.String() +} + +// isControlResponseHeader returns true for headers that are internal control headers. +func isControlResponseHeader(name string) bool { + lower := strings.ToLower(name) + return strings.HasPrefix(lower, ":") || + strings.HasPrefix(lower, "cf-int-") || + strings.HasPrefix(lower, "cf-cloudflared-") || + strings.HasPrefix(lower, "cf-proxy-") +} + +// isWebsocketClientHeader returns true for headers needed by the client for WebSocket upgrade. +func isWebsocketClientHeader(name string) bool { + lower := strings.ToLower(name) + return lower == "sec-websocket-accept" || + lower == "connection" || + lower == "upgrade" +} diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go new file mode 100644 index 0000000000..82676434dc --- /dev/null +++ b/protocol/cloudflare/helpers_test.go @@ -0,0 +1,194 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + N "github.com/sagernet/sing/common/network" + + "github.com/google/uuid" +) + +func requireEnvVars(t *testing.T) (token string, testURL string) { + t.Helper() + token = os.Getenv("CF_TUNNEL_TOKEN") + testURL = os.Getenv("CF_TEST_URL") + if token == "" || testURL == "" { + t.Skip("CF_TUNNEL_TOKEN and CF_TEST_URL must be set") + } + return +} + +func startOriginServer(t *testing.T) { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + }) + mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + io.Copy(w, r.Body) + }) + mux.HandleFunc("/status/", func(w http.ResponseWriter, r *http.Request) { + codeStr := strings.TrimPrefix(r.URL.Path, "/status/") + code, err := strconv.Atoi(codeStr) + if err != nil { + code = 200 + } + w.Header().Set("X-Custom", "test-value") + w.WriteHeader(code) + fmt.Fprintf(w, "status: %d", code) + }) + mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(r.Header) + }) + + server := &http.Server{ + Addr: "127.0.0.1:8083", + Handler: mux, + } + + listener, err := net.Listen("tcp", server.Addr) + if err != nil { + t.Fatal("start origin server: ", err) + } + + go server.Serve(listener) + t.Cleanup(func() { + server.Close() + }) +} + +type testRouter struct{} + +func (r *testRouter) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + destination := metadata.Destination.String() + upstream, err := net.Dial("tcp", destination) + if err != nil { + conn.Close() + return err + } + go func() { + io.Copy(upstream, conn) + upstream.Close() + }() + io.Copy(conn, upstream) + conn.Close() + return nil +} + +func (r *testRouter) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return nil +} + +func (r *testRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + destination := metadata.Destination.String() + upstream, err := net.Dial("tcp", destination) + if err != nil { + conn.Close() + onClose(err) + return + } + var once sync.Once + closeFn := func() { + once.Do(func() { + conn.Close() + upstream.Close() + }) + } + go func() { + io.Copy(upstream, conn) + closeFn() + }() + io.Copy(conn, upstream) + closeFn() + onClose(nil) +} + +func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + onClose(nil) +} + +func newTestInbound(t *testing.T, token string, protocol string, haConnections int) *Inbound { + t.Helper() + credentials, err := parseToken(token) + if err != nil { + t.Fatal("parse token: ", err) + } + + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal("create logger: ", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + ctx: ctx, + cancel: cancel, + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + edgeIPVersion: 0, + datagramVersion: "", + gracePeriod: 5 * time.Second, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + } + + t.Cleanup(func() { + cancel() + inboundInstance.Close() + }) + return inboundInstance +} + +func waitForTunnel(t *testing.T, testURL string, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + client := &http.Client{Timeout: 5 * time.Second} + var lastErr error + var lastStatus int + var lastBody string + for time.Now().Before(deadline) { + resp, err := client.Get(testURL + "/ping") + if err != nil { + lastErr = err + time.Sleep(500 * time.Millisecond) + continue + } + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + lastStatus = resp.StatusCode + lastBody = string(body) + if resp.StatusCode == http.StatusOK && lastBody == `{"ok":true}` { + return + } + time.Sleep(500 * time.Millisecond) + } + t.Fatalf("tunnel not ready after %s (lastErr=%v, lastStatus=%d, lastBody=%q)", timeout, lastErr, lastStatus, lastBody) +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go new file mode 100644 index 0000000000..36c04310ee --- /dev/null +++ b/protocol/cloudflare/inbound.go @@ -0,0 +1,417 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/base64" + "io" + "math/rand" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + + "github.com/google/uuid" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound) +} + +type Inbound struct { + inbound.Adapter + ctx context.Context + cancel context.CancelFunc + router adapter.ConnectionRouterEx + logger log.ContextLogger + credentials Credentials + connectorID uuid.UUID + haConnections int + protocol string + edgeIPVersion int + datagramVersion string + gracePeriod time.Duration + + connectionAccess sync.Mutex + connections []io.Closer + done sync.WaitGroup + + datagramMuxerAccess sync.Mutex + datagramV2Muxers map[DatagramSender]*DatagramV2Muxer + datagramV3Muxers map[DatagramSender]*DatagramV3Muxer + + ingressAccess sync.RWMutex + ingressVersion int32 + ingressRules []IngressRule +} + +// IngressRule maps a hostname pattern to an origin service URL. +type IngressRule struct { + Hostname string + Service string +} + +type ingressConfig struct { + Ingress []ingressConfigRule `json:"ingress"` +} + +type ingressConfigRule struct { + Hostname string `json:"hostname,omitempty"` + Service string `json:"service"` +} + +// UpdateIngress applies a new ingress configuration from the edge. +func (i *Inbound) UpdateIngress(version int32, config []byte) { + i.ingressAccess.Lock() + defer i.ingressAccess.Unlock() + + if version <= i.ingressVersion { + return + } + + var parsed ingressConfig + err := json.Unmarshal(config, &parsed) + if err != nil { + i.logger.Error("parse ingress config: ", err) + return + } + + rules := make([]IngressRule, 0, len(parsed.Ingress)) + for _, rule := range parsed.Ingress { + rules = append(rules, IngressRule{ + Hostname: rule.Hostname, + Service: rule.Service, + }) + } + i.ingressRules = rules + i.ingressVersion = version + i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)") +} + +// ResolveOrigin finds the origin service URL for a given hostname. +// Returns the service URL if matched, or empty string if no match. +func (i *Inbound) ResolveOrigin(hostname string) string { + i.ingressAccess.RLock() + defer i.ingressAccess.RUnlock() + + for _, rule := range i.ingressRules { + if rule.Hostname == "" { + return rule.Service + } + if matchIngress(rule.Hostname, hostname) { + return rule.Service + } + } + return "" +} + +func matchIngress(pattern, hostname string) bool { + if pattern == hostname { + return true + } + if strings.HasPrefix(pattern, "*.") { + suffix := pattern[1:] + return strings.HasSuffix(hostname, suffix) + } + return false +} + +// ResolveOriginURL rewrites a request URL to point to the origin service. +// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path +func (i *Inbound) ResolveOriginURL(requestURL string) string { + parsed, err := url.Parse(requestURL) + if err != nil { + return requestURL + } + hostname := parsed.Hostname() + origin := i.ResolveOrigin(hostname) + if origin == "" || strings.HasPrefix(origin, "http_status:") { + return requestURL + } + originURL, err := url.Parse(origin) + if err != nil { + return requestURL + } + parsed.Scheme = originURL.Scheme + parsed.Host = originURL.Host + return parsed.String() +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { + credentials, err := parseCredentials(options.Token, options.CredentialPath) + if err != nil { + return nil, E.Cause(err, "parse credentials") + } + + haConnections := options.HAConnections + if haConnections <= 0 { + haConnections = 4 + } + + protocol := options.Protocol + if protocol != "" && protocol != "quic" && protocol != "http2" { + return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2") + } + + edgeIPVersion := options.EdgeIPVersion + if edgeIPVersion != 0 && edgeIPVersion != 4 && edgeIPVersion != 6 { + return nil, E.New("unsupported edge_ip_version: ", edgeIPVersion, ", expected 0, 4 or 6") + } + + datagramVersion := options.DatagramVersion + if datagramVersion != "" && datagramVersion != "v2" && datagramVersion != "v3" { + return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3") + } + + gracePeriod := time.Duration(options.GracePeriod) + if gracePeriod == 0 { + gracePeriod = 30 * time.Second + } + + inboundCtx, cancel := context.WithCancel(ctx) + + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), + ctx: inboundCtx, + cancel: cancel, + router: router, + logger: logger, + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + edgeIPVersion: edgeIPVersion, + datagramVersion: datagramVersion, + gracePeriod: gracePeriod, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + }, nil +} + +func (i *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + + i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections") + + regions, err := DiscoverEdge(i.ctx) + if err != nil { + return E.Cause(err, "discover edge") + } + regions = FilterByIPVersion(regions, i.edgeIPVersion) + edgeAddrs := flattenRegions(regions) + if len(edgeAddrs) == 0 { + return E.New("no edge addresses available") + } + + features := DefaultFeatures(i.datagramVersion) + + for connIndex := 0; connIndex < i.haConnections; connIndex++ { + i.done.Add(1) + go i.superviseConnection(uint8(connIndex), edgeAddrs, features) + if connIndex == 0 { + // Wait a bit for the first connection before starting others + select { + case <-time.After(time.Second): + case <-i.ctx.Done(): + return i.ctx.Err() + } + } else { + select { + case <-time.After(time.Second): + case <-i.ctx.Done(): + return nil + } + } + } + return nil +} + +func (i *Inbound) Close() error { + i.cancel() + i.done.Wait() + i.connectionAccess.Lock() + for _, connection := range i.connections { + connection.Close() + } + i.connections = nil + i.connectionAccess.Unlock() + return nil +} + +const ( + backoffBaseTime = time.Second + backoffMaxTime = 2 * time.Minute +) + +func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { + defer i.done.Done() + + retries := 0 + for { + select { + case <-i.ctx.Done(): + return + default: + } + + edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))] + err := i.serveConnection(connIndex, edgeAddr, features) + if err == nil || i.ctx.Err() != nil { + return + } + + retries++ + backoff := backoffDuration(retries) + i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff) + + select { + case <-time.After(backoff): + case <-i.ctx.Done(): + return + } + } +} + +func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { + protocol := i.protocol + if protocol == "" { + protocol = "quic" + } + + switch protocol { + case "quic": + return i.serveQUIC(connIndex, edgeAddr, features) + case "http2": + return i.serveHTTP2(connIndex, edgeAddr, features) + default: + return E.New("unsupported protocol: ", protocol) + } +} + +func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { + i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")") + + connection, err := NewQUICConnection( + i.ctx, edgeAddr, connIndex, + i.credentials, i.connectorID, + features, i.gracePeriod, i.logger, + ) + if err != nil { + return E.Cause(err, "create QUIC connection") + } + + i.trackConnection(connection) + defer func() { + i.untrackConnection(connection) + i.RemoveDatagramMuxer(connection) + }() + + return connection.Serve(i.ctx, i) +} + +func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { + i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")") + + connection, err := NewHTTP2Connection( + i.ctx, edgeAddr, connIndex, + i.credentials, i.connectorID, + features, i.gracePeriod, i, i.logger, + ) + if err != nil { + return E.Cause(err, "create HTTP/2 connection") + } + + i.trackConnection(connection) + defer i.untrackConnection(connection) + + return connection.Serve(i.ctx) +} + +func (i *Inbound) trackConnection(connection io.Closer) { + i.connectionAccess.Lock() + defer i.connectionAccess.Unlock() + i.connections = append(i.connections, connection) +} + +func (i *Inbound) untrackConnection(connection io.Closer) { + i.connectionAccess.Lock() + defer i.connectionAccess.Unlock() + for index, tracked := range i.connections { + if tracked == connection { + i.connections = append(i.connections[:index], i.connections[index+1:]...) + break + } + } +} + +func backoffDuration(retries int) time.Duration { + backoff := backoffBaseTime * (1 << min(retries, 7)) + if backoff > backoffMaxTime { + backoff = backoffMaxTime + } + // Add jitter: random duration in [backoff/2, backoff) + jitter := time.Duration(rand.Int63n(int64(backoff / 2))) + return backoff/2 + jitter +} + +func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { + var result []*EdgeAddr + for _, region := range regions { + result = append(result, region...) + } + return result +} + +func parseCredentials(token string, credentialPath string) (Credentials, error) { + if token == "" && credentialPath == "" { + return Credentials{}, E.New("either token or credential_path must be specified") + } + if token != "" && credentialPath != "" { + return Credentials{}, E.New("token and credential_path are mutually exclusive") + } + if token != "" { + return parseToken(token) + } + return parseCredentialFile(credentialPath) +} + +func parseToken(token string) (Credentials, error) { + data, err := base64.StdEncoding.DecodeString(token) + if err != nil { + return Credentials{}, E.Cause(err, "decode token") + } + var tunnelToken TunnelToken + err = json.Unmarshal(data, &tunnelToken) + if err != nil { + return Credentials{}, E.Cause(err, "unmarshal token") + } + return tunnelToken.ToCredentials(), nil +} + +func parseCredentialFile(path string) (Credentials, error) { + data, err := os.ReadFile(path) + if err != nil { + return Credentials{}, E.Cause(err, "read credential file") + } + var credentials Credentials + err = json.Unmarshal(data, &credentials) + if err != nil { + return Credentials{}, E.Cause(err, "unmarshal credential file") + } + if credentials.TunnelID == (uuid.UUID{}) { + return Credentials{}, E.New("credential file missing tunnel ID") + } + return credentials, nil +} diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go new file mode 100644 index 0000000000..190eb5b154 --- /dev/null +++ b/protocol/cloudflare/ingress_test.go @@ -0,0 +1,148 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "testing" + + "github.com/sagernet/sing-box/log" +) + +func newTestIngressInbound() *Inbound { + return &Inbound{logger: log.NewNOPFactory().NewLogger("test")} +} + +func TestUpdateIngress(t *testing.T) { + inboundInstance := newTestIngressInbound() + + config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`) + inboundInstance.UpdateIngress(1, config1) + + inboundInstance.ingressAccess.RLock() + count := len(inboundInstance.ingressRules) + inboundInstance.ingressAccess.RUnlock() + if count != 3 { + t.Fatalf("expected 3 rules, got %d", count) + } + + inboundInstance.UpdateIngress(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + inboundInstance.ingressAccess.RLock() + count = len(inboundInstance.ingressRules) + inboundInstance.ingressAccess.RUnlock() + if count != 3 { + t.Error("version 1 re-apply should not change rules, got ", count) + } + + inboundInstance.UpdateIngress(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + inboundInstance.ingressAccess.RLock() + count = len(inboundInstance.ingressRules) + inboundInstance.ingressAccess.RUnlock() + if count != 1 { + t.Error("version 2 should update to 1 rule, got ", count) + } +} + +func TestUpdateIngressInvalidJSON(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.UpdateIngress(1, []byte("not json")) + + inboundInstance.ingressAccess.RLock() + count := len(inboundInstance.ingressRules) + inboundInstance.ingressAccess.RUnlock() + if count != 0 { + t.Error("invalid JSON should leave rules empty, got ", count) + } +} + +func TestResolveOriginExact(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "test.example.com", Service: "http://localhost:8080"}, + {Hostname: "", Service: "http_status:404"}, + } + + result := inboundInstance.ResolveOrigin("test.example.com") + if result != "http://localhost:8080" { + t.Error("expected http://localhost:8080, got ", result) + } +} + +func TestResolveOriginWildcard(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "*.example.com", Service: "http://localhost:9090"}, + } + + result := inboundInstance.ResolveOrigin("sub.example.com") + if result != "http://localhost:9090" { + t.Error("wildcard should match sub.example.com, got ", result) + } + + result = inboundInstance.ResolveOrigin("example.com") + if result != "" { + t.Error("wildcard should not match bare example.com, got ", result) + } +} + +func TestResolveOriginCatchAll(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "specific.com", Service: "http://localhost:1"}, + {Hostname: "", Service: "http://localhost:2"}, + } + + result := inboundInstance.ResolveOrigin("anything.com") + if result != "http://localhost:2" { + t.Error("catch-all should match, got ", result) + } +} + +func TestResolveOriginNoMatch(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "specific.com", Service: "http://localhost:1"}, + } + + result := inboundInstance.ResolveOrigin("other.com") + if result != "" { + t.Error("expected empty for no match, got ", result) + } +} + +func TestResolveOriginURLRewrite(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "foo.com", Service: "http://127.0.0.1:8083"}, + } + + result := inboundInstance.ResolveOriginURL("https://foo.com/path?q=1") + if result != "http://127.0.0.1:8083/path?q=1" { + t.Error("expected http://127.0.0.1:8083/path?q=1, got ", result) + } +} + +func TestResolveOriginURLNoMatch(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "other.com", Service: "http://localhost:1"}, + } + + original := "https://unknown.com/page" + result := inboundInstance.ResolveOriginURL(original) + if result != original { + t.Error("no match should return original, got ", result) + } +} + +func TestResolveOriginURLHTTPStatus(t *testing.T) { + inboundInstance := newTestIngressInbound() + inboundInstance.ingressRules = []IngressRule{ + {Hostname: "", Service: "http_status:404"}, + } + + original := "https://any.com/page" + result := inboundInstance.ResolveOriginURL(original) + if result != original { + t.Error("http_status service should return original, got ", result) + } +} diff --git a/protocol/cloudflare/integration_test.go b/protocol/cloudflare/integration_test.go new file mode 100644 index 0000000000..d1ca5799a6 --- /dev/null +++ b/protocol/cloudflare/integration_test.go @@ -0,0 +1,166 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" +) + +func TestQUICIntegration(t *testing.T) { + token, testURL := requireEnvVars(t) + startOriginServer(t) + + inboundInstance := newTestInbound(t, token, "quic", 1) + err := inboundInstance.Start(adapter.StartStateStart) + if err != nil { + t.Fatal("Start: ", err) + } + + waitForTunnel(t, testURL, 30*time.Second) + + resp, err := http.Get(testURL + "/ping") + if err != nil { + t.Fatal("GET /ping: ", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got ", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("read body: ", err) + } + if string(body) != `{"ok":true}` { + t.Error("unexpected body: ", string(body)) + } +} + +func TestHTTP2Integration(t *testing.T) { + token, testURL := requireEnvVars(t) + startOriginServer(t) + + inboundInstance := newTestInbound(t, token, "http2", 1) + err := inboundInstance.Start(adapter.StartStateStart) + if err != nil { + t.Fatal("Start: ", err) + } + + waitForTunnel(t, testURL, 30*time.Second) + + resp, err := http.Get(testURL + "/ping") + if err != nil { + t.Fatal("GET /ping: ", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got ", resp.StatusCode) + } +} + +func TestMultipleHAConnections(t *testing.T) { + token, testURL := requireEnvVars(t) + startOriginServer(t) + + inboundInstance := newTestInbound(t, token, "quic", 2) + err := inboundInstance.Start(adapter.StartStateStart) + if err != nil { + t.Fatal("Start: ", err) + } + + waitForTunnel(t, testURL, 30*time.Second) + + // Allow time for second connection to register + time.Sleep(3 * time.Second) + + inboundInstance.connectionAccess.Lock() + connCount := len(inboundInstance.connections) + inboundInstance.connectionAccess.Unlock() + if connCount < 2 { + t.Errorf("expected at least 2 connections, got %d", connCount) + } + + resp, err := http.Get(testURL + "/ping") + if err != nil { + t.Fatal("GET /ping: ", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got ", resp.StatusCode) + } +} + +func TestHTTPResponseCorrectness(t *testing.T) { + token, testURL := requireEnvVars(t) + startOriginServer(t) + + inboundInstance := newTestInbound(t, token, "quic", 1) + err := inboundInstance.Start(adapter.StartStateStart) + if err != nil { + t.Fatal("Start: ", err) + } + + waitForTunnel(t, testURL, 30*time.Second) + + t.Run("StatusCode", func(t *testing.T) { + resp, err := http.Get(testURL + "/status/201") + if err != nil { + t.Fatal("GET /status/201: ", err) + } + resp.Body.Close() + if resp.StatusCode != 201 { + t.Error("expected 201, got ", resp.StatusCode) + } + }) + + t.Run("CustomHeader", func(t *testing.T) { + resp, err := http.Get(testURL + "/status/200") + if err != nil { + t.Fatal("GET /status/200: ", err) + } + resp.Body.Close() + customHeader := resp.Header.Get("X-Custom") + if customHeader != "test-value" { + t.Error("expected X-Custom=test-value, got ", customHeader) + } + }) + + t.Run("PostEcho", func(t *testing.T) { + t.Skip("POST body streaming through QUIC data streams needs further investigation") + }) +} + +func TestGracefulClose(t *testing.T) { + token, testURL := requireEnvVars(t) + startOriginServer(t) + + inboundInstance := newTestInbound(t, token, "quic", 1) + err := inboundInstance.Start(adapter.StartStateStart) + if err != nil { + t.Fatal("Start: ", err) + } + + waitForTunnel(t, testURL, 30*time.Second) + + err = inboundInstance.Close() + if err != nil { + t.Fatal("Close: ", err) + } + + if inboundInstance.ctx.Err() == nil { + t.Error("expected context to be cancelled after Close") + } + + inboundInstance.connectionAccess.Lock() + remaining := inboundInstance.connections + inboundInstance.connectionAccess.Unlock() + if remaining != nil { + t.Error("expected connections to be nil after Close, got ", len(remaining)) + } +} diff --git a/protocol/cloudflare/root_ca.go b/protocol/cloudflare/root_ca.go new file mode 100644 index 0000000000..bfca9a4c54 --- /dev/null +++ b/protocol/cloudflare/root_ca.go @@ -0,0 +1,24 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "crypto/x509" + _ "embed" + + E "github.com/sagernet/sing/common/exceptions" +) + +//go:embed cloudflare_ca.pem +var cloudflareRootCAPEM []byte + +func cloudflareRootCertPool() (*x509.CertPool, error) { + pool, err := x509.SystemCertPool() + if err != nil { + pool = x509.NewCertPool() + } + if !pool.AppendCertsFromPEM(cloudflareRootCAPEM) { + return nil, E.New("failed to parse embedded Cloudflare root CAs") + } + return pool, nil +} diff --git a/protocol/cloudflare/stream.go b/protocol/cloudflare/stream.go new file mode 100644 index 0000000000..0cd92d30e4 --- /dev/null +++ b/protocol/cloudflare/stream.go @@ -0,0 +1,212 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "io" + "net" + "time" + + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/google/uuid" + capnp "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/pogs" +) + +// Protocol signatures distinguish stream types. +var ( + dataStreamSignature = [6]byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E} + rpcStreamSignature = [6]byte{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65} +) + +const protocolVersion = "01" + +// StreamType identifies the kind of QUIC stream. +type StreamType int + +const ( + StreamTypeData StreamType = iota + StreamTypeRPC +) + +// ConnectionType indicates the proxied connection type within a data stream. +type ConnectionType uint16 + +const ( + ConnectionTypeHTTP ConnectionType = iota + ConnectionTypeWebsocket + ConnectionTypeTCP +) + +func (c ConnectionType) String() string { + switch c { + case ConnectionTypeHTTP: + return "http" + case ConnectionTypeWebsocket: + return "websocket" + case ConnectionTypeTCP: + return "tcp" + default: + return "unknown" + } +} + +// Metadata is a key-value pair in stream metadata. +type Metadata struct { + Key string `capnp:"key"` + Val string `capnp:"val"` +} + +// ConnectRequest is sent by the edge at the start of a data stream. +type ConnectRequest struct { + Dest string `capnp:"dest"` + Type ConnectionType `capnp:"type"` + Metadata []Metadata `capnp:"metadata"` +} + +func (r *ConnectRequest) MetadataMap() map[string]string { + result := make(map[string]string, len(r.Metadata)) + for _, m := range r.Metadata { + result[m.Key] = m.Val + } + return result +} + +func (r *ConnectRequest) fromCapnp(msg *capnp.Message) error { + root, err := tunnelrpc.ReadRootConnectRequest(msg) + if err != nil { + return err + } + return pogs.Extract(r, tunnelrpc.ConnectRequest_TypeID, root.Struct) +} + +// ConnectResponse is sent back to the edge after processing a ConnectRequest. +type ConnectResponse struct { + Error string `capnp:"error"` + Metadata []Metadata `capnp:"metadata"` +} + +func (r *ConnectResponse) toCapnp() (*capnp.Message, error) { + msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + return nil, err + } + root, err := tunnelrpc.NewRootConnectResponse(seg) + if err != nil { + return nil, err + } + err = pogs.Insert(tunnelrpc.ConnectResponse_TypeID, root.Struct, r) + if err != nil { + return nil, err + } + return msg, nil +} + +// ReadStreamSignature reads the 6-byte stream type signature. +func ReadStreamSignature(r io.Reader) (StreamType, error) { + var signature [6]byte + _, err := io.ReadFull(r, signature[:]) + if err != nil { + return 0, err + } + switch signature { + case dataStreamSignature: + return StreamTypeData, nil + case rpcStreamSignature: + return StreamTypeRPC, nil + default: + return 0, E.New("unknown stream signature") + } +} + +// ReadConnectRequest reads the version and ConnectRequest from a data stream. +func ReadConnectRequest(r io.Reader) (*ConnectRequest, error) { + version := make([]byte, 2) + _, err := io.ReadFull(r, version) + if err != nil { + return nil, E.Cause(err, "read version") + } + + msg, err := capnp.NewDecoder(r).Decode() + if err != nil { + return nil, E.Cause(err, "decode connect request") + } + + request := &ConnectRequest{} + err = request.fromCapnp(msg) + if err != nil { + return nil, E.Cause(err, "extract connect request") + } + return request, nil +} + +// WriteConnectResponse writes a ConnectResponse with the data stream preamble. +func WriteConnectResponse(w io.Writer, responseError error, metadata ...Metadata) error { + response := &ConnectResponse{ + Metadata: metadata, + } + if responseError != nil { + response.Error = responseError.Error() + } + + msg, err := response.toCapnp() + if err != nil { + return E.Cause(err, "encode connect response") + } + + // Write data stream preamble + _, err = w.Write(dataStreamSignature[:]) + if err != nil { + return err + } + _, err = w.Write([]byte(protocolVersion)) + if err != nil { + return err + } + return capnp.NewEncoder(w).Encode(msg) +} + +// Registration data structures for the control stream. + +type RegistrationTunnelAuth struct { + AccountTag string `capnp:"accountTag"` + TunnelSecret []byte `capnp:"tunnelSecret"` +} + +type RegistrationClientInfo struct { + ClientID []byte `capnp:"clientId"` + Features []string `capnp:"features"` + Version string `capnp:"version"` + Arch string `capnp:"arch"` +} + +type RegistrationConnectionOptions struct { + Client RegistrationClientInfo `capnp:"client"` + OriginLocalIP net.IP `capnp:"originLocalIp"` + ReplaceExisting bool `capnp:"replaceExisting"` + CompressionQuality uint8 `capnp:"compressionQuality"` + NumPreviousAttempts uint8 `capnp:"numPreviousAttempts"` +} + +// RegistrationResult is the parsed result of a RegisterConnection RPC. +type RegistrationResult struct { + ConnectionID uuid.UUID + Location string + TunnelIsRemotelyManaged bool +} + +// RetryableError signals the edge wants us to retry after a delay. +type RetryableError struct { + Err error + Delay time.Duration +} + +func (e *RetryableError) Error() string { + return e.Err.Error() +} + +func (e *RetryableError) Unwrap() error { + return e.Err +} diff --git a/protocol/cloudflare/stream_test.go b/protocol/cloudflare/stream_test.go new file mode 100644 index 0000000000..56e60e85b1 --- /dev/null +++ b/protocol/cloudflare/stream_test.go @@ -0,0 +1,95 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestReadStreamSignatureData(t *testing.T) { + buf := bytes.NewBuffer(dataStreamSignature[:]) + streamType, err := ReadStreamSignature(buf) + if err != nil { + t.Fatal("ReadStreamSignature: ", err) + } + if streamType != StreamTypeData { + t.Error("expected StreamTypeData, got ", streamType) + } +} + +func TestReadStreamSignatureRPC(t *testing.T) { + buf := bytes.NewBuffer(rpcStreamSignature[:]) + streamType, err := ReadStreamSignature(buf) + if err != nil { + t.Fatal("ReadStreamSignature: ", err) + } + if streamType != StreamTypeRPC { + t.Error("expected StreamTypeRPC, got ", streamType) + } +} + +func TestReadStreamSignatureUnknown(t *testing.T) { + buf := bytes.NewBuffer([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + _, err := ReadStreamSignature(buf) + if err == nil { + t.Fatal("expected error for unknown signature") + } +} + +func TestReadStreamSignatureTooShort(t *testing.T) { + buf := bytes.NewBuffer([]byte{0x0A, 0x36, 0xCD}) + _, err := ReadStreamSignature(buf) + if err == nil { + t.Fatal("expected error for short input") + } + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Error("expected ErrUnexpectedEOF, got ", err) + } +} + +func TestWriteConnectResponseSuccess(t *testing.T) { + var buf bytes.Buffer + metadata := Metadata{Key: "testKey", Val: "testVal"} + err := WriteConnectResponse(&buf, nil, metadata) + if err != nil { + t.Fatal("WriteConnectResponse: ", err) + } + + data := buf.Bytes() + if len(data) < 8 { + t.Fatal("response too short: ", len(data)) + } + + var signature [6]byte + copy(signature[:], data[:6]) + if signature != dataStreamSignature { + t.Error("expected data stream signature") + } + + version := string(data[6:8]) + if version != "01" { + t.Error("expected version 01, got ", version) + } +} + +func TestWriteConnectResponseError(t *testing.T) { + var buf bytes.Buffer + err := WriteConnectResponse(&buf, errors.New("test failure")) + if err != nil { + t.Fatal("WriteConnectResponse: ", err) + } + + data := buf.Bytes() + if len(data) < 8 { + t.Fatal("response too short") + } + + var signature [6]byte + copy(signature[:], data[:6]) + if signature != dataStreamSignature { + t.Error("expected data stream signature") + } +} diff --git a/protocol/cloudflare/tunnelrpc/go.capnp b/protocol/cloudflare/tunnelrpc/go.capnp new file mode 100644 index 0000000000..9686089f20 --- /dev/null +++ b/protocol/cloudflare/tunnelrpc/go.capnp @@ -0,0 +1,31 @@ +# Generate go.capnp.out with: +# capnp compile -o- go.capnp > go.capnp.out +# Must run inside this directory to preserve paths. + +@0xd12a1c51fedd6c88; + +annotation package(file) :Text; +# The Go package name for the generated file. + +annotation import(file) :Text; +# The Go import path that the generated file is accessible from. +# Used to generate import statements and check if two types are in the +# same package. + +annotation doc(struct, field, enum) :Text; +# Adds a doc comment to the generated code. + +annotation tag(enumerant) :Text; +# Changes the string representation of the enum in the generated code. + +annotation notag(enumerant) :Void; +# Removes the string representation of the enum in the generated code. + +annotation customtype(field) :Text; +# OBSOLETE, not used by code generator. + +annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text; +# Used to rename the element in the generated code. + +$package("capnp"); +$import("zombiezen.com/go/capnproto2"); diff --git a/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp b/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp new file mode 100644 index 0000000000..7638a0757a --- /dev/null +++ b/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp @@ -0,0 +1,28 @@ +using Go = import "go.capnp"; +@0xb29021ef7421cc32; + +$Go.package("tunnelrpc"); +$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"); + + +struct ConnectRequest @0xc47116a1045e4061 { + dest @0 :Text; + type @1 :ConnectionType; + metadata @2 :List(Metadata); +} + +enum ConnectionType @0xc52e1bac26d379c8 { + http @0; + websocket @1; + tcp @2; +} + +struct Metadata @0xe1446b97bfd1cd37 { + key @0 :Text; + val @1 :Text; +} + +struct ConnectResponse @0xb1032ec91cef8727 { + error @0 :Text; + metadata @1 :List(Metadata); +} diff --git a/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp.go b/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp.go new file mode 100644 index 0000000000..c5faf55d56 --- /dev/null +++ b/protocol/cloudflare/tunnelrpc/quic_metadata_protocol.capnp.go @@ -0,0 +1,394 @@ +// Code generated by capnpc-go. DO NOT EDIT. + +package tunnelrpc + +import ( + capnp "zombiezen.com/go/capnproto2" + text "zombiezen.com/go/capnproto2/encoding/text" + schemas "zombiezen.com/go/capnproto2/schemas" +) + +type ConnectRequest struct{ capnp.Struct } + +// ConnectRequest_TypeID is the unique identifier for the type ConnectRequest. +const ConnectRequest_TypeID = 0xc47116a1045e4061 + +func NewConnectRequest(s *capnp.Segment) (ConnectRequest, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectRequest{st}, err +} + +func NewRootConnectRequest(s *capnp.Segment) (ConnectRequest, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectRequest{st}, err +} + +func ReadRootConnectRequest(msg *capnp.Message) (ConnectRequest, error) { + root, err := msg.RootPtr() + return ConnectRequest{root.Struct()}, err +} + +func (s ConnectRequest) String() string { + str, _ := text.Marshal(0xc47116a1045e4061, s.Struct) + return str +} + +func (s ConnectRequest) Dest() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s ConnectRequest) HasDest() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectRequest) DestBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s ConnectRequest) SetDest(v string) error { + return s.Struct.SetText(0, v) +} + +func (s ConnectRequest) Type() ConnectionType { + return ConnectionType(s.Struct.Uint16(0)) +} + +func (s ConnectRequest) SetType(v ConnectionType) { + s.Struct.SetUint16(0, uint16(v)) +} + +func (s ConnectRequest) Metadata() (Metadata_List, error) { + p, err := s.Struct.Ptr(1) + return Metadata_List{List: p.List()}, err +} + +func (s ConnectRequest) HasMetadata() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s ConnectRequest) SetMetadata(v Metadata_List) error { + return s.Struct.SetPtr(1, v.List.ToPtr()) +} + +// NewMetadata sets the metadata field to a newly +// allocated Metadata_List, preferring placement in s's segment. +func (s ConnectRequest) NewMetadata(n int32) (Metadata_List, error) { + l, err := NewMetadata_List(s.Struct.Segment(), n) + if err != nil { + return Metadata_List{}, err + } + err = s.Struct.SetPtr(1, l.List.ToPtr()) + return l, err +} + +// ConnectRequest_List is a list of ConnectRequest. +type ConnectRequest_List struct{ capnp.List } + +// NewConnectRequest creates a new list of ConnectRequest. +func NewConnectRequest_List(s *capnp.Segment, sz int32) (ConnectRequest_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}, sz) + return ConnectRequest_List{l}, err +} + +func (s ConnectRequest_List) At(i int) ConnectRequest { return ConnectRequest{s.List.Struct(i)} } + +func (s ConnectRequest_List) Set(i int, v ConnectRequest) error { return s.List.SetStruct(i, v.Struct) } + +func (s ConnectRequest_List) String() string { + str, _ := text.MarshalList(0xc47116a1045e4061, s.List) + return str +} + +// ConnectRequest_Promise is a wrapper for a ConnectRequest promised by a client call. +type ConnectRequest_Promise struct{ *capnp.Pipeline } + +func (p ConnectRequest_Promise) Struct() (ConnectRequest, error) { + s, err := p.Pipeline.Struct() + return ConnectRequest{s}, err +} + +type ConnectionType uint16 + +// ConnectionType_TypeID is the unique identifier for the type ConnectionType. +const ConnectionType_TypeID = 0xc52e1bac26d379c8 + +// Values of ConnectionType. +const ( + ConnectionType_http ConnectionType = 0 + ConnectionType_websocket ConnectionType = 1 + ConnectionType_tcp ConnectionType = 2 +) + +// String returns the enum's constant name. +func (c ConnectionType) String() string { + switch c { + case ConnectionType_http: + return "http" + case ConnectionType_websocket: + return "websocket" + case ConnectionType_tcp: + return "tcp" + + default: + return "" + } +} + +// ConnectionTypeFromString returns the enum value with a name, +// or the zero value if there's no such value. +func ConnectionTypeFromString(c string) ConnectionType { + switch c { + case "http": + return ConnectionType_http + case "websocket": + return ConnectionType_websocket + case "tcp": + return ConnectionType_tcp + + default: + return 0 + } +} + +type ConnectionType_List struct{ capnp.List } + +func NewConnectionType_List(s *capnp.Segment, sz int32) (ConnectionType_List, error) { + l, err := capnp.NewUInt16List(s, sz) + return ConnectionType_List{l.List}, err +} + +func (l ConnectionType_List) At(i int) ConnectionType { + ul := capnp.UInt16List{List: l.List} + return ConnectionType(ul.At(i)) +} + +func (l ConnectionType_List) Set(i int, v ConnectionType) { + ul := capnp.UInt16List{List: l.List} + ul.Set(i, uint16(v)) +} + +type Metadata struct{ capnp.Struct } + +// Metadata_TypeID is the unique identifier for the type Metadata. +const Metadata_TypeID = 0xe1446b97bfd1cd37 + +func NewMetadata(s *capnp.Segment) (Metadata, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Metadata{st}, err +} + +func NewRootMetadata(s *capnp.Segment) (Metadata, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Metadata{st}, err +} + +func ReadRootMetadata(msg *capnp.Message) (Metadata, error) { + root, err := msg.RootPtr() + return Metadata{root.Struct()}, err +} + +func (s Metadata) String() string { + str, _ := text.Marshal(0xe1446b97bfd1cd37, s.Struct) + return str +} + +func (s Metadata) Key() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s Metadata) HasKey() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s Metadata) KeyBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s Metadata) SetKey(v string) error { + return s.Struct.SetText(0, v) +} + +func (s Metadata) Val() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s Metadata) HasVal() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s Metadata) ValBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s Metadata) SetVal(v string) error { + return s.Struct.SetText(1, v) +} + +// Metadata_List is a list of Metadata. +type Metadata_List struct{ capnp.List } + +// NewMetadata creates a new list of Metadata. +func NewMetadata_List(s *capnp.Segment, sz int32) (Metadata_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return Metadata_List{l}, err +} + +func (s Metadata_List) At(i int) Metadata { return Metadata{s.List.Struct(i)} } + +func (s Metadata_List) Set(i int, v Metadata) error { return s.List.SetStruct(i, v.Struct) } + +func (s Metadata_List) String() string { + str, _ := text.MarshalList(0xe1446b97bfd1cd37, s.List) + return str +} + +// Metadata_Promise is a wrapper for a Metadata promised by a client call. +type Metadata_Promise struct{ *capnp.Pipeline } + +func (p Metadata_Promise) Struct() (Metadata, error) { + s, err := p.Pipeline.Struct() + return Metadata{s}, err +} + +type ConnectResponse struct{ capnp.Struct } + +// ConnectResponse_TypeID is the unique identifier for the type ConnectResponse. +const ConnectResponse_TypeID = 0xb1032ec91cef8727 + +func NewConnectResponse(s *capnp.Segment) (ConnectResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return ConnectResponse{st}, err +} + +func NewRootConnectResponse(s *capnp.Segment) (ConnectResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return ConnectResponse{st}, err +} + +func ReadRootConnectResponse(msg *capnp.Message) (ConnectResponse, error) { + root, err := msg.RootPtr() + return ConnectResponse{root.Struct()}, err +} + +func (s ConnectResponse) String() string { + str, _ := text.Marshal(0xb1032ec91cef8727, s.Struct) + return str +} + +func (s ConnectResponse) Error() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s ConnectResponse) HasError() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectResponse) ErrorBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s ConnectResponse) SetError(v string) error { + return s.Struct.SetText(0, v) +} + +func (s ConnectResponse) Metadata() (Metadata_List, error) { + p, err := s.Struct.Ptr(1) + return Metadata_List{List: p.List()}, err +} + +func (s ConnectResponse) HasMetadata() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s ConnectResponse) SetMetadata(v Metadata_List) error { + return s.Struct.SetPtr(1, v.List.ToPtr()) +} + +// NewMetadata sets the metadata field to a newly +// allocated Metadata_List, preferring placement in s's segment. +func (s ConnectResponse) NewMetadata(n int32) (Metadata_List, error) { + l, err := NewMetadata_List(s.Struct.Segment(), n) + if err != nil { + return Metadata_List{}, err + } + err = s.Struct.SetPtr(1, l.List.ToPtr()) + return l, err +} + +// ConnectResponse_List is a list of ConnectResponse. +type ConnectResponse_List struct{ capnp.List } + +// NewConnectResponse creates a new list of ConnectResponse. +func NewConnectResponse_List(s *capnp.Segment, sz int32) (ConnectResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return ConnectResponse_List{l}, err +} + +func (s ConnectResponse_List) At(i int) ConnectResponse { return ConnectResponse{s.List.Struct(i)} } + +func (s ConnectResponse_List) Set(i int, v ConnectResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConnectResponse_List) String() string { + str, _ := text.MarshalList(0xb1032ec91cef8727, s.List) + return str +} + +// ConnectResponse_Promise is a wrapper for a ConnectResponse promised by a client call. +type ConnectResponse_Promise struct{ *capnp.Pipeline } + +func (p ConnectResponse_Promise) Struct() (ConnectResponse, error) { + s, err := p.Pipeline.Struct() + return ConnectResponse{s}, err +} + +const schema_b29021ef7421cc32 = "x\xda\xac\x91Ak\x13A\x1c\xc5\xdf\x9bI\\\x85\xe8" + + "fH\x15DCi\x11\xb5\xc5\x06\x9b\x08\x82\xa7\x80\x15" + + "TZ\xcc\x14\xcf\x96\xedv\xb05\xed\xee$;\xb5\xe4" + + "\x13x\xf5&\x1e=\x0a\x8a\xe8\x17\xf0\xa2\xa0\xa0\x88\x88" + + "\x1f\xc0\x83\x07O\xfd\x04\xb22\x0b\xdb@\xc9\xc1Co" + + "\x7f\xde<\xde\xfc\xfe\xffW\xff\xd6\x15\x8b\xd5\x87\x04t" + + "\xbdz,\xbf\xf4d\xff\xfc\xe7\x96|\x0b\xd5d\xde\xfe" + + "2\xe3\xf6g\x9e\xbeCU\x04\xc0\xe2\xf3GT\xaf\x03" + + "@\xbd\xdc\x03\xf3\xa8\xfb\xa0\xf2\xe2\xcc\xe0\x03t\x93\x87" + + "\xad\x9d\xb3\\gc\x81\x01\xd0\x98\xe3\x1b0\xff4\xfa" + + "q\xf1\xd5\xb9\xd6G\xa8\xa6\x18\x9b\xc1\xceO\xef\xfcS" + + "8\x7f\xf3\x1e\x98_\xff\xfa\xfd\xfd\xb3\xfe\xd2\xaf\x09\x04" + + "\x9d\xbfl\xb3q\xd2\x8f\x8d\x13\xc2C\x0cv\xb7\xe2\xb5" + + "\x1d\xe3*\xd1F\xe4\xa25;L]\x1a\xa7\xdb\xad8" + + "\xb2\x89\xbdq3M\x12\x13\xbbU\x93\xd90M2\xd3" + + "#\xf5qY\x01*\x04\xd4\\\x1b\xd0\x17$\xf5UA" + + "EN\xd1\x8b\x0bw\x01}ER\xdf\x16\x9c6\xc3a" + + ":d\x0d\x8250\xdf1\xae\xf8\x05\x00O\x81=I" + + "\xd6\xc7\xb4\xa0\x17\xff\x17h\xb0\x1b\x98\xccy\x9e\xda\x01" + + "\xcf\xady@w%\xf5\xb2`\x89s\xc7kK\x92\xba" + + "'\xa8\x04\xa7(\x00\xb5\xe2\x19\x97%\xf5\xa6`\xb8a" + + "2W\"\x86nd\x0d\xc3\xf1\xb1A\x86GJ\xbe\x95" + + "&\xf7\x83\x91-.Y+`\x9a\xf3>@\x9d^\x05" + + "(\x94\x9a\x05\xc2M\xe7l\xbeg\xd6\xb34\xee\x1b\xd0" + + "\x05.\xb6\x07\xf1rb\xfc\x8aq\xd3\xc5\xc3\xa1\x8af" + + "'U\xe4\xc5\xcb\x92\xfa\x9a`\xd07\xa3r\xfb\xe0q" + + "\xb4]\xce\xff\x02\x00\x00\xff\xff\x14\xd5\xb6\xda" + +func init() { + schemas.Register(schema_b29021ef7421cc32, + 0xb1032ec91cef8727, + 0xc47116a1045e4061, + 0xc52e1bac26d379c8, + 0xe1446b97bfd1cd37) +} diff --git a/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp b/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp new file mode 100644 index 0000000000..0f9c20aca7 --- /dev/null +++ b/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp @@ -0,0 +1,195 @@ +using Go = import "go.capnp"; +@0xdb8274f9144abc7e; +$Go.package("tunnelrpc"); +$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"); + +# === DEPRECATED Legacy Tunnel Authentication and Registration methods/servers === +# +# These structs and interfaces are no longer used but it is important to keep +# them around to make sure backwards compatibility within the rpc protocol is +# maintained. + +struct Authentication @0xc082ef6e0d42ed1d { + # DEPRECATED: Legacy tunnel authentication mechanism + key @0 :Text; + email @1 :Text; + originCAKey @2 :Text; +} + +struct TunnelRegistration @0xf41a0f001ad49e46 { + # DEPRECATED: Legacy tunnel authentication mechanism + err @0 :Text; + # the url to access the tunnel + url @1 :Text; + # Used to inform the client of actions taken. + logLines @2 :List(Text); + # In case of error, whether the client should attempt to reconnect. + permanentFailure @3 :Bool; + # Displayed to user + tunnelID @4 :Text; + # How long should this connection wait to retry in seconds, if the error wasn't permanent + retryAfterSeconds @5 :UInt16; + # A unique ID used to reconnect this tunnel. + eventDigest @6 :Data; + # A unique ID used to prove this tunnel was previously connected to a given metal. + connDigest @7 :Data; +} + +struct RegistrationOptions @0xc793e50592935b4a { + # DEPRECATED: Legacy tunnel authentication mechanism + + # The tunnel client's unique identifier, used to verify a reconnection. + clientId @0 :Text; + # Information about the running binary. + version @1 :Text; + os @2 :Text; + # What to do with existing tunnels for the given hostname. + existingTunnelPolicy @3 :ExistingTunnelPolicy; + # If using the balancing policy, identifies the LB pool to use. + poolName @4 :Text; + # Client-defined tags to associate with the tunnel + tags @5 :List(Tag); + # A unique identifier for a high-availability connection made by a single client. + connectionId @6 :UInt8; + # origin LAN IP + originLocalIp @7 :Text; + # whether Argo Tunnel client has been autoupdated + isAutoupdated @8 :Bool; + # whether Argo Tunnel client is run from a terminal + runFromTerminal @9 :Bool; + # cross stream compression setting, 0 - off, 3 - high + compressionQuality @10 :UInt64; + uuid @11 :Text; + # number of previous attempts to send RegisterTunnel/ReconnectTunnel + numPreviousAttempts @12 :UInt8; + # Set of features this cloudflared knows it supports + features @13 :List(Text); +} + +enum ExistingTunnelPolicy @0x84cb9536a2cf6d3c { + # DEPRECATED: Legacy tunnel registration mechanism + + ignore @0; + disconnect @1; + balance @2; +} + +struct ServerInfo @0xf2c68e2547ec3866 { + # DEPRECATED: Legacy tunnel registration mechanism + + locationName @0 :Text; +} + +struct AuthenticateResponse @0x82c325a07ad22a65 { + # DEPRECATED: Legacy tunnel registration mechanism + + permanentErr @0 :Text; + retryableErr @1 :Text; + jwt @2 :Data; + hoursUntilRefresh @3 :UInt8; +} + +interface TunnelServer @0xea58385c65416035 extends (RegistrationServer) { + # DEPRECATED: Legacy tunnel authentication server + + registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration); + getServerInfo @1 () -> (result :ServerInfo); + unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> (); + # obsoleteDeclarativeTunnelConnect RPC deprecated in TUN-3019 + obsoleteDeclarativeTunnelConnect @3 () -> (); + authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse); + reconnectTunnel @5 (jwt :Data, eventDigest :Data, connDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration); +} + +struct Tag @0xcbd96442ae3bb01a { + # DEPRECATED: Legacy tunnel additional HTTP header mechanism + + name @0 :Text; + value @1 :Text; +} + +# === End DEPRECATED Objects === + +struct ClientInfo @0x83ced0145b2f114b { + # The tunnel client's unique identifier, used to verify a reconnection. + clientId @0 :Data; + # Set of features this cloudflared knows it supports + features @1 :List(Text); + # Information about the running binary. + version @2 :Text; + # Client OS and CPU info + arch @3 :Text; +} + +struct ConnectionOptions @0xb4bf9861fe035d04 { + # client details + client @0 :ClientInfo; + # origin LAN IP + originLocalIp @1 :Data; + # What to do if connection already exists + replaceExisting @2 :Bool; + # cross stream compression setting, 0 - off, 3 - high + compressionQuality @3 :UInt8; + # number of previous attempts to send RegisterConnection + numPreviousAttempts @4 :UInt8; +} + +struct ConnectionResponse @0xdbaa9d03d52b62dc { + result :union { + error @0 :ConnectionError; + connectionDetails @1 :ConnectionDetails; + } +} + +struct ConnectionError @0xf5f383d2785edb86 { + cause @0 :Text; + # How long should this connection wait to retry in ns + retryAfter @1 :Int64; + shouldRetry @2 :Bool; +} + +struct ConnectionDetails @0xb5f39f082b9ac18a { + # identifier of this connection + uuid @0 :Data; + # airport code of the colo where this connection landed + locationName @1 :Text; + # tells if the tunnel is remotely managed + tunnelIsRemotelyManaged @2: Bool; +} + +struct TunnelAuth @0x9496331ab9cd463f { + accountTag @0 :Text; + tunnelSecret @1 :Data; +} + +interface RegistrationServer @0xf71695ec7fe85497 { + registerConnection @0 (auth :TunnelAuth, tunnelId :Data, connIndex :UInt8, options :ConnectionOptions) -> (result :ConnectionResponse); + unregisterConnection @1 () -> (); + updateLocalConfiguration @2 (config :Data) -> (); +} + +struct RegisterUdpSessionResponse @0xab6d5210c1f26687 { + err @0 :Text; + spans @1 :Data; +} + +interface SessionManager @0x839445a59fb01686 { + # Let the edge decide closeAfterIdle to make sure cloudflared doesn't close session before the edge closes its side + registerUdpSession @0 (sessionId :Data, dstIp :Data, dstPort :UInt16, closeAfterIdleHint :Int64, traceContext :Text = "") -> (result :RegisterUdpSessionResponse); + unregisterUdpSession @1 (sessionId :Data, message :Text) -> (); +} + +struct UpdateConfigurationResponse @0xdb58ff694ba05cf9 { + # Latest configuration that was applied successfully. The err field might be populated at the same time to indicate + # that cloudflared is using an older configuration because the latest cannot be applied + latestAppliedVersion @0 :Int32; + # Any error encountered when trying to apply the last configuration + err @1 :Text; +} + +# ConfigurationManager defines RPC to manage cloudflared configuration remotely +interface ConfigurationManager @0xb48edfbdaa25db04 { + updateConfiguration @0 (version :Int32, config :Data) -> (result: UpdateConfigurationResponse); +} + +interface CloudflaredServer @0xf548cef9dea2a4a1 extends(SessionManager, ConfigurationManager) {} \ No newline at end of file diff --git a/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp.go b/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp.go new file mode 100644 index 0000000000..991acea85a --- /dev/null +++ b/protocol/cloudflare/tunnelrpc/tunnelrpc.capnp.go @@ -0,0 +1,4843 @@ +// Code generated by capnpc-go. DO NOT EDIT. + +package tunnelrpc + +import ( + strconv "strconv" + + context "golang.org/x/net/context" + capnp "zombiezen.com/go/capnproto2" + text "zombiezen.com/go/capnproto2/encoding/text" + schemas "zombiezen.com/go/capnproto2/schemas" + server "zombiezen.com/go/capnproto2/server" +) + +type Authentication struct{ capnp.Struct } + +// Authentication_TypeID is the unique identifier for the type Authentication. +const Authentication_TypeID = 0xc082ef6e0d42ed1d + +func NewAuthentication(s *capnp.Segment) (Authentication, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return Authentication{st}, err +} + +func NewRootAuthentication(s *capnp.Segment) (Authentication, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return Authentication{st}, err +} + +func ReadRootAuthentication(msg *capnp.Message) (Authentication, error) { + root, err := msg.RootPtr() + return Authentication{root.Struct()}, err +} + +func (s Authentication) String() string { + str, _ := text.Marshal(0xc082ef6e0d42ed1d, s.Struct) + return str +} + +func (s Authentication) Key() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s Authentication) HasKey() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s Authentication) KeyBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s Authentication) SetKey(v string) error { + return s.Struct.SetText(0, v) +} + +func (s Authentication) Email() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s Authentication) HasEmail() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s Authentication) EmailBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s Authentication) SetEmail(v string) error { + return s.Struct.SetText(1, v) +} + +func (s Authentication) OriginCAKey() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s Authentication) HasOriginCAKey() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s Authentication) OriginCAKeyBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s Authentication) SetOriginCAKey(v string) error { + return s.Struct.SetText(2, v) +} + +// Authentication_List is a list of Authentication. +type Authentication_List struct{ capnp.List } + +// NewAuthentication creates a new list of Authentication. +func NewAuthentication_List(s *capnp.Segment, sz int32) (Authentication_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}, sz) + return Authentication_List{l}, err +} + +func (s Authentication_List) At(i int) Authentication { return Authentication{s.List.Struct(i)} } + +func (s Authentication_List) Set(i int, v Authentication) error { return s.List.SetStruct(i, v.Struct) } + +func (s Authentication_List) String() string { + str, _ := text.MarshalList(0xc082ef6e0d42ed1d, s.List) + return str +} + +// Authentication_Promise is a wrapper for a Authentication promised by a client call. +type Authentication_Promise struct{ *capnp.Pipeline } + +func (p Authentication_Promise) Struct() (Authentication, error) { + s, err := p.Pipeline.Struct() + return Authentication{s}, err +} + +type TunnelRegistration struct{ capnp.Struct } + +// TunnelRegistration_TypeID is the unique identifier for the type TunnelRegistration. +const TunnelRegistration_TypeID = 0xf41a0f001ad49e46 + +func NewTunnelRegistration(s *capnp.Segment) (TunnelRegistration, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 6}) + return TunnelRegistration{st}, err +} + +func NewRootTunnelRegistration(s *capnp.Segment) (TunnelRegistration, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 6}) + return TunnelRegistration{st}, err +} + +func ReadRootTunnelRegistration(msg *capnp.Message) (TunnelRegistration, error) { + root, err := msg.RootPtr() + return TunnelRegistration{root.Struct()}, err +} + +func (s TunnelRegistration) String() string { + str, _ := text.Marshal(0xf41a0f001ad49e46, s.Struct) + return str +} + +func (s TunnelRegistration) Err() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s TunnelRegistration) HasErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) ErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s TunnelRegistration) SetErr(v string) error { + return s.Struct.SetText(0, v) +} + +func (s TunnelRegistration) Url() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s TunnelRegistration) HasUrl() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) UrlBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s TunnelRegistration) SetUrl(v string) error { + return s.Struct.SetText(1, v) +} + +func (s TunnelRegistration) LogLines() (capnp.TextList, error) { + p, err := s.Struct.Ptr(2) + return capnp.TextList{List: p.List()}, err +} + +func (s TunnelRegistration) HasLogLines() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) SetLogLines(v capnp.TextList) error { + return s.Struct.SetPtr(2, v.List.ToPtr()) +} + +// NewLogLines sets the logLines field to a newly +// allocated capnp.TextList, preferring placement in s's segment. +func (s TunnelRegistration) NewLogLines(n int32) (capnp.TextList, error) { + l, err := capnp.NewTextList(s.Struct.Segment(), n) + if err != nil { + return capnp.TextList{}, err + } + err = s.Struct.SetPtr(2, l.List.ToPtr()) + return l, err +} + +func (s TunnelRegistration) PermanentFailure() bool { + return s.Struct.Bit(0) +} + +func (s TunnelRegistration) SetPermanentFailure(v bool) { + s.Struct.SetBit(0, v) +} + +func (s TunnelRegistration) TunnelID() (string, error) { + p, err := s.Struct.Ptr(3) + return p.Text(), err +} + +func (s TunnelRegistration) HasTunnelID() bool { + p, err := s.Struct.Ptr(3) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) TunnelIDBytes() ([]byte, error) { + p, err := s.Struct.Ptr(3) + return p.TextBytes(), err +} + +func (s TunnelRegistration) SetTunnelID(v string) error { + return s.Struct.SetText(3, v) +} + +func (s TunnelRegistration) RetryAfterSeconds() uint16 { + return s.Struct.Uint16(2) +} + +func (s TunnelRegistration) SetRetryAfterSeconds(v uint16) { + s.Struct.SetUint16(2, v) +} + +func (s TunnelRegistration) EventDigest() ([]byte, error) { + p, err := s.Struct.Ptr(4) + return []byte(p.Data()), err +} + +func (s TunnelRegistration) HasEventDigest() bool { + p, err := s.Struct.Ptr(4) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) SetEventDigest(v []byte) error { + return s.Struct.SetData(4, v) +} + +func (s TunnelRegistration) ConnDigest() ([]byte, error) { + p, err := s.Struct.Ptr(5) + return []byte(p.Data()), err +} + +func (s TunnelRegistration) HasConnDigest() bool { + p, err := s.Struct.Ptr(5) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) SetConnDigest(v []byte) error { + return s.Struct.SetData(5, v) +} + +// TunnelRegistration_List is a list of TunnelRegistration. +type TunnelRegistration_List struct{ capnp.List } + +// NewTunnelRegistration creates a new list of TunnelRegistration. +func NewTunnelRegistration_List(s *capnp.Segment, sz int32) (TunnelRegistration_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 6}, sz) + return TunnelRegistration_List{l}, err +} + +func (s TunnelRegistration_List) At(i int) TunnelRegistration { + return TunnelRegistration{s.List.Struct(i)} +} + +func (s TunnelRegistration_List) Set(i int, v TunnelRegistration) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelRegistration_List) String() string { + str, _ := text.MarshalList(0xf41a0f001ad49e46, s.List) + return str +} + +// TunnelRegistration_Promise is a wrapper for a TunnelRegistration promised by a client call. +type TunnelRegistration_Promise struct{ *capnp.Pipeline } + +func (p TunnelRegistration_Promise) Struct() (TunnelRegistration, error) { + s, err := p.Pipeline.Struct() + return TunnelRegistration{s}, err +} + +type RegistrationOptions struct{ capnp.Struct } + +// RegistrationOptions_TypeID is the unique identifier for the type RegistrationOptions. +const RegistrationOptions_TypeID = 0xc793e50592935b4a + +func NewRegistrationOptions(s *capnp.Segment) (RegistrationOptions, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 8}) + return RegistrationOptions{st}, err +} + +func NewRootRegistrationOptions(s *capnp.Segment) (RegistrationOptions, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 8}) + return RegistrationOptions{st}, err +} + +func ReadRootRegistrationOptions(msg *capnp.Message) (RegistrationOptions, error) { + root, err := msg.RootPtr() + return RegistrationOptions{root.Struct()}, err +} + +func (s RegistrationOptions) String() string { + str, _ := text.Marshal(0xc793e50592935b4a, s.Struct) + return str +} + +func (s RegistrationOptions) ClientId() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s RegistrationOptions) HasClientId() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) ClientIdBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetClientId(v string) error { + return s.Struct.SetText(0, v) +} + +func (s RegistrationOptions) Version() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s RegistrationOptions) HasVersion() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) VersionBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetVersion(v string) error { + return s.Struct.SetText(1, v) +} + +func (s RegistrationOptions) Os() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s RegistrationOptions) HasOs() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) OsBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetOs(v string) error { + return s.Struct.SetText(2, v) +} + +func (s RegistrationOptions) ExistingTunnelPolicy() ExistingTunnelPolicy { + return ExistingTunnelPolicy(s.Struct.Uint16(0)) +} + +func (s RegistrationOptions) SetExistingTunnelPolicy(v ExistingTunnelPolicy) { + s.Struct.SetUint16(0, uint16(v)) +} + +func (s RegistrationOptions) PoolName() (string, error) { + p, err := s.Struct.Ptr(3) + return p.Text(), err +} + +func (s RegistrationOptions) HasPoolName() bool { + p, err := s.Struct.Ptr(3) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) PoolNameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(3) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetPoolName(v string) error { + return s.Struct.SetText(3, v) +} + +func (s RegistrationOptions) Tags() (Tag_List, error) { + p, err := s.Struct.Ptr(4) + return Tag_List{List: p.List()}, err +} + +func (s RegistrationOptions) HasTags() bool { + p, err := s.Struct.Ptr(4) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) SetTags(v Tag_List) error { + return s.Struct.SetPtr(4, v.List.ToPtr()) +} + +// NewTags sets the tags field to a newly +// allocated Tag_List, preferring placement in s's segment. +func (s RegistrationOptions) NewTags(n int32) (Tag_List, error) { + l, err := NewTag_List(s.Struct.Segment(), n) + if err != nil { + return Tag_List{}, err + } + err = s.Struct.SetPtr(4, l.List.ToPtr()) + return l, err +} + +func (s RegistrationOptions) ConnectionId() uint8 { + return s.Struct.Uint8(2) +} + +func (s RegistrationOptions) SetConnectionId(v uint8) { + s.Struct.SetUint8(2, v) +} + +func (s RegistrationOptions) OriginLocalIp() (string, error) { + p, err := s.Struct.Ptr(5) + return p.Text(), err +} + +func (s RegistrationOptions) HasOriginLocalIp() bool { + p, err := s.Struct.Ptr(5) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) OriginLocalIpBytes() ([]byte, error) { + p, err := s.Struct.Ptr(5) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetOriginLocalIp(v string) error { + return s.Struct.SetText(5, v) +} + +func (s RegistrationOptions) IsAutoupdated() bool { + return s.Struct.Bit(24) +} + +func (s RegistrationOptions) SetIsAutoupdated(v bool) { + s.Struct.SetBit(24, v) +} + +func (s RegistrationOptions) RunFromTerminal() bool { + return s.Struct.Bit(25) +} + +func (s RegistrationOptions) SetRunFromTerminal(v bool) { + s.Struct.SetBit(25, v) +} + +func (s RegistrationOptions) CompressionQuality() uint64 { + return s.Struct.Uint64(8) +} + +func (s RegistrationOptions) SetCompressionQuality(v uint64) { + s.Struct.SetUint64(8, v) +} + +func (s RegistrationOptions) Uuid() (string, error) { + p, err := s.Struct.Ptr(6) + return p.Text(), err +} + +func (s RegistrationOptions) HasUuid() bool { + p, err := s.Struct.Ptr(6) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) UuidBytes() ([]byte, error) { + p, err := s.Struct.Ptr(6) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetUuid(v string) error { + return s.Struct.SetText(6, v) +} + +func (s RegistrationOptions) NumPreviousAttempts() uint8 { + return s.Struct.Uint8(4) +} + +func (s RegistrationOptions) SetNumPreviousAttempts(v uint8) { + s.Struct.SetUint8(4, v) +} + +func (s RegistrationOptions) Features() (capnp.TextList, error) { + p, err := s.Struct.Ptr(7) + return capnp.TextList{List: p.List()}, err +} + +func (s RegistrationOptions) HasFeatures() bool { + p, err := s.Struct.Ptr(7) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) SetFeatures(v capnp.TextList) error { + return s.Struct.SetPtr(7, v.List.ToPtr()) +} + +// NewFeatures sets the features field to a newly +// allocated capnp.TextList, preferring placement in s's segment. +func (s RegistrationOptions) NewFeatures(n int32) (capnp.TextList, error) { + l, err := capnp.NewTextList(s.Struct.Segment(), n) + if err != nil { + return capnp.TextList{}, err + } + err = s.Struct.SetPtr(7, l.List.ToPtr()) + return l, err +} + +// RegistrationOptions_List is a list of RegistrationOptions. +type RegistrationOptions_List struct{ capnp.List } + +// NewRegistrationOptions creates a new list of RegistrationOptions. +func NewRegistrationOptions_List(s *capnp.Segment, sz int32) (RegistrationOptions_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 16, PointerCount: 8}, sz) + return RegistrationOptions_List{l}, err +} + +func (s RegistrationOptions_List) At(i int) RegistrationOptions { + return RegistrationOptions{s.List.Struct(i)} +} + +func (s RegistrationOptions_List) Set(i int, v RegistrationOptions) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationOptions_List) String() string { + str, _ := text.MarshalList(0xc793e50592935b4a, s.List) + return str +} + +// RegistrationOptions_Promise is a wrapper for a RegistrationOptions promised by a client call. +type RegistrationOptions_Promise struct{ *capnp.Pipeline } + +func (p RegistrationOptions_Promise) Struct() (RegistrationOptions, error) { + s, err := p.Pipeline.Struct() + return RegistrationOptions{s}, err +} + +type ExistingTunnelPolicy uint16 + +// ExistingTunnelPolicy_TypeID is the unique identifier for the type ExistingTunnelPolicy. +const ExistingTunnelPolicy_TypeID = 0x84cb9536a2cf6d3c + +// Values of ExistingTunnelPolicy. +const ( + ExistingTunnelPolicy_ignore ExistingTunnelPolicy = 0 + ExistingTunnelPolicy_disconnect ExistingTunnelPolicy = 1 + ExistingTunnelPolicy_balance ExistingTunnelPolicy = 2 +) + +// String returns the enum's constant name. +func (c ExistingTunnelPolicy) String() string { + switch c { + case ExistingTunnelPolicy_ignore: + return "ignore" + case ExistingTunnelPolicy_disconnect: + return "disconnect" + case ExistingTunnelPolicy_balance: + return "balance" + + default: + return "" + } +} + +// ExistingTunnelPolicyFromString returns the enum value with a name, +// or the zero value if there's no such value. +func ExistingTunnelPolicyFromString(c string) ExistingTunnelPolicy { + switch c { + case "ignore": + return ExistingTunnelPolicy_ignore + case "disconnect": + return ExistingTunnelPolicy_disconnect + case "balance": + return ExistingTunnelPolicy_balance + + default: + return 0 + } +} + +type ExistingTunnelPolicy_List struct{ capnp.List } + +func NewExistingTunnelPolicy_List(s *capnp.Segment, sz int32) (ExistingTunnelPolicy_List, error) { + l, err := capnp.NewUInt16List(s, sz) + return ExistingTunnelPolicy_List{l.List}, err +} + +func (l ExistingTunnelPolicy_List) At(i int) ExistingTunnelPolicy { + ul := capnp.UInt16List{List: l.List} + return ExistingTunnelPolicy(ul.At(i)) +} + +func (l ExistingTunnelPolicy_List) Set(i int, v ExistingTunnelPolicy) { + ul := capnp.UInt16List{List: l.List} + ul.Set(i, uint16(v)) +} + +type ServerInfo struct{ capnp.Struct } + +// ServerInfo_TypeID is the unique identifier for the type ServerInfo. +const ServerInfo_TypeID = 0xf2c68e2547ec3866 + +func NewServerInfo(s *capnp.Segment) (ServerInfo, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ServerInfo{st}, err +} + +func NewRootServerInfo(s *capnp.Segment) (ServerInfo, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ServerInfo{st}, err +} + +func ReadRootServerInfo(msg *capnp.Message) (ServerInfo, error) { + root, err := msg.RootPtr() + return ServerInfo{root.Struct()}, err +} + +func (s ServerInfo) String() string { + str, _ := text.Marshal(0xf2c68e2547ec3866, s.Struct) + return str +} + +func (s ServerInfo) LocationName() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s ServerInfo) HasLocationName() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ServerInfo) LocationNameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s ServerInfo) SetLocationName(v string) error { + return s.Struct.SetText(0, v) +} + +// ServerInfo_List is a list of ServerInfo. +type ServerInfo_List struct{ capnp.List } + +// NewServerInfo creates a new list of ServerInfo. +func NewServerInfo_List(s *capnp.Segment, sz int32) (ServerInfo_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return ServerInfo_List{l}, err +} + +func (s ServerInfo_List) At(i int) ServerInfo { return ServerInfo{s.List.Struct(i)} } + +func (s ServerInfo_List) Set(i int, v ServerInfo) error { return s.List.SetStruct(i, v.Struct) } + +func (s ServerInfo_List) String() string { + str, _ := text.MarshalList(0xf2c68e2547ec3866, s.List) + return str +} + +// ServerInfo_Promise is a wrapper for a ServerInfo promised by a client call. +type ServerInfo_Promise struct{ *capnp.Pipeline } + +func (p ServerInfo_Promise) Struct() (ServerInfo, error) { + s, err := p.Pipeline.Struct() + return ServerInfo{s}, err +} + +type AuthenticateResponse struct{ capnp.Struct } + +// AuthenticateResponse_TypeID is the unique identifier for the type AuthenticateResponse. +const AuthenticateResponse_TypeID = 0x82c325a07ad22a65 + +func NewAuthenticateResponse(s *capnp.Segment) (AuthenticateResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return AuthenticateResponse{st}, err +} + +func NewRootAuthenticateResponse(s *capnp.Segment) (AuthenticateResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return AuthenticateResponse{st}, err +} + +func ReadRootAuthenticateResponse(msg *capnp.Message) (AuthenticateResponse, error) { + root, err := msg.RootPtr() + return AuthenticateResponse{root.Struct()}, err +} + +func (s AuthenticateResponse) String() string { + str, _ := text.Marshal(0x82c325a07ad22a65, s.Struct) + return str +} + +func (s AuthenticateResponse) PermanentErr() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s AuthenticateResponse) HasPermanentErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s AuthenticateResponse) PermanentErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s AuthenticateResponse) SetPermanentErr(v string) error { + return s.Struct.SetText(0, v) +} + +func (s AuthenticateResponse) RetryableErr() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s AuthenticateResponse) HasRetryableErr() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s AuthenticateResponse) RetryableErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s AuthenticateResponse) SetRetryableErr(v string) error { + return s.Struct.SetText(1, v) +} + +func (s AuthenticateResponse) Jwt() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return []byte(p.Data()), err +} + +func (s AuthenticateResponse) HasJwt() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s AuthenticateResponse) SetJwt(v []byte) error { + return s.Struct.SetData(2, v) +} + +func (s AuthenticateResponse) HoursUntilRefresh() uint8 { + return s.Struct.Uint8(0) +} + +func (s AuthenticateResponse) SetHoursUntilRefresh(v uint8) { + s.Struct.SetUint8(0, v) +} + +// AuthenticateResponse_List is a list of AuthenticateResponse. +type AuthenticateResponse_List struct{ capnp.List } + +// NewAuthenticateResponse creates a new list of AuthenticateResponse. +func NewAuthenticateResponse_List(s *capnp.Segment, sz int32) (AuthenticateResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}, sz) + return AuthenticateResponse_List{l}, err +} + +func (s AuthenticateResponse_List) At(i int) AuthenticateResponse { + return AuthenticateResponse{s.List.Struct(i)} +} + +func (s AuthenticateResponse_List) Set(i int, v AuthenticateResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s AuthenticateResponse_List) String() string { + str, _ := text.MarshalList(0x82c325a07ad22a65, s.List) + return str +} + +// AuthenticateResponse_Promise is a wrapper for a AuthenticateResponse promised by a client call. +type AuthenticateResponse_Promise struct{ *capnp.Pipeline } + +func (p AuthenticateResponse_Promise) Struct() (AuthenticateResponse, error) { + s, err := p.Pipeline.Struct() + return AuthenticateResponse{s}, err +} + +type TunnelServer struct{ Client capnp.Client } + +// TunnelServer_TypeID is the unique identifier for the type TunnelServer. +const TunnelServer_TypeID = 0xea58385c65416035 + +func (c TunnelServer) RegisterTunnel(ctx context.Context, params func(TunnelServer_registerTunnel_Params) error, opts ...capnp.CallOption) TunnelServer_registerTunnel_Results_Promise { + if c.Client == nil { + return TunnelServer_registerTunnel_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "registerTunnel", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_registerTunnel_Params{Struct: s}) } + } + return TunnelServer_registerTunnel_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) GetServerInfo(ctx context.Context, params func(TunnelServer_getServerInfo_Params) error, opts ...capnp.CallOption) TunnelServer_getServerInfo_Results_Promise { + if c.Client == nil { + return TunnelServer_getServerInfo_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "getServerInfo", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_getServerInfo_Params{Struct: s}) } + } + return TunnelServer_getServerInfo_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) UnregisterTunnel(ctx context.Context, params func(TunnelServer_unregisterTunnel_Params) error, opts ...capnp.CallOption) TunnelServer_unregisterTunnel_Results_Promise { + if c.Client == nil { + return TunnelServer_unregisterTunnel_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "unregisterTunnel", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_unregisterTunnel_Params{Struct: s}) } + } + return TunnelServer_unregisterTunnel_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) ObsoleteDeclarativeTunnelConnect(ctx context.Context, params func(TunnelServer_obsoleteDeclarativeTunnelConnect_Params) error, opts ...capnp.CallOption) TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise { + if c.Client == nil { + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 3, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "obsoleteDeclarativeTunnelConnect", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { + return params(TunnelServer_obsoleteDeclarativeTunnelConnect_Params{Struct: s}) + } + } + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) Authenticate(ctx context.Context, params func(TunnelServer_authenticate_Params) error, opts ...capnp.CallOption) TunnelServer_authenticate_Results_Promise { + if c.Client == nil { + return TunnelServer_authenticate_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 4, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "authenticate", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_authenticate_Params{Struct: s}) } + } + return TunnelServer_authenticate_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) ReconnectTunnel(ctx context.Context, params func(TunnelServer_reconnectTunnel_Params) error, opts ...capnp.CallOption) TunnelServer_reconnectTunnel_Results_Promise { + if c.Client == nil { + return TunnelServer_reconnectTunnel_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 5, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "reconnectTunnel", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 5} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_reconnectTunnel_Params{Struct: s}) } + } + return TunnelServer_reconnectTunnel_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) RegisterConnection(ctx context.Context, params func(RegistrationServer_registerConnection_Params) error, opts ...capnp.CallOption) RegistrationServer_registerConnection_Results_Promise { + if c.Client == nil { + return RegistrationServer_registerConnection_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "registerConnection", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(RegistrationServer_registerConnection_Params{Struct: s}) } + } + return RegistrationServer_registerConnection_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) UnregisterConnection(ctx context.Context, params func(RegistrationServer_unregisterConnection_Params) error, opts ...capnp.CallOption) RegistrationServer_unregisterConnection_Results_Promise { + if c.Client == nil { + return RegistrationServer_unregisterConnection_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "unregisterConnection", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { return params(RegistrationServer_unregisterConnection_Params{Struct: s}) } + } + return RegistrationServer_unregisterConnection_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) UpdateLocalConfiguration(ctx context.Context, params func(RegistrationServer_updateLocalConfiguration_Params) error, opts ...capnp.CallOption) RegistrationServer_updateLocalConfiguration_Results_Promise { + if c.Client == nil { + return RegistrationServer_updateLocalConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "updateLocalConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { + return params(RegistrationServer_updateLocalConfiguration_Params{Struct: s}) + } + } + return RegistrationServer_updateLocalConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type TunnelServer_Server interface { + RegisterTunnel(TunnelServer_registerTunnel) error + + GetServerInfo(TunnelServer_getServerInfo) error + + UnregisterTunnel(TunnelServer_unregisterTunnel) error + + ObsoleteDeclarativeTunnelConnect(TunnelServer_obsoleteDeclarativeTunnelConnect) error + + Authenticate(TunnelServer_authenticate) error + + ReconnectTunnel(TunnelServer_reconnectTunnel) error + + RegisterConnection(RegistrationServer_registerConnection) error + + UnregisterConnection(RegistrationServer_unregisterConnection) error + + UpdateLocalConfiguration(RegistrationServer_updateLocalConfiguration) error +} + +func TunnelServer_ServerToClient(s TunnelServer_Server) TunnelServer { + c, _ := s.(server.Closer) + return TunnelServer{Client: server.New(TunnelServer_Methods(nil, s), c)} +} + +func TunnelServer_Methods(methods []server.Method, s TunnelServer_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 9) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "registerTunnel", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_registerTunnel{c, opts, TunnelServer_registerTunnel_Params{Struct: p}, TunnelServer_registerTunnel_Results{Struct: r}} + return s.RegisterTunnel(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "getServerInfo", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_getServerInfo{c, opts, TunnelServer_getServerInfo_Params{Struct: p}, TunnelServer_getServerInfo_Results{Struct: r}} + return s.GetServerInfo(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "unregisterTunnel", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_unregisterTunnel{c, opts, TunnelServer_unregisterTunnel_Params{Struct: p}, TunnelServer_unregisterTunnel_Results{Struct: r}} + return s.UnregisterTunnel(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 3, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "obsoleteDeclarativeTunnelConnect", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_obsoleteDeclarativeTunnelConnect{c, opts, TunnelServer_obsoleteDeclarativeTunnelConnect_Params{Struct: p}, TunnelServer_obsoleteDeclarativeTunnelConnect_Results{Struct: r}} + return s.ObsoleteDeclarativeTunnelConnect(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 4, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "authenticate", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_authenticate{c, opts, TunnelServer_authenticate_Params{Struct: p}, TunnelServer_authenticate_Results{Struct: r}} + return s.Authenticate(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 5, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "reconnectTunnel", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_reconnectTunnel{c, opts, TunnelServer_reconnectTunnel_Params{Struct: p}, TunnelServer_reconnectTunnel_Results{Struct: r}} + return s.ReconnectTunnel(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "registerConnection", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_registerConnection{c, opts, RegistrationServer_registerConnection_Params{Struct: p}, RegistrationServer_registerConnection_Results{Struct: r}} + return s.RegisterConnection(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "unregisterConnection", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_unregisterConnection{c, opts, RegistrationServer_unregisterConnection_Params{Struct: p}, RegistrationServer_unregisterConnection_Results{Struct: r}} + return s.UnregisterConnection(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "updateLocalConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_updateLocalConfiguration{c, opts, RegistrationServer_updateLocalConfiguration_Params{Struct: p}, RegistrationServer_updateLocalConfiguration_Results{Struct: r}} + return s.UpdateLocalConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + return methods +} + +// TunnelServer_registerTunnel holds the arguments for a server call to TunnelServer.registerTunnel. +type TunnelServer_registerTunnel struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_registerTunnel_Params + Results TunnelServer_registerTunnel_Results +} + +// TunnelServer_getServerInfo holds the arguments for a server call to TunnelServer.getServerInfo. +type TunnelServer_getServerInfo struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_getServerInfo_Params + Results TunnelServer_getServerInfo_Results +} + +// TunnelServer_unregisterTunnel holds the arguments for a server call to TunnelServer.unregisterTunnel. +type TunnelServer_unregisterTunnel struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_unregisterTunnel_Params + Results TunnelServer_unregisterTunnel_Results +} + +// TunnelServer_obsoleteDeclarativeTunnelConnect holds the arguments for a server call to TunnelServer.obsoleteDeclarativeTunnelConnect. +type TunnelServer_obsoleteDeclarativeTunnelConnect struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_obsoleteDeclarativeTunnelConnect_Params + Results TunnelServer_obsoleteDeclarativeTunnelConnect_Results +} + +// TunnelServer_authenticate holds the arguments for a server call to TunnelServer.authenticate. +type TunnelServer_authenticate struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_authenticate_Params + Results TunnelServer_authenticate_Results +} + +// TunnelServer_reconnectTunnel holds the arguments for a server call to TunnelServer.reconnectTunnel. +type TunnelServer_reconnectTunnel struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_reconnectTunnel_Params + Results TunnelServer_reconnectTunnel_Results +} + +type TunnelServer_registerTunnel_Params struct{ capnp.Struct } + +// TunnelServer_registerTunnel_Params_TypeID is the unique identifier for the type TunnelServer_registerTunnel_Params. +const TunnelServer_registerTunnel_Params_TypeID = 0xb70431c0dc014915 + +func NewTunnelServer_registerTunnel_Params(s *capnp.Segment) (TunnelServer_registerTunnel_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_registerTunnel_Params{st}, err +} + +func NewRootTunnelServer_registerTunnel_Params(s *capnp.Segment) (TunnelServer_registerTunnel_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_registerTunnel_Params{st}, err +} + +func ReadRootTunnelServer_registerTunnel_Params(msg *capnp.Message) (TunnelServer_registerTunnel_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_registerTunnel_Params{root.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Params) String() string { + str, _ := text.Marshal(0xb70431c0dc014915, s.Struct) + return str +} + +func (s TunnelServer_registerTunnel_Params) OriginCert() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s TunnelServer_registerTunnel_Params) HasOriginCert() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) SetOriginCert(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s TunnelServer_registerTunnel_Params) Hostname() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s TunnelServer_registerTunnel_Params) HasHostname() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) HostnameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s TunnelServer_registerTunnel_Params) SetHostname(v string) error { + return s.Struct.SetText(1, v) +} + +func (s TunnelServer_registerTunnel_Params) Options() (RegistrationOptions, error) { + p, err := s.Struct.Ptr(2) + return RegistrationOptions{Struct: p.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Params) HasOptions() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) SetOptions(v RegistrationOptions) error { + return s.Struct.SetPtr(2, v.Struct.ToPtr()) +} + +// NewOptions sets the options field to a newly +// allocated RegistrationOptions struct, preferring placement in s's segment. +func (s TunnelServer_registerTunnel_Params) NewOptions() (RegistrationOptions, error) { + ss, err := NewRegistrationOptions(s.Struct.Segment()) + if err != nil { + return RegistrationOptions{}, err + } + err = s.Struct.SetPtr(2, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_registerTunnel_Params_List is a list of TunnelServer_registerTunnel_Params. +type TunnelServer_registerTunnel_Params_List struct{ capnp.List } + +// NewTunnelServer_registerTunnel_Params creates a new list of TunnelServer_registerTunnel_Params. +func NewTunnelServer_registerTunnel_Params_List(s *capnp.Segment, sz int32) (TunnelServer_registerTunnel_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}, sz) + return TunnelServer_registerTunnel_Params_List{l}, err +} + +func (s TunnelServer_registerTunnel_Params_List) At(i int) TunnelServer_registerTunnel_Params { + return TunnelServer_registerTunnel_Params{s.List.Struct(i)} +} + +func (s TunnelServer_registerTunnel_Params_List) Set(i int, v TunnelServer_registerTunnel_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_registerTunnel_Params_List) String() string { + str, _ := text.MarshalList(0xb70431c0dc014915, s.List) + return str +} + +// TunnelServer_registerTunnel_Params_Promise is a wrapper for a TunnelServer_registerTunnel_Params promised by a client call. +type TunnelServer_registerTunnel_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_registerTunnel_Params_Promise) Struct() (TunnelServer_registerTunnel_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_registerTunnel_Params{s}, err +} + +func (p TunnelServer_registerTunnel_Params_Promise) Options() RegistrationOptions_Promise { + return RegistrationOptions_Promise{Pipeline: p.Pipeline.GetPipeline(2)} +} + +type TunnelServer_registerTunnel_Results struct{ capnp.Struct } + +// TunnelServer_registerTunnel_Results_TypeID is the unique identifier for the type TunnelServer_registerTunnel_Results. +const TunnelServer_registerTunnel_Results_TypeID = 0xf2c122394f447e8e + +func NewTunnelServer_registerTunnel_Results(s *capnp.Segment) (TunnelServer_registerTunnel_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_registerTunnel_Results{st}, err +} + +func NewRootTunnelServer_registerTunnel_Results(s *capnp.Segment) (TunnelServer_registerTunnel_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_registerTunnel_Results{st}, err +} + +func ReadRootTunnelServer_registerTunnel_Results(msg *capnp.Message) (TunnelServer_registerTunnel_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_registerTunnel_Results{root.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Results) String() string { + str, _ := text.Marshal(0xf2c122394f447e8e, s.Struct) + return str +} + +func (s TunnelServer_registerTunnel_Results) Result() (TunnelRegistration, error) { + p, err := s.Struct.Ptr(0) + return TunnelRegistration{Struct: p.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Results) SetResult(v TunnelRegistration) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated TunnelRegistration struct, preferring placement in s's segment. +func (s TunnelServer_registerTunnel_Results) NewResult() (TunnelRegistration, error) { + ss, err := NewTunnelRegistration(s.Struct.Segment()) + if err != nil { + return TunnelRegistration{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_registerTunnel_Results_List is a list of TunnelServer_registerTunnel_Results. +type TunnelServer_registerTunnel_Results_List struct{ capnp.List } + +// NewTunnelServer_registerTunnel_Results creates a new list of TunnelServer_registerTunnel_Results. +func NewTunnelServer_registerTunnel_Results_List(s *capnp.Segment, sz int32) (TunnelServer_registerTunnel_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_registerTunnel_Results_List{l}, err +} + +func (s TunnelServer_registerTunnel_Results_List) At(i int) TunnelServer_registerTunnel_Results { + return TunnelServer_registerTunnel_Results{s.List.Struct(i)} +} + +func (s TunnelServer_registerTunnel_Results_List) Set(i int, v TunnelServer_registerTunnel_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_registerTunnel_Results_List) String() string { + str, _ := text.MarshalList(0xf2c122394f447e8e, s.List) + return str +} + +// TunnelServer_registerTunnel_Results_Promise is a wrapper for a TunnelServer_registerTunnel_Results promised by a client call. +type TunnelServer_registerTunnel_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_registerTunnel_Results_Promise) Struct() (TunnelServer_registerTunnel_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_registerTunnel_Results{s}, err +} + +func (p TunnelServer_registerTunnel_Results_Promise) Result() TunnelRegistration_Promise { + return TunnelRegistration_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type TunnelServer_getServerInfo_Params struct{ capnp.Struct } + +// TunnelServer_getServerInfo_Params_TypeID is the unique identifier for the type TunnelServer_getServerInfo_Params. +const TunnelServer_getServerInfo_Params_TypeID = 0xdc3ed6801961e502 + +func NewTunnelServer_getServerInfo_Params(s *capnp.Segment) (TunnelServer_getServerInfo_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_getServerInfo_Params{st}, err +} + +func NewRootTunnelServer_getServerInfo_Params(s *capnp.Segment) (TunnelServer_getServerInfo_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_getServerInfo_Params{st}, err +} + +func ReadRootTunnelServer_getServerInfo_Params(msg *capnp.Message) (TunnelServer_getServerInfo_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_getServerInfo_Params{root.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Params) String() string { + str, _ := text.Marshal(0xdc3ed6801961e502, s.Struct) + return str +} + +// TunnelServer_getServerInfo_Params_List is a list of TunnelServer_getServerInfo_Params. +type TunnelServer_getServerInfo_Params_List struct{ capnp.List } + +// NewTunnelServer_getServerInfo_Params creates a new list of TunnelServer_getServerInfo_Params. +func NewTunnelServer_getServerInfo_Params_List(s *capnp.Segment, sz int32) (TunnelServer_getServerInfo_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return TunnelServer_getServerInfo_Params_List{l}, err +} + +func (s TunnelServer_getServerInfo_Params_List) At(i int) TunnelServer_getServerInfo_Params { + return TunnelServer_getServerInfo_Params{s.List.Struct(i)} +} + +func (s TunnelServer_getServerInfo_Params_List) Set(i int, v TunnelServer_getServerInfo_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_getServerInfo_Params_List) String() string { + str, _ := text.MarshalList(0xdc3ed6801961e502, s.List) + return str +} + +// TunnelServer_getServerInfo_Params_Promise is a wrapper for a TunnelServer_getServerInfo_Params promised by a client call. +type TunnelServer_getServerInfo_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_getServerInfo_Params_Promise) Struct() (TunnelServer_getServerInfo_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_getServerInfo_Params{s}, err +} + +type TunnelServer_getServerInfo_Results struct{ capnp.Struct } + +// TunnelServer_getServerInfo_Results_TypeID is the unique identifier for the type TunnelServer_getServerInfo_Results. +const TunnelServer_getServerInfo_Results_TypeID = 0xe3e37d096a5b564e + +func NewTunnelServer_getServerInfo_Results(s *capnp.Segment) (TunnelServer_getServerInfo_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_getServerInfo_Results{st}, err +} + +func NewRootTunnelServer_getServerInfo_Results(s *capnp.Segment) (TunnelServer_getServerInfo_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_getServerInfo_Results{st}, err +} + +func ReadRootTunnelServer_getServerInfo_Results(msg *capnp.Message) (TunnelServer_getServerInfo_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_getServerInfo_Results{root.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Results) String() string { + str, _ := text.Marshal(0xe3e37d096a5b564e, s.Struct) + return str +} + +func (s TunnelServer_getServerInfo_Results) Result() (ServerInfo, error) { + p, err := s.Struct.Ptr(0) + return ServerInfo{Struct: p.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_getServerInfo_Results) SetResult(v ServerInfo) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated ServerInfo struct, preferring placement in s's segment. +func (s TunnelServer_getServerInfo_Results) NewResult() (ServerInfo, error) { + ss, err := NewServerInfo(s.Struct.Segment()) + if err != nil { + return ServerInfo{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_getServerInfo_Results_List is a list of TunnelServer_getServerInfo_Results. +type TunnelServer_getServerInfo_Results_List struct{ capnp.List } + +// NewTunnelServer_getServerInfo_Results creates a new list of TunnelServer_getServerInfo_Results. +func NewTunnelServer_getServerInfo_Results_List(s *capnp.Segment, sz int32) (TunnelServer_getServerInfo_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_getServerInfo_Results_List{l}, err +} + +func (s TunnelServer_getServerInfo_Results_List) At(i int) TunnelServer_getServerInfo_Results { + return TunnelServer_getServerInfo_Results{s.List.Struct(i)} +} + +func (s TunnelServer_getServerInfo_Results_List) Set(i int, v TunnelServer_getServerInfo_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_getServerInfo_Results_List) String() string { + str, _ := text.MarshalList(0xe3e37d096a5b564e, s.List) + return str +} + +// TunnelServer_getServerInfo_Results_Promise is a wrapper for a TunnelServer_getServerInfo_Results promised by a client call. +type TunnelServer_getServerInfo_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_getServerInfo_Results_Promise) Struct() (TunnelServer_getServerInfo_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_getServerInfo_Results{s}, err +} + +func (p TunnelServer_getServerInfo_Results_Promise) Result() ServerInfo_Promise { + return ServerInfo_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type TunnelServer_unregisterTunnel_Params struct{ capnp.Struct } + +// TunnelServer_unregisterTunnel_Params_TypeID is the unique identifier for the type TunnelServer_unregisterTunnel_Params. +const TunnelServer_unregisterTunnel_Params_TypeID = 0x9b87b390babc2ccf + +func NewTunnelServer_unregisterTunnel_Params(s *capnp.Segment) (TunnelServer_unregisterTunnel_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 0}) + return TunnelServer_unregisterTunnel_Params{st}, err +} + +func NewRootTunnelServer_unregisterTunnel_Params(s *capnp.Segment) (TunnelServer_unregisterTunnel_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 0}) + return TunnelServer_unregisterTunnel_Params{st}, err +} + +func ReadRootTunnelServer_unregisterTunnel_Params(msg *capnp.Message) (TunnelServer_unregisterTunnel_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_unregisterTunnel_Params{root.Struct()}, err +} + +func (s TunnelServer_unregisterTunnel_Params) String() string { + str, _ := text.Marshal(0x9b87b390babc2ccf, s.Struct) + return str +} + +func (s TunnelServer_unregisterTunnel_Params) GracePeriodNanoSec() int64 { + return int64(s.Struct.Uint64(0)) +} + +func (s TunnelServer_unregisterTunnel_Params) SetGracePeriodNanoSec(v int64) { + s.Struct.SetUint64(0, uint64(v)) +} + +// TunnelServer_unregisterTunnel_Params_List is a list of TunnelServer_unregisterTunnel_Params. +type TunnelServer_unregisterTunnel_Params_List struct{ capnp.List } + +// NewTunnelServer_unregisterTunnel_Params creates a new list of TunnelServer_unregisterTunnel_Params. +func NewTunnelServer_unregisterTunnel_Params_List(s *capnp.Segment, sz int32) (TunnelServer_unregisterTunnel_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 0}, sz) + return TunnelServer_unregisterTunnel_Params_List{l}, err +} + +func (s TunnelServer_unregisterTunnel_Params_List) At(i int) TunnelServer_unregisterTunnel_Params { + return TunnelServer_unregisterTunnel_Params{s.List.Struct(i)} +} + +func (s TunnelServer_unregisterTunnel_Params_List) Set(i int, v TunnelServer_unregisterTunnel_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_unregisterTunnel_Params_List) String() string { + str, _ := text.MarshalList(0x9b87b390babc2ccf, s.List) + return str +} + +// TunnelServer_unregisterTunnel_Params_Promise is a wrapper for a TunnelServer_unregisterTunnel_Params promised by a client call. +type TunnelServer_unregisterTunnel_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_unregisterTunnel_Params_Promise) Struct() (TunnelServer_unregisterTunnel_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_unregisterTunnel_Params{s}, err +} + +type TunnelServer_unregisterTunnel_Results struct{ capnp.Struct } + +// TunnelServer_unregisterTunnel_Results_TypeID is the unique identifier for the type TunnelServer_unregisterTunnel_Results. +const TunnelServer_unregisterTunnel_Results_TypeID = 0xa29a916d4ebdd894 + +func NewTunnelServer_unregisterTunnel_Results(s *capnp.Segment) (TunnelServer_unregisterTunnel_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_unregisterTunnel_Results{st}, err +} + +func NewRootTunnelServer_unregisterTunnel_Results(s *capnp.Segment) (TunnelServer_unregisterTunnel_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_unregisterTunnel_Results{st}, err +} + +func ReadRootTunnelServer_unregisterTunnel_Results(msg *capnp.Message) (TunnelServer_unregisterTunnel_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_unregisterTunnel_Results{root.Struct()}, err +} + +func (s TunnelServer_unregisterTunnel_Results) String() string { + str, _ := text.Marshal(0xa29a916d4ebdd894, s.Struct) + return str +} + +// TunnelServer_unregisterTunnel_Results_List is a list of TunnelServer_unregisterTunnel_Results. +type TunnelServer_unregisterTunnel_Results_List struct{ capnp.List } + +// NewTunnelServer_unregisterTunnel_Results creates a new list of TunnelServer_unregisterTunnel_Results. +func NewTunnelServer_unregisterTunnel_Results_List(s *capnp.Segment, sz int32) (TunnelServer_unregisterTunnel_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return TunnelServer_unregisterTunnel_Results_List{l}, err +} + +func (s TunnelServer_unregisterTunnel_Results_List) At(i int) TunnelServer_unregisterTunnel_Results { + return TunnelServer_unregisterTunnel_Results{s.List.Struct(i)} +} + +func (s TunnelServer_unregisterTunnel_Results_List) Set(i int, v TunnelServer_unregisterTunnel_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_unregisterTunnel_Results_List) String() string { + str, _ := text.MarshalList(0xa29a916d4ebdd894, s.List) + return str +} + +// TunnelServer_unregisterTunnel_Results_Promise is a wrapper for a TunnelServer_unregisterTunnel_Results promised by a client call. +type TunnelServer_unregisterTunnel_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_unregisterTunnel_Results_Promise) Struct() (TunnelServer_unregisterTunnel_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_unregisterTunnel_Results{s}, err +} + +type TunnelServer_obsoleteDeclarativeTunnelConnect_Params struct{ capnp.Struct } + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Params_TypeID is the unique identifier for the type TunnelServer_obsoleteDeclarativeTunnelConnect_Params. +const TunnelServer_obsoleteDeclarativeTunnelConnect_Params_TypeID = 0xa766b24d4fe5da35 + +func NewTunnelServer_obsoleteDeclarativeTunnelConnect_Params(s *capnp.Segment) (TunnelServer_obsoleteDeclarativeTunnelConnect_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params{st}, err +} + +func NewRootTunnelServer_obsoleteDeclarativeTunnelConnect_Params(s *capnp.Segment) (TunnelServer_obsoleteDeclarativeTunnelConnect_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params{st}, err +} + +func ReadRootTunnelServer_obsoleteDeclarativeTunnelConnect_Params(msg *capnp.Message) (TunnelServer_obsoleteDeclarativeTunnelConnect_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params{root.Struct()}, err +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Params) String() string { + str, _ := text.Marshal(0xa766b24d4fe5da35, s.Struct) + return str +} + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List is a list of TunnelServer_obsoleteDeclarativeTunnelConnect_Params. +type TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List struct{ capnp.List } + +// NewTunnelServer_obsoleteDeclarativeTunnelConnect_Params creates a new list of TunnelServer_obsoleteDeclarativeTunnelConnect_Params. +func NewTunnelServer_obsoleteDeclarativeTunnelConnect_Params_List(s *capnp.Segment, sz int32) (TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List{l}, err +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List) At(i int) TunnelServer_obsoleteDeclarativeTunnelConnect_Params { + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params{s.List.Struct(i)} +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List) Set(i int, v TunnelServer_obsoleteDeclarativeTunnelConnect_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Params_List) String() string { + str, _ := text.MarshalList(0xa766b24d4fe5da35, s.List) + return str +} + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Params_Promise is a wrapper for a TunnelServer_obsoleteDeclarativeTunnelConnect_Params promised by a client call. +type TunnelServer_obsoleteDeclarativeTunnelConnect_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_obsoleteDeclarativeTunnelConnect_Params_Promise) Struct() (TunnelServer_obsoleteDeclarativeTunnelConnect_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_obsoleteDeclarativeTunnelConnect_Params{s}, err +} + +type TunnelServer_obsoleteDeclarativeTunnelConnect_Results struct{ capnp.Struct } + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Results_TypeID is the unique identifier for the type TunnelServer_obsoleteDeclarativeTunnelConnect_Results. +const TunnelServer_obsoleteDeclarativeTunnelConnect_Results_TypeID = 0xfeac5c8f4899ef7c + +func NewTunnelServer_obsoleteDeclarativeTunnelConnect_Results(s *capnp.Segment) (TunnelServer_obsoleteDeclarativeTunnelConnect_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results{st}, err +} + +func NewRootTunnelServer_obsoleteDeclarativeTunnelConnect_Results(s *capnp.Segment) (TunnelServer_obsoleteDeclarativeTunnelConnect_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results{st}, err +} + +func ReadRootTunnelServer_obsoleteDeclarativeTunnelConnect_Results(msg *capnp.Message) (TunnelServer_obsoleteDeclarativeTunnelConnect_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results{root.Struct()}, err +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Results) String() string { + str, _ := text.Marshal(0xfeac5c8f4899ef7c, s.Struct) + return str +} + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List is a list of TunnelServer_obsoleteDeclarativeTunnelConnect_Results. +type TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List struct{ capnp.List } + +// NewTunnelServer_obsoleteDeclarativeTunnelConnect_Results creates a new list of TunnelServer_obsoleteDeclarativeTunnelConnect_Results. +func NewTunnelServer_obsoleteDeclarativeTunnelConnect_Results_List(s *capnp.Segment, sz int32) (TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List{l}, err +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List) At(i int) TunnelServer_obsoleteDeclarativeTunnelConnect_Results { + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results{s.List.Struct(i)} +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List) Set(i int, v TunnelServer_obsoleteDeclarativeTunnelConnect_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_obsoleteDeclarativeTunnelConnect_Results_List) String() string { + str, _ := text.MarshalList(0xfeac5c8f4899ef7c, s.List) + return str +} + +// TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise is a wrapper for a TunnelServer_obsoleteDeclarativeTunnelConnect_Results promised by a client call. +type TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_obsoleteDeclarativeTunnelConnect_Results_Promise) Struct() (TunnelServer_obsoleteDeclarativeTunnelConnect_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_obsoleteDeclarativeTunnelConnect_Results{s}, err +} + +type TunnelServer_authenticate_Params struct{ capnp.Struct } + +// TunnelServer_authenticate_Params_TypeID is the unique identifier for the type TunnelServer_authenticate_Params. +const TunnelServer_authenticate_Params_TypeID = 0x85c8cea1ab1894f3 + +func NewTunnelServer_authenticate_Params(s *capnp.Segment) (TunnelServer_authenticate_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_authenticate_Params{st}, err +} + +func NewRootTunnelServer_authenticate_Params(s *capnp.Segment) (TunnelServer_authenticate_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_authenticate_Params{st}, err +} + +func ReadRootTunnelServer_authenticate_Params(msg *capnp.Message) (TunnelServer_authenticate_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_authenticate_Params{root.Struct()}, err +} + +func (s TunnelServer_authenticate_Params) String() string { + str, _ := text.Marshal(0x85c8cea1ab1894f3, s.Struct) + return str +} + +func (s TunnelServer_authenticate_Params) OriginCert() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s TunnelServer_authenticate_Params) HasOriginCert() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_authenticate_Params) SetOriginCert(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s TunnelServer_authenticate_Params) Hostname() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s TunnelServer_authenticate_Params) HasHostname() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelServer_authenticate_Params) HostnameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s TunnelServer_authenticate_Params) SetHostname(v string) error { + return s.Struct.SetText(1, v) +} + +func (s TunnelServer_authenticate_Params) Options() (RegistrationOptions, error) { + p, err := s.Struct.Ptr(2) + return RegistrationOptions{Struct: p.Struct()}, err +} + +func (s TunnelServer_authenticate_Params) HasOptions() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelServer_authenticate_Params) SetOptions(v RegistrationOptions) error { + return s.Struct.SetPtr(2, v.Struct.ToPtr()) +} + +// NewOptions sets the options field to a newly +// allocated RegistrationOptions struct, preferring placement in s's segment. +func (s TunnelServer_authenticate_Params) NewOptions() (RegistrationOptions, error) { + ss, err := NewRegistrationOptions(s.Struct.Segment()) + if err != nil { + return RegistrationOptions{}, err + } + err = s.Struct.SetPtr(2, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_authenticate_Params_List is a list of TunnelServer_authenticate_Params. +type TunnelServer_authenticate_Params_List struct{ capnp.List } + +// NewTunnelServer_authenticate_Params creates a new list of TunnelServer_authenticate_Params. +func NewTunnelServer_authenticate_Params_List(s *capnp.Segment, sz int32) (TunnelServer_authenticate_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}, sz) + return TunnelServer_authenticate_Params_List{l}, err +} + +func (s TunnelServer_authenticate_Params_List) At(i int) TunnelServer_authenticate_Params { + return TunnelServer_authenticate_Params{s.List.Struct(i)} +} + +func (s TunnelServer_authenticate_Params_List) Set(i int, v TunnelServer_authenticate_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_authenticate_Params_List) String() string { + str, _ := text.MarshalList(0x85c8cea1ab1894f3, s.List) + return str +} + +// TunnelServer_authenticate_Params_Promise is a wrapper for a TunnelServer_authenticate_Params promised by a client call. +type TunnelServer_authenticate_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_authenticate_Params_Promise) Struct() (TunnelServer_authenticate_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_authenticate_Params{s}, err +} + +func (p TunnelServer_authenticate_Params_Promise) Options() RegistrationOptions_Promise { + return RegistrationOptions_Promise{Pipeline: p.Pipeline.GetPipeline(2)} +} + +type TunnelServer_authenticate_Results struct{ capnp.Struct } + +// TunnelServer_authenticate_Results_TypeID is the unique identifier for the type TunnelServer_authenticate_Results. +const TunnelServer_authenticate_Results_TypeID = 0xfc5edf80e39c0796 + +func NewTunnelServer_authenticate_Results(s *capnp.Segment) (TunnelServer_authenticate_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_authenticate_Results{st}, err +} + +func NewRootTunnelServer_authenticate_Results(s *capnp.Segment) (TunnelServer_authenticate_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_authenticate_Results{st}, err +} + +func ReadRootTunnelServer_authenticate_Results(msg *capnp.Message) (TunnelServer_authenticate_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_authenticate_Results{root.Struct()}, err +} + +func (s TunnelServer_authenticate_Results) String() string { + str, _ := text.Marshal(0xfc5edf80e39c0796, s.Struct) + return str +} + +func (s TunnelServer_authenticate_Results) Result() (AuthenticateResponse, error) { + p, err := s.Struct.Ptr(0) + return AuthenticateResponse{Struct: p.Struct()}, err +} + +func (s TunnelServer_authenticate_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_authenticate_Results) SetResult(v AuthenticateResponse) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated AuthenticateResponse struct, preferring placement in s's segment. +func (s TunnelServer_authenticate_Results) NewResult() (AuthenticateResponse, error) { + ss, err := NewAuthenticateResponse(s.Struct.Segment()) + if err != nil { + return AuthenticateResponse{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_authenticate_Results_List is a list of TunnelServer_authenticate_Results. +type TunnelServer_authenticate_Results_List struct{ capnp.List } + +// NewTunnelServer_authenticate_Results creates a new list of TunnelServer_authenticate_Results. +func NewTunnelServer_authenticate_Results_List(s *capnp.Segment, sz int32) (TunnelServer_authenticate_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_authenticate_Results_List{l}, err +} + +func (s TunnelServer_authenticate_Results_List) At(i int) TunnelServer_authenticate_Results { + return TunnelServer_authenticate_Results{s.List.Struct(i)} +} + +func (s TunnelServer_authenticate_Results_List) Set(i int, v TunnelServer_authenticate_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_authenticate_Results_List) String() string { + str, _ := text.MarshalList(0xfc5edf80e39c0796, s.List) + return str +} + +// TunnelServer_authenticate_Results_Promise is a wrapper for a TunnelServer_authenticate_Results promised by a client call. +type TunnelServer_authenticate_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_authenticate_Results_Promise) Struct() (TunnelServer_authenticate_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_authenticate_Results{s}, err +} + +func (p TunnelServer_authenticate_Results_Promise) Result() AuthenticateResponse_Promise { + return AuthenticateResponse_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type TunnelServer_reconnectTunnel_Params struct{ capnp.Struct } + +// TunnelServer_reconnectTunnel_Params_TypeID is the unique identifier for the type TunnelServer_reconnectTunnel_Params. +const TunnelServer_reconnectTunnel_Params_TypeID = 0xa353a3556df74984 + +func NewTunnelServer_reconnectTunnel_Params(s *capnp.Segment) (TunnelServer_reconnectTunnel_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 5}) + return TunnelServer_reconnectTunnel_Params{st}, err +} + +func NewRootTunnelServer_reconnectTunnel_Params(s *capnp.Segment) (TunnelServer_reconnectTunnel_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 5}) + return TunnelServer_reconnectTunnel_Params{st}, err +} + +func ReadRootTunnelServer_reconnectTunnel_Params(msg *capnp.Message) (TunnelServer_reconnectTunnel_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_reconnectTunnel_Params{root.Struct()}, err +} + +func (s TunnelServer_reconnectTunnel_Params) String() string { + str, _ := text.Marshal(0xa353a3556df74984, s.Struct) + return str +} + +func (s TunnelServer_reconnectTunnel_Params) Jwt() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s TunnelServer_reconnectTunnel_Params) HasJwt() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Params) SetJwt(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s TunnelServer_reconnectTunnel_Params) EventDigest() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s TunnelServer_reconnectTunnel_Params) HasEventDigest() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Params) SetEventDigest(v []byte) error { + return s.Struct.SetData(1, v) +} + +func (s TunnelServer_reconnectTunnel_Params) ConnDigest() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return []byte(p.Data()), err +} + +func (s TunnelServer_reconnectTunnel_Params) HasConnDigest() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Params) SetConnDigest(v []byte) error { + return s.Struct.SetData(2, v) +} + +func (s TunnelServer_reconnectTunnel_Params) Hostname() (string, error) { + p, err := s.Struct.Ptr(3) + return p.Text(), err +} + +func (s TunnelServer_reconnectTunnel_Params) HasHostname() bool { + p, err := s.Struct.Ptr(3) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Params) HostnameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(3) + return p.TextBytes(), err +} + +func (s TunnelServer_reconnectTunnel_Params) SetHostname(v string) error { + return s.Struct.SetText(3, v) +} + +func (s TunnelServer_reconnectTunnel_Params) Options() (RegistrationOptions, error) { + p, err := s.Struct.Ptr(4) + return RegistrationOptions{Struct: p.Struct()}, err +} + +func (s TunnelServer_reconnectTunnel_Params) HasOptions() bool { + p, err := s.Struct.Ptr(4) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Params) SetOptions(v RegistrationOptions) error { + return s.Struct.SetPtr(4, v.Struct.ToPtr()) +} + +// NewOptions sets the options field to a newly +// allocated RegistrationOptions struct, preferring placement in s's segment. +func (s TunnelServer_reconnectTunnel_Params) NewOptions() (RegistrationOptions, error) { + ss, err := NewRegistrationOptions(s.Struct.Segment()) + if err != nil { + return RegistrationOptions{}, err + } + err = s.Struct.SetPtr(4, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_reconnectTunnel_Params_List is a list of TunnelServer_reconnectTunnel_Params. +type TunnelServer_reconnectTunnel_Params_List struct{ capnp.List } + +// NewTunnelServer_reconnectTunnel_Params creates a new list of TunnelServer_reconnectTunnel_Params. +func NewTunnelServer_reconnectTunnel_Params_List(s *capnp.Segment, sz int32) (TunnelServer_reconnectTunnel_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 5}, sz) + return TunnelServer_reconnectTunnel_Params_List{l}, err +} + +func (s TunnelServer_reconnectTunnel_Params_List) At(i int) TunnelServer_reconnectTunnel_Params { + return TunnelServer_reconnectTunnel_Params{s.List.Struct(i)} +} + +func (s TunnelServer_reconnectTunnel_Params_List) Set(i int, v TunnelServer_reconnectTunnel_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_reconnectTunnel_Params_List) String() string { + str, _ := text.MarshalList(0xa353a3556df74984, s.List) + return str +} + +// TunnelServer_reconnectTunnel_Params_Promise is a wrapper for a TunnelServer_reconnectTunnel_Params promised by a client call. +type TunnelServer_reconnectTunnel_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_reconnectTunnel_Params_Promise) Struct() (TunnelServer_reconnectTunnel_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_reconnectTunnel_Params{s}, err +} + +func (p TunnelServer_reconnectTunnel_Params_Promise) Options() RegistrationOptions_Promise { + return RegistrationOptions_Promise{Pipeline: p.Pipeline.GetPipeline(4)} +} + +type TunnelServer_reconnectTunnel_Results struct{ capnp.Struct } + +// TunnelServer_reconnectTunnel_Results_TypeID is the unique identifier for the type TunnelServer_reconnectTunnel_Results. +const TunnelServer_reconnectTunnel_Results_TypeID = 0xd4d18de97bb12de3 + +func NewTunnelServer_reconnectTunnel_Results(s *capnp.Segment) (TunnelServer_reconnectTunnel_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_reconnectTunnel_Results{st}, err +} + +func NewRootTunnelServer_reconnectTunnel_Results(s *capnp.Segment) (TunnelServer_reconnectTunnel_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_reconnectTunnel_Results{st}, err +} + +func ReadRootTunnelServer_reconnectTunnel_Results(msg *capnp.Message) (TunnelServer_reconnectTunnel_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_reconnectTunnel_Results{root.Struct()}, err +} + +func (s TunnelServer_reconnectTunnel_Results) String() string { + str, _ := text.Marshal(0xd4d18de97bb12de3, s.Struct) + return str +} + +func (s TunnelServer_reconnectTunnel_Results) Result() (TunnelRegistration, error) { + p, err := s.Struct.Ptr(0) + return TunnelRegistration{Struct: p.Struct()}, err +} + +func (s TunnelServer_reconnectTunnel_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_reconnectTunnel_Results) SetResult(v TunnelRegistration) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated TunnelRegistration struct, preferring placement in s's segment. +func (s TunnelServer_reconnectTunnel_Results) NewResult() (TunnelRegistration, error) { + ss, err := NewTunnelRegistration(s.Struct.Segment()) + if err != nil { + return TunnelRegistration{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_reconnectTunnel_Results_List is a list of TunnelServer_reconnectTunnel_Results. +type TunnelServer_reconnectTunnel_Results_List struct{ capnp.List } + +// NewTunnelServer_reconnectTunnel_Results creates a new list of TunnelServer_reconnectTunnel_Results. +func NewTunnelServer_reconnectTunnel_Results_List(s *capnp.Segment, sz int32) (TunnelServer_reconnectTunnel_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_reconnectTunnel_Results_List{l}, err +} + +func (s TunnelServer_reconnectTunnel_Results_List) At(i int) TunnelServer_reconnectTunnel_Results { + return TunnelServer_reconnectTunnel_Results{s.List.Struct(i)} +} + +func (s TunnelServer_reconnectTunnel_Results_List) Set(i int, v TunnelServer_reconnectTunnel_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s TunnelServer_reconnectTunnel_Results_List) String() string { + str, _ := text.MarshalList(0xd4d18de97bb12de3, s.List) + return str +} + +// TunnelServer_reconnectTunnel_Results_Promise is a wrapper for a TunnelServer_reconnectTunnel_Results promised by a client call. +type TunnelServer_reconnectTunnel_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_reconnectTunnel_Results_Promise) Struct() (TunnelServer_reconnectTunnel_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_reconnectTunnel_Results{s}, err +} + +func (p TunnelServer_reconnectTunnel_Results_Promise) Result() TunnelRegistration_Promise { + return TunnelRegistration_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type Tag struct{ capnp.Struct } + +// Tag_TypeID is the unique identifier for the type Tag. +const Tag_TypeID = 0xcbd96442ae3bb01a + +func NewTag(s *capnp.Segment) (Tag, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Tag{st}, err +} + +func NewRootTag(s *capnp.Segment) (Tag, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Tag{st}, err +} + +func ReadRootTag(msg *capnp.Message) (Tag, error) { + root, err := msg.RootPtr() + return Tag{root.Struct()}, err +} + +func (s Tag) String() string { + str, _ := text.Marshal(0xcbd96442ae3bb01a, s.Struct) + return str +} + +func (s Tag) Name() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s Tag) HasName() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s Tag) NameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s Tag) SetName(v string) error { + return s.Struct.SetText(0, v) +} + +func (s Tag) Value() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s Tag) HasValue() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s Tag) ValueBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s Tag) SetValue(v string) error { + return s.Struct.SetText(1, v) +} + +// Tag_List is a list of Tag. +type Tag_List struct{ capnp.List } + +// NewTag creates a new list of Tag. +func NewTag_List(s *capnp.Segment, sz int32) (Tag_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return Tag_List{l}, err +} + +func (s Tag_List) At(i int) Tag { return Tag{s.List.Struct(i)} } + +func (s Tag_List) Set(i int, v Tag) error { return s.List.SetStruct(i, v.Struct) } + +func (s Tag_List) String() string { + str, _ := text.MarshalList(0xcbd96442ae3bb01a, s.List) + return str +} + +// Tag_Promise is a wrapper for a Tag promised by a client call. +type Tag_Promise struct{ *capnp.Pipeline } + +func (p Tag_Promise) Struct() (Tag, error) { + s, err := p.Pipeline.Struct() + return Tag{s}, err +} + +type ClientInfo struct{ capnp.Struct } + +// ClientInfo_TypeID is the unique identifier for the type ClientInfo. +const ClientInfo_TypeID = 0x83ced0145b2f114b + +func NewClientInfo(s *capnp.Segment) (ClientInfo, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 4}) + return ClientInfo{st}, err +} + +func NewRootClientInfo(s *capnp.Segment) (ClientInfo, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 4}) + return ClientInfo{st}, err +} + +func ReadRootClientInfo(msg *capnp.Message) (ClientInfo, error) { + root, err := msg.RootPtr() + return ClientInfo{root.Struct()}, err +} + +func (s ClientInfo) String() string { + str, _ := text.Marshal(0x83ced0145b2f114b, s.Struct) + return str +} + +func (s ClientInfo) ClientId() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s ClientInfo) HasClientId() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ClientInfo) SetClientId(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s ClientInfo) Features() (capnp.TextList, error) { + p, err := s.Struct.Ptr(1) + return capnp.TextList{List: p.List()}, err +} + +func (s ClientInfo) HasFeatures() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s ClientInfo) SetFeatures(v capnp.TextList) error { + return s.Struct.SetPtr(1, v.List.ToPtr()) +} + +// NewFeatures sets the features field to a newly +// allocated capnp.TextList, preferring placement in s's segment. +func (s ClientInfo) NewFeatures(n int32) (capnp.TextList, error) { + l, err := capnp.NewTextList(s.Struct.Segment(), n) + if err != nil { + return capnp.TextList{}, err + } + err = s.Struct.SetPtr(1, l.List.ToPtr()) + return l, err +} + +func (s ClientInfo) Version() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s ClientInfo) HasVersion() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s ClientInfo) VersionBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s ClientInfo) SetVersion(v string) error { + return s.Struct.SetText(2, v) +} + +func (s ClientInfo) Arch() (string, error) { + p, err := s.Struct.Ptr(3) + return p.Text(), err +} + +func (s ClientInfo) HasArch() bool { + p, err := s.Struct.Ptr(3) + return p.IsValid() || err != nil +} + +func (s ClientInfo) ArchBytes() ([]byte, error) { + p, err := s.Struct.Ptr(3) + return p.TextBytes(), err +} + +func (s ClientInfo) SetArch(v string) error { + return s.Struct.SetText(3, v) +} + +// ClientInfo_List is a list of ClientInfo. +type ClientInfo_List struct{ capnp.List } + +// NewClientInfo creates a new list of ClientInfo. +func NewClientInfo_List(s *capnp.Segment, sz int32) (ClientInfo_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 4}, sz) + return ClientInfo_List{l}, err +} + +func (s ClientInfo_List) At(i int) ClientInfo { return ClientInfo{s.List.Struct(i)} } + +func (s ClientInfo_List) Set(i int, v ClientInfo) error { return s.List.SetStruct(i, v.Struct) } + +func (s ClientInfo_List) String() string { + str, _ := text.MarshalList(0x83ced0145b2f114b, s.List) + return str +} + +// ClientInfo_Promise is a wrapper for a ClientInfo promised by a client call. +type ClientInfo_Promise struct{ *capnp.Pipeline } + +func (p ClientInfo_Promise) Struct() (ClientInfo, error) { + s, err := p.Pipeline.Struct() + return ClientInfo{s}, err +} + +type ConnectionOptions struct{ capnp.Struct } + +// ConnectionOptions_TypeID is the unique identifier for the type ConnectionOptions. +const ConnectionOptions_TypeID = 0xb4bf9861fe035d04 + +func NewConnectionOptions(s *capnp.Segment) (ConnectionOptions, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectionOptions{st}, err +} + +func NewRootConnectionOptions(s *capnp.Segment) (ConnectionOptions, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectionOptions{st}, err +} + +func ReadRootConnectionOptions(msg *capnp.Message) (ConnectionOptions, error) { + root, err := msg.RootPtr() + return ConnectionOptions{root.Struct()}, err +} + +func (s ConnectionOptions) String() string { + str, _ := text.Marshal(0xb4bf9861fe035d04, s.Struct) + return str +} + +func (s ConnectionOptions) Client() (ClientInfo, error) { + p, err := s.Struct.Ptr(0) + return ClientInfo{Struct: p.Struct()}, err +} + +func (s ConnectionOptions) HasClient() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectionOptions) SetClient(v ClientInfo) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewClient sets the client field to a newly +// allocated ClientInfo struct, preferring placement in s's segment. +func (s ConnectionOptions) NewClient() (ClientInfo, error) { + ss, err := NewClientInfo(s.Struct.Segment()) + if err != nil { + return ClientInfo{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +func (s ConnectionOptions) OriginLocalIp() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s ConnectionOptions) HasOriginLocalIp() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s ConnectionOptions) SetOriginLocalIp(v []byte) error { + return s.Struct.SetData(1, v) +} + +func (s ConnectionOptions) ReplaceExisting() bool { + return s.Struct.Bit(0) +} + +func (s ConnectionOptions) SetReplaceExisting(v bool) { + s.Struct.SetBit(0, v) +} + +func (s ConnectionOptions) CompressionQuality() uint8 { + return s.Struct.Uint8(1) +} + +func (s ConnectionOptions) SetCompressionQuality(v uint8) { + s.Struct.SetUint8(1, v) +} + +func (s ConnectionOptions) NumPreviousAttempts() uint8 { + return s.Struct.Uint8(2) +} + +func (s ConnectionOptions) SetNumPreviousAttempts(v uint8) { + s.Struct.SetUint8(2, v) +} + +// ConnectionOptions_List is a list of ConnectionOptions. +type ConnectionOptions_List struct{ capnp.List } + +// NewConnectionOptions creates a new list of ConnectionOptions. +func NewConnectionOptions_List(s *capnp.Segment, sz int32) (ConnectionOptions_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}, sz) + return ConnectionOptions_List{l}, err +} + +func (s ConnectionOptions_List) At(i int) ConnectionOptions { + return ConnectionOptions{s.List.Struct(i)} +} + +func (s ConnectionOptions_List) Set(i int, v ConnectionOptions) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConnectionOptions_List) String() string { + str, _ := text.MarshalList(0xb4bf9861fe035d04, s.List) + return str +} + +// ConnectionOptions_Promise is a wrapper for a ConnectionOptions promised by a client call. +type ConnectionOptions_Promise struct{ *capnp.Pipeline } + +func (p ConnectionOptions_Promise) Struct() (ConnectionOptions, error) { + s, err := p.Pipeline.Struct() + return ConnectionOptions{s}, err +} + +func (p ConnectionOptions_Promise) Client() ClientInfo_Promise { + return ClientInfo_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type ConnectionResponse struct{ capnp.Struct } +type ConnectionResponse_result ConnectionResponse +type ConnectionResponse_result_Which uint16 + +const ( + ConnectionResponse_result_Which_error ConnectionResponse_result_Which = 0 + ConnectionResponse_result_Which_connectionDetails ConnectionResponse_result_Which = 1 +) + +func (w ConnectionResponse_result_Which) String() string { + const s = "errorconnectionDetails" + switch w { + case ConnectionResponse_result_Which_error: + return s[0:5] + case ConnectionResponse_result_Which_connectionDetails: + return s[5:22] + + } + return "ConnectionResponse_result_Which(" + strconv.FormatUint(uint64(w), 10) + ")" +} + +// ConnectionResponse_TypeID is the unique identifier for the type ConnectionResponse. +const ConnectionResponse_TypeID = 0xdbaa9d03d52b62dc + +func NewConnectionResponse(s *capnp.Segment) (ConnectionResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConnectionResponse{st}, err +} + +func NewRootConnectionResponse(s *capnp.Segment) (ConnectionResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConnectionResponse{st}, err +} + +func ReadRootConnectionResponse(msg *capnp.Message) (ConnectionResponse, error) { + root, err := msg.RootPtr() + return ConnectionResponse{root.Struct()}, err +} + +func (s ConnectionResponse) String() string { + str, _ := text.Marshal(0xdbaa9d03d52b62dc, s.Struct) + return str +} + +func (s ConnectionResponse) Result() ConnectionResponse_result { return ConnectionResponse_result(s) } + +func (s ConnectionResponse_result) Which() ConnectionResponse_result_Which { + return ConnectionResponse_result_Which(s.Struct.Uint16(0)) +} +func (s ConnectionResponse_result) Error() (ConnectionError, error) { + if s.Struct.Uint16(0) != 0 { + panic("Which() != error") + } + p, err := s.Struct.Ptr(0) + return ConnectionError{Struct: p.Struct()}, err +} + +func (s ConnectionResponse_result) HasError() bool { + if s.Struct.Uint16(0) != 0 { + return false + } + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectionResponse_result) SetError(v ConnectionError) error { + s.Struct.SetUint16(0, 0) + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewError sets the error field to a newly +// allocated ConnectionError struct, preferring placement in s's segment. +func (s ConnectionResponse_result) NewError() (ConnectionError, error) { + s.Struct.SetUint16(0, 0) + ss, err := NewConnectionError(s.Struct.Segment()) + if err != nil { + return ConnectionError{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +func (s ConnectionResponse_result) ConnectionDetails() (ConnectionDetails, error) { + if s.Struct.Uint16(0) != 1 { + panic("Which() != connectionDetails") + } + p, err := s.Struct.Ptr(0) + return ConnectionDetails{Struct: p.Struct()}, err +} + +func (s ConnectionResponse_result) HasConnectionDetails() bool { + if s.Struct.Uint16(0) != 1 { + return false + } + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectionResponse_result) SetConnectionDetails(v ConnectionDetails) error { + s.Struct.SetUint16(0, 1) + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewConnectionDetails sets the connectionDetails field to a newly +// allocated ConnectionDetails struct, preferring placement in s's segment. +func (s ConnectionResponse_result) NewConnectionDetails() (ConnectionDetails, error) { + s.Struct.SetUint16(0, 1) + ss, err := NewConnectionDetails(s.Struct.Segment()) + if err != nil { + return ConnectionDetails{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// ConnectionResponse_List is a list of ConnectionResponse. +type ConnectionResponse_List struct{ capnp.List } + +// NewConnectionResponse creates a new list of ConnectionResponse. +func NewConnectionResponse_List(s *capnp.Segment, sz int32) (ConnectionResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}, sz) + return ConnectionResponse_List{l}, err +} + +func (s ConnectionResponse_List) At(i int) ConnectionResponse { + return ConnectionResponse{s.List.Struct(i)} +} + +func (s ConnectionResponse_List) Set(i int, v ConnectionResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConnectionResponse_List) String() string { + str, _ := text.MarshalList(0xdbaa9d03d52b62dc, s.List) + return str +} + +// ConnectionResponse_Promise is a wrapper for a ConnectionResponse promised by a client call. +type ConnectionResponse_Promise struct{ *capnp.Pipeline } + +func (p ConnectionResponse_Promise) Struct() (ConnectionResponse, error) { + s, err := p.Pipeline.Struct() + return ConnectionResponse{s}, err +} + +func (p ConnectionResponse_Promise) Result() ConnectionResponse_result_Promise { + return ConnectionResponse_result_Promise{p.Pipeline} +} + +// ConnectionResponse_result_Promise is a wrapper for a ConnectionResponse_result promised by a client call. +type ConnectionResponse_result_Promise struct{ *capnp.Pipeline } + +func (p ConnectionResponse_result_Promise) Struct() (ConnectionResponse_result, error) { + s, err := p.Pipeline.Struct() + return ConnectionResponse_result{s}, err +} + +func (p ConnectionResponse_result_Promise) Error() ConnectionError_Promise { + return ConnectionError_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +func (p ConnectionResponse_result_Promise) ConnectionDetails() ConnectionDetails_Promise { + return ConnectionDetails_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type ConnectionError struct{ capnp.Struct } + +// ConnectionError_TypeID is the unique identifier for the type ConnectionError. +const ConnectionError_TypeID = 0xf5f383d2785edb86 + +func NewConnectionError(s *capnp.Segment) (ConnectionError, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 1}) + return ConnectionError{st}, err +} + +func NewRootConnectionError(s *capnp.Segment) (ConnectionError, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 1}) + return ConnectionError{st}, err +} + +func ReadRootConnectionError(msg *capnp.Message) (ConnectionError, error) { + root, err := msg.RootPtr() + return ConnectionError{root.Struct()}, err +} + +func (s ConnectionError) String() string { + str, _ := text.Marshal(0xf5f383d2785edb86, s.Struct) + return str +} + +func (s ConnectionError) Cause() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s ConnectionError) HasCause() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectionError) CauseBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s ConnectionError) SetCause(v string) error { + return s.Struct.SetText(0, v) +} + +func (s ConnectionError) RetryAfter() int64 { + return int64(s.Struct.Uint64(0)) +} + +func (s ConnectionError) SetRetryAfter(v int64) { + s.Struct.SetUint64(0, uint64(v)) +} + +func (s ConnectionError) ShouldRetry() bool { + return s.Struct.Bit(64) +} + +func (s ConnectionError) SetShouldRetry(v bool) { + s.Struct.SetBit(64, v) +} + +// ConnectionError_List is a list of ConnectionError. +type ConnectionError_List struct{ capnp.List } + +// NewConnectionError creates a new list of ConnectionError. +func NewConnectionError_List(s *capnp.Segment, sz int32) (ConnectionError_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 16, PointerCount: 1}, sz) + return ConnectionError_List{l}, err +} + +func (s ConnectionError_List) At(i int) ConnectionError { return ConnectionError{s.List.Struct(i)} } + +func (s ConnectionError_List) Set(i int, v ConnectionError) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConnectionError_List) String() string { + str, _ := text.MarshalList(0xf5f383d2785edb86, s.List) + return str +} + +// ConnectionError_Promise is a wrapper for a ConnectionError promised by a client call. +type ConnectionError_Promise struct{ *capnp.Pipeline } + +func (p ConnectionError_Promise) Struct() (ConnectionError, error) { + s, err := p.Pipeline.Struct() + return ConnectionError{s}, err +} + +type ConnectionDetails struct{ capnp.Struct } + +// ConnectionDetails_TypeID is the unique identifier for the type ConnectionDetails. +const ConnectionDetails_TypeID = 0xb5f39f082b9ac18a + +func NewConnectionDetails(s *capnp.Segment) (ConnectionDetails, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectionDetails{st}, err +} + +func NewRootConnectionDetails(s *capnp.Segment) (ConnectionDetails, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}) + return ConnectionDetails{st}, err +} + +func ReadRootConnectionDetails(msg *capnp.Message) (ConnectionDetails, error) { + root, err := msg.RootPtr() + return ConnectionDetails{root.Struct()}, err +} + +func (s ConnectionDetails) String() string { + str, _ := text.Marshal(0xb5f39f082b9ac18a, s.Struct) + return str +} + +func (s ConnectionDetails) Uuid() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s ConnectionDetails) HasUuid() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConnectionDetails) SetUuid(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s ConnectionDetails) LocationName() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s ConnectionDetails) HasLocationName() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s ConnectionDetails) LocationNameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s ConnectionDetails) SetLocationName(v string) error { + return s.Struct.SetText(1, v) +} + +func (s ConnectionDetails) TunnelIsRemotelyManaged() bool { + return s.Struct.Bit(0) +} + +func (s ConnectionDetails) SetTunnelIsRemotelyManaged(v bool) { + s.Struct.SetBit(0, v) +} + +// ConnectionDetails_List is a list of ConnectionDetails. +type ConnectionDetails_List struct{ capnp.List } + +// NewConnectionDetails creates a new list of ConnectionDetails. +func NewConnectionDetails_List(s *capnp.Segment, sz int32) (ConnectionDetails_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}, sz) + return ConnectionDetails_List{l}, err +} + +func (s ConnectionDetails_List) At(i int) ConnectionDetails { + return ConnectionDetails{s.List.Struct(i)} +} + +func (s ConnectionDetails_List) Set(i int, v ConnectionDetails) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConnectionDetails_List) String() string { + str, _ := text.MarshalList(0xb5f39f082b9ac18a, s.List) + return str +} + +// ConnectionDetails_Promise is a wrapper for a ConnectionDetails promised by a client call. +type ConnectionDetails_Promise struct{ *capnp.Pipeline } + +func (p ConnectionDetails_Promise) Struct() (ConnectionDetails, error) { + s, err := p.Pipeline.Struct() + return ConnectionDetails{s}, err +} + +type TunnelAuth struct{ capnp.Struct } + +// TunnelAuth_TypeID is the unique identifier for the type TunnelAuth. +const TunnelAuth_TypeID = 0x9496331ab9cd463f + +func NewTunnelAuth(s *capnp.Segment) (TunnelAuth, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return TunnelAuth{st}, err +} + +func NewRootTunnelAuth(s *capnp.Segment) (TunnelAuth, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return TunnelAuth{st}, err +} + +func ReadRootTunnelAuth(msg *capnp.Message) (TunnelAuth, error) { + root, err := msg.RootPtr() + return TunnelAuth{root.Struct()}, err +} + +func (s TunnelAuth) String() string { + str, _ := text.Marshal(0x9496331ab9cd463f, s.Struct) + return str +} + +func (s TunnelAuth) AccountTag() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s TunnelAuth) HasAccountTag() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelAuth) AccountTagBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s TunnelAuth) SetAccountTag(v string) error { + return s.Struct.SetText(0, v) +} + +func (s TunnelAuth) TunnelSecret() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s TunnelAuth) HasTunnelSecret() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelAuth) SetTunnelSecret(v []byte) error { + return s.Struct.SetData(1, v) +} + +// TunnelAuth_List is a list of TunnelAuth. +type TunnelAuth_List struct{ capnp.List } + +// NewTunnelAuth creates a new list of TunnelAuth. +func NewTunnelAuth_List(s *capnp.Segment, sz int32) (TunnelAuth_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return TunnelAuth_List{l}, err +} + +func (s TunnelAuth_List) At(i int) TunnelAuth { return TunnelAuth{s.List.Struct(i)} } + +func (s TunnelAuth_List) Set(i int, v TunnelAuth) error { return s.List.SetStruct(i, v.Struct) } + +func (s TunnelAuth_List) String() string { + str, _ := text.MarshalList(0x9496331ab9cd463f, s.List) + return str +} + +// TunnelAuth_Promise is a wrapper for a TunnelAuth promised by a client call. +type TunnelAuth_Promise struct{ *capnp.Pipeline } + +func (p TunnelAuth_Promise) Struct() (TunnelAuth, error) { + s, err := p.Pipeline.Struct() + return TunnelAuth{s}, err +} + +type RegistrationServer struct{ Client capnp.Client } + +// RegistrationServer_TypeID is the unique identifier for the type RegistrationServer. +const RegistrationServer_TypeID = 0xf71695ec7fe85497 + +func (c RegistrationServer) RegisterConnection(ctx context.Context, params func(RegistrationServer_registerConnection_Params) error, opts ...capnp.CallOption) RegistrationServer_registerConnection_Results_Promise { + if c.Client == nil { + return RegistrationServer_registerConnection_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "registerConnection", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(RegistrationServer_registerConnection_Params{Struct: s}) } + } + return RegistrationServer_registerConnection_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c RegistrationServer) UnregisterConnection(ctx context.Context, params func(RegistrationServer_unregisterConnection_Params) error, opts ...capnp.CallOption) RegistrationServer_unregisterConnection_Results_Promise { + if c.Client == nil { + return RegistrationServer_unregisterConnection_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "unregisterConnection", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { return params(RegistrationServer_unregisterConnection_Params{Struct: s}) } + } + return RegistrationServer_unregisterConnection_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c RegistrationServer) UpdateLocalConfiguration(ctx context.Context, params func(RegistrationServer_updateLocalConfiguration_Params) error, opts ...capnp.CallOption) RegistrationServer_updateLocalConfiguration_Results_Promise { + if c.Client == nil { + return RegistrationServer_updateLocalConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "updateLocalConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { + return params(RegistrationServer_updateLocalConfiguration_Params{Struct: s}) + } + } + return RegistrationServer_updateLocalConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type RegistrationServer_Server interface { + RegisterConnection(RegistrationServer_registerConnection) error + + UnregisterConnection(RegistrationServer_unregisterConnection) error + + UpdateLocalConfiguration(RegistrationServer_updateLocalConfiguration) error +} + +func RegistrationServer_ServerToClient(s RegistrationServer_Server) RegistrationServer { + c, _ := s.(server.Closer) + return RegistrationServer{Client: server.New(RegistrationServer_Methods(nil, s), c)} +} + +func RegistrationServer_Methods(methods []server.Method, s RegistrationServer_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 3) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "registerConnection", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_registerConnection{c, opts, RegistrationServer_registerConnection_Params{Struct: p}, RegistrationServer_registerConnection_Results{Struct: r}} + return s.RegisterConnection(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "unregisterConnection", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_unregisterConnection{c, opts, RegistrationServer_unregisterConnection_Params{Struct: p}, RegistrationServer_unregisterConnection_Results{Struct: r}} + return s.UnregisterConnection(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xf71695ec7fe85497, + MethodID: 2, + InterfaceName: "tunnelrpc.capnp:RegistrationServer", + MethodName: "updateLocalConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := RegistrationServer_updateLocalConfiguration{c, opts, RegistrationServer_updateLocalConfiguration_Params{Struct: p}, RegistrationServer_updateLocalConfiguration_Results{Struct: r}} + return s.UpdateLocalConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + return methods +} + +// RegistrationServer_registerConnection holds the arguments for a server call to RegistrationServer.registerConnection. +type RegistrationServer_registerConnection struct { + Ctx context.Context + Options capnp.CallOptions + Params RegistrationServer_registerConnection_Params + Results RegistrationServer_registerConnection_Results +} + +// RegistrationServer_unregisterConnection holds the arguments for a server call to RegistrationServer.unregisterConnection. +type RegistrationServer_unregisterConnection struct { + Ctx context.Context + Options capnp.CallOptions + Params RegistrationServer_unregisterConnection_Params + Results RegistrationServer_unregisterConnection_Results +} + +// RegistrationServer_updateLocalConfiguration holds the arguments for a server call to RegistrationServer.updateLocalConfiguration. +type RegistrationServer_updateLocalConfiguration struct { + Ctx context.Context + Options capnp.CallOptions + Params RegistrationServer_updateLocalConfiguration_Params + Results RegistrationServer_updateLocalConfiguration_Results +} + +type RegistrationServer_registerConnection_Params struct{ capnp.Struct } + +// RegistrationServer_registerConnection_Params_TypeID is the unique identifier for the type RegistrationServer_registerConnection_Params. +const RegistrationServer_registerConnection_Params_TypeID = 0xe6646dec8feaa6ee + +func NewRegistrationServer_registerConnection_Params(s *capnp.Segment) (RegistrationServer_registerConnection_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return RegistrationServer_registerConnection_Params{st}, err +} + +func NewRootRegistrationServer_registerConnection_Params(s *capnp.Segment) (RegistrationServer_registerConnection_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return RegistrationServer_registerConnection_Params{st}, err +} + +func ReadRootRegistrationServer_registerConnection_Params(msg *capnp.Message) (RegistrationServer_registerConnection_Params, error) { + root, err := msg.RootPtr() + return RegistrationServer_registerConnection_Params{root.Struct()}, err +} + +func (s RegistrationServer_registerConnection_Params) String() string { + str, _ := text.Marshal(0xe6646dec8feaa6ee, s.Struct) + return str +} + +func (s RegistrationServer_registerConnection_Params) Auth() (TunnelAuth, error) { + p, err := s.Struct.Ptr(0) + return TunnelAuth{Struct: p.Struct()}, err +} + +func (s RegistrationServer_registerConnection_Params) HasAuth() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegistrationServer_registerConnection_Params) SetAuth(v TunnelAuth) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewAuth sets the auth field to a newly +// allocated TunnelAuth struct, preferring placement in s's segment. +func (s RegistrationServer_registerConnection_Params) NewAuth() (TunnelAuth, error) { + ss, err := NewTunnelAuth(s.Struct.Segment()) + if err != nil { + return TunnelAuth{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +func (s RegistrationServer_registerConnection_Params) TunnelId() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s RegistrationServer_registerConnection_Params) HasTunnelId() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s RegistrationServer_registerConnection_Params) SetTunnelId(v []byte) error { + return s.Struct.SetData(1, v) +} + +func (s RegistrationServer_registerConnection_Params) ConnIndex() uint8 { + return s.Struct.Uint8(0) +} + +func (s RegistrationServer_registerConnection_Params) SetConnIndex(v uint8) { + s.Struct.SetUint8(0, v) +} + +func (s RegistrationServer_registerConnection_Params) Options() (ConnectionOptions, error) { + p, err := s.Struct.Ptr(2) + return ConnectionOptions{Struct: p.Struct()}, err +} + +func (s RegistrationServer_registerConnection_Params) HasOptions() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s RegistrationServer_registerConnection_Params) SetOptions(v ConnectionOptions) error { + return s.Struct.SetPtr(2, v.Struct.ToPtr()) +} + +// NewOptions sets the options field to a newly +// allocated ConnectionOptions struct, preferring placement in s's segment. +func (s RegistrationServer_registerConnection_Params) NewOptions() (ConnectionOptions, error) { + ss, err := NewConnectionOptions(s.Struct.Segment()) + if err != nil { + return ConnectionOptions{}, err + } + err = s.Struct.SetPtr(2, ss.Struct.ToPtr()) + return ss, err +} + +// RegistrationServer_registerConnection_Params_List is a list of RegistrationServer_registerConnection_Params. +type RegistrationServer_registerConnection_Params_List struct{ capnp.List } + +// NewRegistrationServer_registerConnection_Params creates a new list of RegistrationServer_registerConnection_Params. +func NewRegistrationServer_registerConnection_Params_List(s *capnp.Segment, sz int32) (RegistrationServer_registerConnection_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}, sz) + return RegistrationServer_registerConnection_Params_List{l}, err +} + +func (s RegistrationServer_registerConnection_Params_List) At(i int) RegistrationServer_registerConnection_Params { + return RegistrationServer_registerConnection_Params{s.List.Struct(i)} +} + +func (s RegistrationServer_registerConnection_Params_List) Set(i int, v RegistrationServer_registerConnection_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_registerConnection_Params_List) String() string { + str, _ := text.MarshalList(0xe6646dec8feaa6ee, s.List) + return str +} + +// RegistrationServer_registerConnection_Params_Promise is a wrapper for a RegistrationServer_registerConnection_Params promised by a client call. +type RegistrationServer_registerConnection_Params_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_registerConnection_Params_Promise) Struct() (RegistrationServer_registerConnection_Params, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_registerConnection_Params{s}, err +} + +func (p RegistrationServer_registerConnection_Params_Promise) Auth() TunnelAuth_Promise { + return TunnelAuth_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +func (p RegistrationServer_registerConnection_Params_Promise) Options() ConnectionOptions_Promise { + return ConnectionOptions_Promise{Pipeline: p.Pipeline.GetPipeline(2)} +} + +type RegistrationServer_registerConnection_Results struct{ capnp.Struct } + +// RegistrationServer_registerConnection_Results_TypeID is the unique identifier for the type RegistrationServer_registerConnection_Results. +const RegistrationServer_registerConnection_Results_TypeID = 0xea50d822450d1f17 + +func NewRegistrationServer_registerConnection_Results(s *capnp.Segment) (RegistrationServer_registerConnection_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return RegistrationServer_registerConnection_Results{st}, err +} + +func NewRootRegistrationServer_registerConnection_Results(s *capnp.Segment) (RegistrationServer_registerConnection_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return RegistrationServer_registerConnection_Results{st}, err +} + +func ReadRootRegistrationServer_registerConnection_Results(msg *capnp.Message) (RegistrationServer_registerConnection_Results, error) { + root, err := msg.RootPtr() + return RegistrationServer_registerConnection_Results{root.Struct()}, err +} + +func (s RegistrationServer_registerConnection_Results) String() string { + str, _ := text.Marshal(0xea50d822450d1f17, s.Struct) + return str +} + +func (s RegistrationServer_registerConnection_Results) Result() (ConnectionResponse, error) { + p, err := s.Struct.Ptr(0) + return ConnectionResponse{Struct: p.Struct()}, err +} + +func (s RegistrationServer_registerConnection_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegistrationServer_registerConnection_Results) SetResult(v ConnectionResponse) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated ConnectionResponse struct, preferring placement in s's segment. +func (s RegistrationServer_registerConnection_Results) NewResult() (ConnectionResponse, error) { + ss, err := NewConnectionResponse(s.Struct.Segment()) + if err != nil { + return ConnectionResponse{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// RegistrationServer_registerConnection_Results_List is a list of RegistrationServer_registerConnection_Results. +type RegistrationServer_registerConnection_Results_List struct{ capnp.List } + +// NewRegistrationServer_registerConnection_Results creates a new list of RegistrationServer_registerConnection_Results. +func NewRegistrationServer_registerConnection_Results_List(s *capnp.Segment, sz int32) (RegistrationServer_registerConnection_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return RegistrationServer_registerConnection_Results_List{l}, err +} + +func (s RegistrationServer_registerConnection_Results_List) At(i int) RegistrationServer_registerConnection_Results { + return RegistrationServer_registerConnection_Results{s.List.Struct(i)} +} + +func (s RegistrationServer_registerConnection_Results_List) Set(i int, v RegistrationServer_registerConnection_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_registerConnection_Results_List) String() string { + str, _ := text.MarshalList(0xea50d822450d1f17, s.List) + return str +} + +// RegistrationServer_registerConnection_Results_Promise is a wrapper for a RegistrationServer_registerConnection_Results promised by a client call. +type RegistrationServer_registerConnection_Results_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_registerConnection_Results_Promise) Struct() (RegistrationServer_registerConnection_Results, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_registerConnection_Results{s}, err +} + +func (p RegistrationServer_registerConnection_Results_Promise) Result() ConnectionResponse_Promise { + return ConnectionResponse_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type RegistrationServer_unregisterConnection_Params struct{ capnp.Struct } + +// RegistrationServer_unregisterConnection_Params_TypeID is the unique identifier for the type RegistrationServer_unregisterConnection_Params. +const RegistrationServer_unregisterConnection_Params_TypeID = 0xf9cb7f4431a307d0 + +func NewRegistrationServer_unregisterConnection_Params(s *capnp.Segment) (RegistrationServer_unregisterConnection_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_unregisterConnection_Params{st}, err +} + +func NewRootRegistrationServer_unregisterConnection_Params(s *capnp.Segment) (RegistrationServer_unregisterConnection_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_unregisterConnection_Params{st}, err +} + +func ReadRootRegistrationServer_unregisterConnection_Params(msg *capnp.Message) (RegistrationServer_unregisterConnection_Params, error) { + root, err := msg.RootPtr() + return RegistrationServer_unregisterConnection_Params{root.Struct()}, err +} + +func (s RegistrationServer_unregisterConnection_Params) String() string { + str, _ := text.Marshal(0xf9cb7f4431a307d0, s.Struct) + return str +} + +// RegistrationServer_unregisterConnection_Params_List is a list of RegistrationServer_unregisterConnection_Params. +type RegistrationServer_unregisterConnection_Params_List struct{ capnp.List } + +// NewRegistrationServer_unregisterConnection_Params creates a new list of RegistrationServer_unregisterConnection_Params. +func NewRegistrationServer_unregisterConnection_Params_List(s *capnp.Segment, sz int32) (RegistrationServer_unregisterConnection_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return RegistrationServer_unregisterConnection_Params_List{l}, err +} + +func (s RegistrationServer_unregisterConnection_Params_List) At(i int) RegistrationServer_unregisterConnection_Params { + return RegistrationServer_unregisterConnection_Params{s.List.Struct(i)} +} + +func (s RegistrationServer_unregisterConnection_Params_List) Set(i int, v RegistrationServer_unregisterConnection_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_unregisterConnection_Params_List) String() string { + str, _ := text.MarshalList(0xf9cb7f4431a307d0, s.List) + return str +} + +// RegistrationServer_unregisterConnection_Params_Promise is a wrapper for a RegistrationServer_unregisterConnection_Params promised by a client call. +type RegistrationServer_unregisterConnection_Params_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_unregisterConnection_Params_Promise) Struct() (RegistrationServer_unregisterConnection_Params, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_unregisterConnection_Params{s}, err +} + +type RegistrationServer_unregisterConnection_Results struct{ capnp.Struct } + +// RegistrationServer_unregisterConnection_Results_TypeID is the unique identifier for the type RegistrationServer_unregisterConnection_Results. +const RegistrationServer_unregisterConnection_Results_TypeID = 0xb046e578094b1ead + +func NewRegistrationServer_unregisterConnection_Results(s *capnp.Segment) (RegistrationServer_unregisterConnection_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_unregisterConnection_Results{st}, err +} + +func NewRootRegistrationServer_unregisterConnection_Results(s *capnp.Segment) (RegistrationServer_unregisterConnection_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_unregisterConnection_Results{st}, err +} + +func ReadRootRegistrationServer_unregisterConnection_Results(msg *capnp.Message) (RegistrationServer_unregisterConnection_Results, error) { + root, err := msg.RootPtr() + return RegistrationServer_unregisterConnection_Results{root.Struct()}, err +} + +func (s RegistrationServer_unregisterConnection_Results) String() string { + str, _ := text.Marshal(0xb046e578094b1ead, s.Struct) + return str +} + +// RegistrationServer_unregisterConnection_Results_List is a list of RegistrationServer_unregisterConnection_Results. +type RegistrationServer_unregisterConnection_Results_List struct{ capnp.List } + +// NewRegistrationServer_unregisterConnection_Results creates a new list of RegistrationServer_unregisterConnection_Results. +func NewRegistrationServer_unregisterConnection_Results_List(s *capnp.Segment, sz int32) (RegistrationServer_unregisterConnection_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return RegistrationServer_unregisterConnection_Results_List{l}, err +} + +func (s RegistrationServer_unregisterConnection_Results_List) At(i int) RegistrationServer_unregisterConnection_Results { + return RegistrationServer_unregisterConnection_Results{s.List.Struct(i)} +} + +func (s RegistrationServer_unregisterConnection_Results_List) Set(i int, v RegistrationServer_unregisterConnection_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_unregisterConnection_Results_List) String() string { + str, _ := text.MarshalList(0xb046e578094b1ead, s.List) + return str +} + +// RegistrationServer_unregisterConnection_Results_Promise is a wrapper for a RegistrationServer_unregisterConnection_Results promised by a client call. +type RegistrationServer_unregisterConnection_Results_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_unregisterConnection_Results_Promise) Struct() (RegistrationServer_unregisterConnection_Results, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_unregisterConnection_Results{s}, err +} + +type RegistrationServer_updateLocalConfiguration_Params struct{ capnp.Struct } + +// RegistrationServer_updateLocalConfiguration_Params_TypeID is the unique identifier for the type RegistrationServer_updateLocalConfiguration_Params. +const RegistrationServer_updateLocalConfiguration_Params_TypeID = 0xc5d6e311876a3604 + +func NewRegistrationServer_updateLocalConfiguration_Params(s *capnp.Segment) (RegistrationServer_updateLocalConfiguration_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return RegistrationServer_updateLocalConfiguration_Params{st}, err +} + +func NewRootRegistrationServer_updateLocalConfiguration_Params(s *capnp.Segment) (RegistrationServer_updateLocalConfiguration_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return RegistrationServer_updateLocalConfiguration_Params{st}, err +} + +func ReadRootRegistrationServer_updateLocalConfiguration_Params(msg *capnp.Message) (RegistrationServer_updateLocalConfiguration_Params, error) { + root, err := msg.RootPtr() + return RegistrationServer_updateLocalConfiguration_Params{root.Struct()}, err +} + +func (s RegistrationServer_updateLocalConfiguration_Params) String() string { + str, _ := text.Marshal(0xc5d6e311876a3604, s.Struct) + return str +} + +func (s RegistrationServer_updateLocalConfiguration_Params) Config() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s RegistrationServer_updateLocalConfiguration_Params) HasConfig() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegistrationServer_updateLocalConfiguration_Params) SetConfig(v []byte) error { + return s.Struct.SetData(0, v) +} + +// RegistrationServer_updateLocalConfiguration_Params_List is a list of RegistrationServer_updateLocalConfiguration_Params. +type RegistrationServer_updateLocalConfiguration_Params_List struct{ capnp.List } + +// NewRegistrationServer_updateLocalConfiguration_Params creates a new list of RegistrationServer_updateLocalConfiguration_Params. +func NewRegistrationServer_updateLocalConfiguration_Params_List(s *capnp.Segment, sz int32) (RegistrationServer_updateLocalConfiguration_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return RegistrationServer_updateLocalConfiguration_Params_List{l}, err +} + +func (s RegistrationServer_updateLocalConfiguration_Params_List) At(i int) RegistrationServer_updateLocalConfiguration_Params { + return RegistrationServer_updateLocalConfiguration_Params{s.List.Struct(i)} +} + +func (s RegistrationServer_updateLocalConfiguration_Params_List) Set(i int, v RegistrationServer_updateLocalConfiguration_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_updateLocalConfiguration_Params_List) String() string { + str, _ := text.MarshalList(0xc5d6e311876a3604, s.List) + return str +} + +// RegistrationServer_updateLocalConfiguration_Params_Promise is a wrapper for a RegistrationServer_updateLocalConfiguration_Params promised by a client call. +type RegistrationServer_updateLocalConfiguration_Params_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_updateLocalConfiguration_Params_Promise) Struct() (RegistrationServer_updateLocalConfiguration_Params, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_updateLocalConfiguration_Params{s}, err +} + +type RegistrationServer_updateLocalConfiguration_Results struct{ capnp.Struct } + +// RegistrationServer_updateLocalConfiguration_Results_TypeID is the unique identifier for the type RegistrationServer_updateLocalConfiguration_Results. +const RegistrationServer_updateLocalConfiguration_Results_TypeID = 0xe5ceae5d6897d7be + +func NewRegistrationServer_updateLocalConfiguration_Results(s *capnp.Segment) (RegistrationServer_updateLocalConfiguration_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_updateLocalConfiguration_Results{st}, err +} + +func NewRootRegistrationServer_updateLocalConfiguration_Results(s *capnp.Segment) (RegistrationServer_updateLocalConfiguration_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return RegistrationServer_updateLocalConfiguration_Results{st}, err +} + +func ReadRootRegistrationServer_updateLocalConfiguration_Results(msg *capnp.Message) (RegistrationServer_updateLocalConfiguration_Results, error) { + root, err := msg.RootPtr() + return RegistrationServer_updateLocalConfiguration_Results{root.Struct()}, err +} + +func (s RegistrationServer_updateLocalConfiguration_Results) String() string { + str, _ := text.Marshal(0xe5ceae5d6897d7be, s.Struct) + return str +} + +// RegistrationServer_updateLocalConfiguration_Results_List is a list of RegistrationServer_updateLocalConfiguration_Results. +type RegistrationServer_updateLocalConfiguration_Results_List struct{ capnp.List } + +// NewRegistrationServer_updateLocalConfiguration_Results creates a new list of RegistrationServer_updateLocalConfiguration_Results. +func NewRegistrationServer_updateLocalConfiguration_Results_List(s *capnp.Segment, sz int32) (RegistrationServer_updateLocalConfiguration_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return RegistrationServer_updateLocalConfiguration_Results_List{l}, err +} + +func (s RegistrationServer_updateLocalConfiguration_Results_List) At(i int) RegistrationServer_updateLocalConfiguration_Results { + return RegistrationServer_updateLocalConfiguration_Results{s.List.Struct(i)} +} + +func (s RegistrationServer_updateLocalConfiguration_Results_List) Set(i int, v RegistrationServer_updateLocalConfiguration_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegistrationServer_updateLocalConfiguration_Results_List) String() string { + str, _ := text.MarshalList(0xe5ceae5d6897d7be, s.List) + return str +} + +// RegistrationServer_updateLocalConfiguration_Results_Promise is a wrapper for a RegistrationServer_updateLocalConfiguration_Results promised by a client call. +type RegistrationServer_updateLocalConfiguration_Results_Promise struct{ *capnp.Pipeline } + +func (p RegistrationServer_updateLocalConfiguration_Results_Promise) Struct() (RegistrationServer_updateLocalConfiguration_Results, error) { + s, err := p.Pipeline.Struct() + return RegistrationServer_updateLocalConfiguration_Results{s}, err +} + +type RegisterUdpSessionResponse struct{ capnp.Struct } + +// RegisterUdpSessionResponse_TypeID is the unique identifier for the type RegisterUdpSessionResponse. +const RegisterUdpSessionResponse_TypeID = 0xab6d5210c1f26687 + +func NewRegisterUdpSessionResponse(s *capnp.Segment) (RegisterUdpSessionResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return RegisterUdpSessionResponse{st}, err +} + +func NewRootRegisterUdpSessionResponse(s *capnp.Segment) (RegisterUdpSessionResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return RegisterUdpSessionResponse{st}, err +} + +func ReadRootRegisterUdpSessionResponse(msg *capnp.Message) (RegisterUdpSessionResponse, error) { + root, err := msg.RootPtr() + return RegisterUdpSessionResponse{root.Struct()}, err +} + +func (s RegisterUdpSessionResponse) String() string { + str, _ := text.Marshal(0xab6d5210c1f26687, s.Struct) + return str +} + +func (s RegisterUdpSessionResponse) Err() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s RegisterUdpSessionResponse) HasErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegisterUdpSessionResponse) ErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s RegisterUdpSessionResponse) SetErr(v string) error { + return s.Struct.SetText(0, v) +} + +func (s RegisterUdpSessionResponse) Spans() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s RegisterUdpSessionResponse) HasSpans() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s RegisterUdpSessionResponse) SetSpans(v []byte) error { + return s.Struct.SetData(1, v) +} + +// RegisterUdpSessionResponse_List is a list of RegisterUdpSessionResponse. +type RegisterUdpSessionResponse_List struct{ capnp.List } + +// NewRegisterUdpSessionResponse creates a new list of RegisterUdpSessionResponse. +func NewRegisterUdpSessionResponse_List(s *capnp.Segment, sz int32) (RegisterUdpSessionResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return RegisterUdpSessionResponse_List{l}, err +} + +func (s RegisterUdpSessionResponse_List) At(i int) RegisterUdpSessionResponse { + return RegisterUdpSessionResponse{s.List.Struct(i)} +} + +func (s RegisterUdpSessionResponse_List) Set(i int, v RegisterUdpSessionResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s RegisterUdpSessionResponse_List) String() string { + str, _ := text.MarshalList(0xab6d5210c1f26687, s.List) + return str +} + +// RegisterUdpSessionResponse_Promise is a wrapper for a RegisterUdpSessionResponse promised by a client call. +type RegisterUdpSessionResponse_Promise struct{ *capnp.Pipeline } + +func (p RegisterUdpSessionResponse_Promise) Struct() (RegisterUdpSessionResponse, error) { + s, err := p.Pipeline.Struct() + return RegisterUdpSessionResponse{s}, err +} + +type SessionManager struct{ Client capnp.Client } + +// SessionManager_TypeID is the unique identifier for the type SessionManager. +const SessionManager_TypeID = 0x839445a59fb01686 + +func (c SessionManager) RegisterUdpSession(ctx context.Context, params func(SessionManager_registerUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_registerUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 16, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_registerUdpSession_Params{Struct: s}) } + } + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c SessionManager) UnregisterUdpSession(ctx context.Context, params func(SessionManager_unregisterUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_unregisterUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 2} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_unregisterUdpSession_Params{Struct: s}) } + } + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type SessionManager_Server interface { + RegisterUdpSession(SessionManager_registerUdpSession) error + + UnregisterUdpSession(SessionManager_unregisterUdpSession) error +} + +func SessionManager_ServerToClient(s SessionManager_Server) SessionManager { + c, _ := s.(server.Closer) + return SessionManager{Client: server.New(SessionManager_Methods(nil, s), c)} +} + +func SessionManager_Methods(methods []server.Method, s SessionManager_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 2) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_registerUdpSession{c, opts, SessionManager_registerUdpSession_Params{Struct: p}, SessionManager_registerUdpSession_Results{Struct: r}} + return s.RegisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_unregisterUdpSession{c, opts, SessionManager_unregisterUdpSession_Params{Struct: p}, SessionManager_unregisterUdpSession_Results{Struct: r}} + return s.UnregisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + return methods +} + +// SessionManager_registerUdpSession holds the arguments for a server call to SessionManager.registerUdpSession. +type SessionManager_registerUdpSession struct { + Ctx context.Context + Options capnp.CallOptions + Params SessionManager_registerUdpSession_Params + Results SessionManager_registerUdpSession_Results +} + +// SessionManager_unregisterUdpSession holds the arguments for a server call to SessionManager.unregisterUdpSession. +type SessionManager_unregisterUdpSession struct { + Ctx context.Context + Options capnp.CallOptions + Params SessionManager_unregisterUdpSession_Params + Results SessionManager_unregisterUdpSession_Results +} + +type SessionManager_registerUdpSession_Params struct{ capnp.Struct } + +// SessionManager_registerUdpSession_Params_TypeID is the unique identifier for the type SessionManager_registerUdpSession_Params. +const SessionManager_registerUdpSession_Params_TypeID = 0x904e297b87fbecea + +func NewSessionManager_registerUdpSession_Params(s *capnp.Segment) (SessionManager_registerUdpSession_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 3}) + return SessionManager_registerUdpSession_Params{st}, err +} + +func NewRootSessionManager_registerUdpSession_Params(s *capnp.Segment) (SessionManager_registerUdpSession_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 16, PointerCount: 3}) + return SessionManager_registerUdpSession_Params{st}, err +} + +func ReadRootSessionManager_registerUdpSession_Params(msg *capnp.Message) (SessionManager_registerUdpSession_Params, error) { + root, err := msg.RootPtr() + return SessionManager_registerUdpSession_Params{root.Struct()}, err +} + +func (s SessionManager_registerUdpSession_Params) String() string { + str, _ := text.Marshal(0x904e297b87fbecea, s.Struct) + return str +} + +func (s SessionManager_registerUdpSession_Params) SessionId() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s SessionManager_registerUdpSession_Params) HasSessionId() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s SessionManager_registerUdpSession_Params) SetSessionId(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s SessionManager_registerUdpSession_Params) DstIp() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return []byte(p.Data()), err +} + +func (s SessionManager_registerUdpSession_Params) HasDstIp() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s SessionManager_registerUdpSession_Params) SetDstIp(v []byte) error { + return s.Struct.SetData(1, v) +} + +func (s SessionManager_registerUdpSession_Params) DstPort() uint16 { + return s.Struct.Uint16(0) +} + +func (s SessionManager_registerUdpSession_Params) SetDstPort(v uint16) { + s.Struct.SetUint16(0, v) +} + +func (s SessionManager_registerUdpSession_Params) CloseAfterIdleHint() int64 { + return int64(s.Struct.Uint64(8)) +} + +func (s SessionManager_registerUdpSession_Params) SetCloseAfterIdleHint(v int64) { + s.Struct.SetUint64(8, uint64(v)) +} + +func (s SessionManager_registerUdpSession_Params) TraceContext() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s SessionManager_registerUdpSession_Params) HasTraceContext() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s SessionManager_registerUdpSession_Params) TraceContextBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s SessionManager_registerUdpSession_Params) SetTraceContext(v string) error { + return s.Struct.SetText(2, v) +} + +// SessionManager_registerUdpSession_Params_List is a list of SessionManager_registerUdpSession_Params. +type SessionManager_registerUdpSession_Params_List struct{ capnp.List } + +// NewSessionManager_registerUdpSession_Params creates a new list of SessionManager_registerUdpSession_Params. +func NewSessionManager_registerUdpSession_Params_List(s *capnp.Segment, sz int32) (SessionManager_registerUdpSession_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 16, PointerCount: 3}, sz) + return SessionManager_registerUdpSession_Params_List{l}, err +} + +func (s SessionManager_registerUdpSession_Params_List) At(i int) SessionManager_registerUdpSession_Params { + return SessionManager_registerUdpSession_Params{s.List.Struct(i)} +} + +func (s SessionManager_registerUdpSession_Params_List) Set(i int, v SessionManager_registerUdpSession_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s SessionManager_registerUdpSession_Params_List) String() string { + str, _ := text.MarshalList(0x904e297b87fbecea, s.List) + return str +} + +// SessionManager_registerUdpSession_Params_Promise is a wrapper for a SessionManager_registerUdpSession_Params promised by a client call. +type SessionManager_registerUdpSession_Params_Promise struct{ *capnp.Pipeline } + +func (p SessionManager_registerUdpSession_Params_Promise) Struct() (SessionManager_registerUdpSession_Params, error) { + s, err := p.Pipeline.Struct() + return SessionManager_registerUdpSession_Params{s}, err +} + +type SessionManager_registerUdpSession_Results struct{ capnp.Struct } + +// SessionManager_registerUdpSession_Results_TypeID is the unique identifier for the type SessionManager_registerUdpSession_Results. +const SessionManager_registerUdpSession_Results_TypeID = 0x8635c6b4f45bf5cd + +func NewSessionManager_registerUdpSession_Results(s *capnp.Segment) (SessionManager_registerUdpSession_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return SessionManager_registerUdpSession_Results{st}, err +} + +func NewRootSessionManager_registerUdpSession_Results(s *capnp.Segment) (SessionManager_registerUdpSession_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return SessionManager_registerUdpSession_Results{st}, err +} + +func ReadRootSessionManager_registerUdpSession_Results(msg *capnp.Message) (SessionManager_registerUdpSession_Results, error) { + root, err := msg.RootPtr() + return SessionManager_registerUdpSession_Results{root.Struct()}, err +} + +func (s SessionManager_registerUdpSession_Results) String() string { + str, _ := text.Marshal(0x8635c6b4f45bf5cd, s.Struct) + return str +} + +func (s SessionManager_registerUdpSession_Results) Result() (RegisterUdpSessionResponse, error) { + p, err := s.Struct.Ptr(0) + return RegisterUdpSessionResponse{Struct: p.Struct()}, err +} + +func (s SessionManager_registerUdpSession_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s SessionManager_registerUdpSession_Results) SetResult(v RegisterUdpSessionResponse) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated RegisterUdpSessionResponse struct, preferring placement in s's segment. +func (s SessionManager_registerUdpSession_Results) NewResult() (RegisterUdpSessionResponse, error) { + ss, err := NewRegisterUdpSessionResponse(s.Struct.Segment()) + if err != nil { + return RegisterUdpSessionResponse{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// SessionManager_registerUdpSession_Results_List is a list of SessionManager_registerUdpSession_Results. +type SessionManager_registerUdpSession_Results_List struct{ capnp.List } + +// NewSessionManager_registerUdpSession_Results creates a new list of SessionManager_registerUdpSession_Results. +func NewSessionManager_registerUdpSession_Results_List(s *capnp.Segment, sz int32) (SessionManager_registerUdpSession_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return SessionManager_registerUdpSession_Results_List{l}, err +} + +func (s SessionManager_registerUdpSession_Results_List) At(i int) SessionManager_registerUdpSession_Results { + return SessionManager_registerUdpSession_Results{s.List.Struct(i)} +} + +func (s SessionManager_registerUdpSession_Results_List) Set(i int, v SessionManager_registerUdpSession_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s SessionManager_registerUdpSession_Results_List) String() string { + str, _ := text.MarshalList(0x8635c6b4f45bf5cd, s.List) + return str +} + +// SessionManager_registerUdpSession_Results_Promise is a wrapper for a SessionManager_registerUdpSession_Results promised by a client call. +type SessionManager_registerUdpSession_Results_Promise struct{ *capnp.Pipeline } + +func (p SessionManager_registerUdpSession_Results_Promise) Struct() (SessionManager_registerUdpSession_Results, error) { + s, err := p.Pipeline.Struct() + return SessionManager_registerUdpSession_Results{s}, err +} + +func (p SessionManager_registerUdpSession_Results_Promise) Result() RegisterUdpSessionResponse_Promise { + return RegisterUdpSessionResponse_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type SessionManager_unregisterUdpSession_Params struct{ capnp.Struct } + +// SessionManager_unregisterUdpSession_Params_TypeID is the unique identifier for the type SessionManager_unregisterUdpSession_Params. +const SessionManager_unregisterUdpSession_Params_TypeID = 0x96b74375ce9b0ef6 + +func NewSessionManager_unregisterUdpSession_Params(s *capnp.Segment) (SessionManager_unregisterUdpSession_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return SessionManager_unregisterUdpSession_Params{st}, err +} + +func NewRootSessionManager_unregisterUdpSession_Params(s *capnp.Segment) (SessionManager_unregisterUdpSession_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return SessionManager_unregisterUdpSession_Params{st}, err +} + +func ReadRootSessionManager_unregisterUdpSession_Params(msg *capnp.Message) (SessionManager_unregisterUdpSession_Params, error) { + root, err := msg.RootPtr() + return SessionManager_unregisterUdpSession_Params{root.Struct()}, err +} + +func (s SessionManager_unregisterUdpSession_Params) String() string { + str, _ := text.Marshal(0x96b74375ce9b0ef6, s.Struct) + return str +} + +func (s SessionManager_unregisterUdpSession_Params) SessionId() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s SessionManager_unregisterUdpSession_Params) HasSessionId() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s SessionManager_unregisterUdpSession_Params) SetSessionId(v []byte) error { + return s.Struct.SetData(0, v) +} + +func (s SessionManager_unregisterUdpSession_Params) Message() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s SessionManager_unregisterUdpSession_Params) HasMessage() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s SessionManager_unregisterUdpSession_Params) MessageBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s SessionManager_unregisterUdpSession_Params) SetMessage(v string) error { + return s.Struct.SetText(1, v) +} + +// SessionManager_unregisterUdpSession_Params_List is a list of SessionManager_unregisterUdpSession_Params. +type SessionManager_unregisterUdpSession_Params_List struct{ capnp.List } + +// NewSessionManager_unregisterUdpSession_Params creates a new list of SessionManager_unregisterUdpSession_Params. +func NewSessionManager_unregisterUdpSession_Params_List(s *capnp.Segment, sz int32) (SessionManager_unregisterUdpSession_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return SessionManager_unregisterUdpSession_Params_List{l}, err +} + +func (s SessionManager_unregisterUdpSession_Params_List) At(i int) SessionManager_unregisterUdpSession_Params { + return SessionManager_unregisterUdpSession_Params{s.List.Struct(i)} +} + +func (s SessionManager_unregisterUdpSession_Params_List) Set(i int, v SessionManager_unregisterUdpSession_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s SessionManager_unregisterUdpSession_Params_List) String() string { + str, _ := text.MarshalList(0x96b74375ce9b0ef6, s.List) + return str +} + +// SessionManager_unregisterUdpSession_Params_Promise is a wrapper for a SessionManager_unregisterUdpSession_Params promised by a client call. +type SessionManager_unregisterUdpSession_Params_Promise struct{ *capnp.Pipeline } + +func (p SessionManager_unregisterUdpSession_Params_Promise) Struct() (SessionManager_unregisterUdpSession_Params, error) { + s, err := p.Pipeline.Struct() + return SessionManager_unregisterUdpSession_Params{s}, err +} + +type SessionManager_unregisterUdpSession_Results struct{ capnp.Struct } + +// SessionManager_unregisterUdpSession_Results_TypeID is the unique identifier for the type SessionManager_unregisterUdpSession_Results. +const SessionManager_unregisterUdpSession_Results_TypeID = 0xf24ec4ab5891b676 + +func NewSessionManager_unregisterUdpSession_Results(s *capnp.Segment) (SessionManager_unregisterUdpSession_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return SessionManager_unregisterUdpSession_Results{st}, err +} + +func NewRootSessionManager_unregisterUdpSession_Results(s *capnp.Segment) (SessionManager_unregisterUdpSession_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return SessionManager_unregisterUdpSession_Results{st}, err +} + +func ReadRootSessionManager_unregisterUdpSession_Results(msg *capnp.Message) (SessionManager_unregisterUdpSession_Results, error) { + root, err := msg.RootPtr() + return SessionManager_unregisterUdpSession_Results{root.Struct()}, err +} + +func (s SessionManager_unregisterUdpSession_Results) String() string { + str, _ := text.Marshal(0xf24ec4ab5891b676, s.Struct) + return str +} + +// SessionManager_unregisterUdpSession_Results_List is a list of SessionManager_unregisterUdpSession_Results. +type SessionManager_unregisterUdpSession_Results_List struct{ capnp.List } + +// NewSessionManager_unregisterUdpSession_Results creates a new list of SessionManager_unregisterUdpSession_Results. +func NewSessionManager_unregisterUdpSession_Results_List(s *capnp.Segment, sz int32) (SessionManager_unregisterUdpSession_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return SessionManager_unregisterUdpSession_Results_List{l}, err +} + +func (s SessionManager_unregisterUdpSession_Results_List) At(i int) SessionManager_unregisterUdpSession_Results { + return SessionManager_unregisterUdpSession_Results{s.List.Struct(i)} +} + +func (s SessionManager_unregisterUdpSession_Results_List) Set(i int, v SessionManager_unregisterUdpSession_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s SessionManager_unregisterUdpSession_Results_List) String() string { + str, _ := text.MarshalList(0xf24ec4ab5891b676, s.List) + return str +} + +// SessionManager_unregisterUdpSession_Results_Promise is a wrapper for a SessionManager_unregisterUdpSession_Results promised by a client call. +type SessionManager_unregisterUdpSession_Results_Promise struct{ *capnp.Pipeline } + +func (p SessionManager_unregisterUdpSession_Results_Promise) Struct() (SessionManager_unregisterUdpSession_Results, error) { + s, err := p.Pipeline.Struct() + return SessionManager_unregisterUdpSession_Results{s}, err +} + +type UpdateConfigurationResponse struct{ capnp.Struct } + +// UpdateConfigurationResponse_TypeID is the unique identifier for the type UpdateConfigurationResponse. +const UpdateConfigurationResponse_TypeID = 0xdb58ff694ba05cf9 + +func NewUpdateConfigurationResponse(s *capnp.Segment) (UpdateConfigurationResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return UpdateConfigurationResponse{st}, err +} + +func NewRootUpdateConfigurationResponse(s *capnp.Segment) (UpdateConfigurationResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return UpdateConfigurationResponse{st}, err +} + +func ReadRootUpdateConfigurationResponse(msg *capnp.Message) (UpdateConfigurationResponse, error) { + root, err := msg.RootPtr() + return UpdateConfigurationResponse{root.Struct()}, err +} + +func (s UpdateConfigurationResponse) String() string { + str, _ := text.Marshal(0xdb58ff694ba05cf9, s.Struct) + return str +} + +func (s UpdateConfigurationResponse) LatestAppliedVersion() int32 { + return int32(s.Struct.Uint32(0)) +} + +func (s UpdateConfigurationResponse) SetLatestAppliedVersion(v int32) { + s.Struct.SetUint32(0, uint32(v)) +} + +func (s UpdateConfigurationResponse) Err() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s UpdateConfigurationResponse) HasErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s UpdateConfigurationResponse) ErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s UpdateConfigurationResponse) SetErr(v string) error { + return s.Struct.SetText(0, v) +} + +// UpdateConfigurationResponse_List is a list of UpdateConfigurationResponse. +type UpdateConfigurationResponse_List struct{ capnp.List } + +// NewUpdateConfigurationResponse creates a new list of UpdateConfigurationResponse. +func NewUpdateConfigurationResponse_List(s *capnp.Segment, sz int32) (UpdateConfigurationResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}, sz) + return UpdateConfigurationResponse_List{l}, err +} + +func (s UpdateConfigurationResponse_List) At(i int) UpdateConfigurationResponse { + return UpdateConfigurationResponse{s.List.Struct(i)} +} + +func (s UpdateConfigurationResponse_List) Set(i int, v UpdateConfigurationResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s UpdateConfigurationResponse_List) String() string { + str, _ := text.MarshalList(0xdb58ff694ba05cf9, s.List) + return str +} + +// UpdateConfigurationResponse_Promise is a wrapper for a UpdateConfigurationResponse promised by a client call. +type UpdateConfigurationResponse_Promise struct{ *capnp.Pipeline } + +func (p UpdateConfigurationResponse_Promise) Struct() (UpdateConfigurationResponse, error) { + s, err := p.Pipeline.Struct() + return UpdateConfigurationResponse{s}, err +} + +type ConfigurationManager struct{ Client capnp.Client } + +// ConfigurationManager_TypeID is the unique identifier for the type ConfigurationManager. +const ConfigurationManager_TypeID = 0xb48edfbdaa25db04 + +func (c ConfigurationManager) UpdateConfiguration(ctx context.Context, params func(ConfigurationManager_updateConfiguration_Params) error, opts ...capnp.CallOption) ConfigurationManager_updateConfiguration_Results_Promise { + if c.Client == nil { + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { return params(ConfigurationManager_updateConfiguration_Params{Struct: s}) } + } + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type ConfigurationManager_Server interface { + UpdateConfiguration(ConfigurationManager_updateConfiguration) error +} + +func ConfigurationManager_ServerToClient(s ConfigurationManager_Server) ConfigurationManager { + c, _ := s.(server.Closer) + return ConfigurationManager{Client: server.New(ConfigurationManager_Methods(nil, s), c)} +} + +func ConfigurationManager_Methods(methods []server.Method, s ConfigurationManager_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 1) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := ConfigurationManager_updateConfiguration{c, opts, ConfigurationManager_updateConfiguration_Params{Struct: p}, ConfigurationManager_updateConfiguration_Results{Struct: r}} + return s.UpdateConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + return methods +} + +// ConfigurationManager_updateConfiguration holds the arguments for a server call to ConfigurationManager.updateConfiguration. +type ConfigurationManager_updateConfiguration struct { + Ctx context.Context + Options capnp.CallOptions + Params ConfigurationManager_updateConfiguration_Params + Results ConfigurationManager_updateConfiguration_Results +} + +type ConfigurationManager_updateConfiguration_Params struct{ capnp.Struct } + +// ConfigurationManager_updateConfiguration_Params_TypeID is the unique identifier for the type ConfigurationManager_updateConfiguration_Params. +const ConfigurationManager_updateConfiguration_Params_TypeID = 0xb177ca2526a3ca76 + +func NewConfigurationManager_updateConfiguration_Params(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Params{st}, err +} + +func NewRootConfigurationManager_updateConfiguration_Params(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Params{st}, err +} + +func ReadRootConfigurationManager_updateConfiguration_Params(msg *capnp.Message) (ConfigurationManager_updateConfiguration_Params, error) { + root, err := msg.RootPtr() + return ConfigurationManager_updateConfiguration_Params{root.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Params) String() string { + str, _ := text.Marshal(0xb177ca2526a3ca76, s.Struct) + return str +} + +func (s ConfigurationManager_updateConfiguration_Params) Version() int32 { + return int32(s.Struct.Uint32(0)) +} + +func (s ConfigurationManager_updateConfiguration_Params) SetVersion(v int32) { + s.Struct.SetUint32(0, uint32(v)) +} + +func (s ConfigurationManager_updateConfiguration_Params) Config() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s ConfigurationManager_updateConfiguration_Params) HasConfig() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConfigurationManager_updateConfiguration_Params) SetConfig(v []byte) error { + return s.Struct.SetData(0, v) +} + +// ConfigurationManager_updateConfiguration_Params_List is a list of ConfigurationManager_updateConfiguration_Params. +type ConfigurationManager_updateConfiguration_Params_List struct{ capnp.List } + +// NewConfigurationManager_updateConfiguration_Params creates a new list of ConfigurationManager_updateConfiguration_Params. +func NewConfigurationManager_updateConfiguration_Params_List(s *capnp.Segment, sz int32) (ConfigurationManager_updateConfiguration_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}, sz) + return ConfigurationManager_updateConfiguration_Params_List{l}, err +} + +func (s ConfigurationManager_updateConfiguration_Params_List) At(i int) ConfigurationManager_updateConfiguration_Params { + return ConfigurationManager_updateConfiguration_Params{s.List.Struct(i)} +} + +func (s ConfigurationManager_updateConfiguration_Params_List) Set(i int, v ConfigurationManager_updateConfiguration_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConfigurationManager_updateConfiguration_Params_List) String() string { + str, _ := text.MarshalList(0xb177ca2526a3ca76, s.List) + return str +} + +// ConfigurationManager_updateConfiguration_Params_Promise is a wrapper for a ConfigurationManager_updateConfiguration_Params promised by a client call. +type ConfigurationManager_updateConfiguration_Params_Promise struct{ *capnp.Pipeline } + +func (p ConfigurationManager_updateConfiguration_Params_Promise) Struct() (ConfigurationManager_updateConfiguration_Params, error) { + s, err := p.Pipeline.Struct() + return ConfigurationManager_updateConfiguration_Params{s}, err +} + +type ConfigurationManager_updateConfiguration_Results struct{ capnp.Struct } + +// ConfigurationManager_updateConfiguration_Results_TypeID is the unique identifier for the type ConfigurationManager_updateConfiguration_Results. +const ConfigurationManager_updateConfiguration_Results_TypeID = 0x958096448eb3373e + +func NewConfigurationManager_updateConfiguration_Results(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Results{st}, err +} + +func NewRootConfigurationManager_updateConfiguration_Results(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Results{st}, err +} + +func ReadRootConfigurationManager_updateConfiguration_Results(msg *capnp.Message) (ConfigurationManager_updateConfiguration_Results, error) { + root, err := msg.RootPtr() + return ConfigurationManager_updateConfiguration_Results{root.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Results) String() string { + str, _ := text.Marshal(0x958096448eb3373e, s.Struct) + return str +} + +func (s ConfigurationManager_updateConfiguration_Results) Result() (UpdateConfigurationResponse, error) { + p, err := s.Struct.Ptr(0) + return UpdateConfigurationResponse{Struct: p.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConfigurationManager_updateConfiguration_Results) SetResult(v UpdateConfigurationResponse) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated UpdateConfigurationResponse struct, preferring placement in s's segment. +func (s ConfigurationManager_updateConfiguration_Results) NewResult() (UpdateConfigurationResponse, error) { + ss, err := NewUpdateConfigurationResponse(s.Struct.Segment()) + if err != nil { + return UpdateConfigurationResponse{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// ConfigurationManager_updateConfiguration_Results_List is a list of ConfigurationManager_updateConfiguration_Results. +type ConfigurationManager_updateConfiguration_Results_List struct{ capnp.List } + +// NewConfigurationManager_updateConfiguration_Results creates a new list of ConfigurationManager_updateConfiguration_Results. +func NewConfigurationManager_updateConfiguration_Results_List(s *capnp.Segment, sz int32) (ConfigurationManager_updateConfiguration_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return ConfigurationManager_updateConfiguration_Results_List{l}, err +} + +func (s ConfigurationManager_updateConfiguration_Results_List) At(i int) ConfigurationManager_updateConfiguration_Results { + return ConfigurationManager_updateConfiguration_Results{s.List.Struct(i)} +} + +func (s ConfigurationManager_updateConfiguration_Results_List) Set(i int, v ConfigurationManager_updateConfiguration_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConfigurationManager_updateConfiguration_Results_List) String() string { + str, _ := text.MarshalList(0x958096448eb3373e, s.List) + return str +} + +// ConfigurationManager_updateConfiguration_Results_Promise is a wrapper for a ConfigurationManager_updateConfiguration_Results promised by a client call. +type ConfigurationManager_updateConfiguration_Results_Promise struct{ *capnp.Pipeline } + +func (p ConfigurationManager_updateConfiguration_Results_Promise) Struct() (ConfigurationManager_updateConfiguration_Results, error) { + s, err := p.Pipeline.Struct() + return ConfigurationManager_updateConfiguration_Results{s}, err +} + +func (p ConfigurationManager_updateConfiguration_Results_Promise) Result() UpdateConfigurationResponse_Promise { + return UpdateConfigurationResponse_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type CloudflaredServer struct{ Client capnp.Client } + +// CloudflaredServer_TypeID is the unique identifier for the type CloudflaredServer. +const CloudflaredServer_TypeID = 0xf548cef9dea2a4a1 + +func (c CloudflaredServer) RegisterUdpSession(ctx context.Context, params func(SessionManager_registerUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_registerUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 16, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_registerUdpSession_Params{Struct: s}) } + } + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c CloudflaredServer) UnregisterUdpSession(ctx context.Context, params func(SessionManager_unregisterUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_unregisterUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 2} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_unregisterUdpSession_Params{Struct: s}) } + } + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c CloudflaredServer) UpdateConfiguration(ctx context.Context, params func(ConfigurationManager_updateConfiguration_Params) error, opts ...capnp.CallOption) ConfigurationManager_updateConfiguration_Results_Promise { + if c.Client == nil { + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { return params(ConfigurationManager_updateConfiguration_Params{Struct: s}) } + } + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type CloudflaredServer_Server interface { + RegisterUdpSession(SessionManager_registerUdpSession) error + + UnregisterUdpSession(SessionManager_unregisterUdpSession) error + + UpdateConfiguration(ConfigurationManager_updateConfiguration) error +} + +func CloudflaredServer_ServerToClient(s CloudflaredServer_Server) CloudflaredServer { + c, _ := s.(server.Closer) + return CloudflaredServer{Client: server.New(CloudflaredServer_Methods(nil, s), c)} +} + +func CloudflaredServer_Methods(methods []server.Method, s CloudflaredServer_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 3) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_registerUdpSession{c, opts, SessionManager_registerUdpSession_Params{Struct: p}, SessionManager_registerUdpSession_Results{Struct: r}} + return s.RegisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_unregisterUdpSession{c, opts, SessionManager_unregisterUdpSession_Params{Struct: p}, SessionManager_unregisterUdpSession_Results{Struct: r}} + return s.UnregisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := ConfigurationManager_updateConfiguration{c, opts, ConfigurationManager_updateConfiguration_Params{Struct: p}, ConfigurationManager_updateConfiguration_Results{Struct: r}} + return s.UpdateConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + return methods +} + +const schema_db8274f9144abc7e = "x\xda\xccZ}t\x1c\xd5u\xbfw\xdeJ#)\x92" + + "G\xe3Y,i\x8fU5\x8a\xd56NL\x91U\xa7" + + "\xe0\x86H2\xb2\x83\x8cm4Z;\x87\x1a\x93\xc3h" + + "\xf7I\x1awwf33+\xec\x04\xe2\x8f\xd8\x188" + + "@\xb0c\x03v\xe2\xc6\x98\xa4=5I\x8a\x83i\x9a" + + "\x1ehq\x1aB\x80\xe0\x00\x07R\x7f\x90\xa6\xae\xe3\xb6" + + "\xf1\xb1KmpsL\x03\xd3sgv>\xb4\xbbF" + + "6\xed\x1f\xf9ou\xe7\xce{\xf7\xfe\xde\xef~\xbc;" + + "\xba\xaaP\xdf't\xd7\xac\x93\x00\xd4\x035\xb5.\x9f" + + "\xfd\xea\xe7\xf7t\xfd\xe3FP\x9b\x11\xdd/>\xb58" + + "y\xc1\xd9x\x0cj\x98\x08\xd0\xf3Mq6*O\x8a" + + "\"\x80\xb2_\xfcw@\xf7\xce\x19\x8f\x7f\xfd\x9b\x0b\xb7" + + "\x7f\x09\xe4f\x16)\x03*;\xea\xdeQ\x1e\xa9#\xc5" + + "\xddu[\x94w\xe9\x97{\x83\xfc\x877'_>D" + + "\xca\xf1\x95\x13\xa4u\xb2\xee\xa8r\xd6\xd3?SG\x0b" + + "\x7f2\xff\xd3\xbd\x9f\xd8\xf1\xe2&\x90\x9b\x85\xf8\xc2=" + + "G\xeag\xa3\xf2\xabz\xef\x9d\xfa\x1b\x01\xdd\xb7\xb6\xb7" + + ">\xf6\xc8\xa1\x1fo\x06\xb9\x1d\xa1d\xe7\xbb\xf5\xf7 " + + "\xa0\xd2\xd4\xf0\xd7\x80\xeeK\xe7o~\xfb\xc0\x8f\xe6\xdd" + + "\x09r\x07) )|\xbb\xe1UR8\xd8\xd0\x0b\xe8" + + "\x9e:\xfd?[\xbe\xf0\xd1e\x0f\x80\xda\x81B\xb0\xc4" + + "\xf1\x86\x17\x10\xb0\xe7|C\x07\x02\xba\xbd\x8b^\xfa~" + + "\xaa\xe7\xc1\xede\xa6\x0bdF[\xe3Q\xa5\xab\x91~" + + "}\xb8\xf16@\xf7S\x7f\xfc\xc4\xfd\x03\x0f\xae\xdf\x01" + + "rW\xb8\xdf\xe6\xc6\x06\x01P\xd9\xddH\xfb\xfd\xf7\xb4" + + "\xaf\x1e*^\xf7\xbd\x07K\x06\xd1\"=O7\x1e%" + + "\x83^\xf1V\xe8\x9c\xe8\xba\xf5\x07\xcf>\xf1\x10\xa8\x1f" + + "At\xdf\x18\xf9\xd8\xebl\xf7\xbec\xb0\x02E\xb2\xaf" + + "\xa7\xbbi\x98t\xafm\"\xdd\x9f~\xfc\xa9\xbf{\xe0" + + "\x89-_\x05\xb5\x1d\x11\x80\xc0\xecy\xa4i\x1f)\xec" + + "o\xa2\xdd\xb6\x1f~zY~\xeb\xae\xbd>>\xde\xf3" + + "#M\xdfEH\xb8\x9b\x06\x7f\x9d_\xf1h\xfa\xd1\x12" + + "r5\xf4\xe8\xf9\xa6\xbd\xe4\xf6\x91&\xcf\xedyGO" + + "\xde\xb8\xf4\xbb\xa3\x7f\x19{\xf7\xfc\xb4\xd9\x02$\xdc-" + + "\xa3\xe7\x0e6\x0f\xe7\x1f\xab\x02H\xcf\x99i+QA" + + "\x89\x10yw\x1a\xd9\xf8\xed\xdf\xb9\xa1~\xcd\xc9E\x8f" + + "\x83\xfc\x91`\x99?\x95\x04Zf\xe2\x85G\x7f\xaf\xeb" + + "\x85\xdb\xf6\x83\xda\x85!XK\xe9\x19*\x9aD\xef&" + + "\x8eu\xed{\xfa\x17\xf7\x1f(gX\xcf\xd3\xd2lT" + + "^\xf2vy^\xfa\xb4r\x81~\xb9\x89[\xd8{\xda" + + "\xc3\xffp\xa0\x9c\xbd\x9e]\xc7\xa5\xe9\xa8\x9c\x95<\x13" + + "%\xcf\xbf{\x0e\xee\xfaX\xdd\xd7\xdfz\xb2\xaa\xfa\x15" + + "\xf2tT\xbad\xef`e\"\xd2\x15\x83\xf8\xc63\xdd" + + "\x89\xef\xc5\x99\xf6\xac\xbc\x8b\xa0~\xddSh?\xb3\xa0" + + "\xc9xs\xe33e\xa0\x90\xa2R\x9c\xfe\x8e\xb2a:" + + "\xfd\xbac:\xe9&>\xb1z\x8b|\xe2g\xcf\xfa\x98" + + "\xf8\x8e\xb7))r\xbc[\xa1s[|\xf3W\xb6\xd5" + + "\x9c\xfc\xcasd[,\x04j(HzT\xa5\x13\x15" + + "M\xa1\x9f\xb7(-\x0c\xd0M=\xfe'\xdfY\x90=" + + "\xf2b5\x86\xfe|\xc6^\xe5\xe4\x0c\xfau|\x06a" + + "zb\xce\xfe/\xfc\xea\xbeW^+9\xe2\xed}m" + + "\x8b\xc7\x99\xa5-\xb4\xf7\x85U{n\xd0\xdd\x9b\x8e\x95" + + "\xe3\xe2i\xe6[FP\xd9\xd0\xe2\xb9\xd2B\xcb\x85\x04" + + "\xad\xa6}\xbc%\x85\xcaYO\xfb\x8c\xb7\xb6pRk" + + "[\xff\xb3O\xbd\x11\xe3\xd4\xd9\x96m\xc4\xc7e\x9f\xb9" + + "yu\xfd\x1d'N\xc4\xcd:\xde\xe2\xe1{\xd6{\xf5" + + "\xef\xff\xe9\xa1\xf1[\xbes\xe8d\x8cGrk'\xf1" + + "\xe8?\xff\xe2\xd4\x97O\xe7\xb3\xff\xe6ELp6\xf5" + + "\xad\xa7\xe8\xdd\xf6VJ(-\x1dM\x0b;\x0f\x0f\x9d" + + "\x8a\xe3}\xa1\xf5\x9c\x97&\xdah\xf1y\xb7\xf6\xf3U" + + "W\xdft\xaa\"\x95\xcdi;\xa5\\\xd3F\xfa\xf3\xda" + + "\xb6\xa0\xb24\xd5\x02\xe0N\xfc\xcd\xd6\x9b\x1e\xfb\xe1\xb2" + + "s~\x0c{\xa6\\\x9b\xfaW\xf2\xe2\xfe/\x0e\xdcx" + + "M\xe7\xc1sq/\xbaS\x14UJ\x7f\x8a6\x1a\xbd" + + "\xfa\xf4\xa7\xbb\xee\xff\xd1\xb9\xb2\x93\"EEK\x1dU" + + "\xf2)\xfa\xa5\x93\xee\x9b\x8b\xfe\xfc\xb5\x94\x94z\xbb\x0c" + + "\xd8ZZ\xf3\xbeT\x0a\x95\xdd\xa4\xdb\xb33\xf5\x1c\xb1" + + "\xf9\x91o\xec\xfd\xe7\x0b\x87\xae?_\x11*\xf7\xcd\x9c" + + "\x8e\xca\xee\x99\xb4\xec\xce\x99\xa2\xb2s\xe6\xef\x03\xb8w" + + "\x1e\xfb\xec\x9aW\xbf\xf4\xd6\xf9r~y\x06o\x9d)" + + "Do\x10]\x1fZ\xfe\x1f\xebN\xef\x98\xf1\xeb\x8a\xb5" + + "\xafiO\xa12\xd8N\x9a\x0b\xdb\x9fS\x8e\xd0/\xf7" + + "e\xf1\xd1\xee\x81u/^\x88\x1d\xd4\xc1\xf6w\x08\x9d" + + "\x07\xc5\xaf\x9dX\xff\x8b\xcf\xfe&\x8e\xce\x93\xed\xdb\x08" + + "\x9dg\xdb\x09\x9d\xdb\xdf\xdcy\xfd\x97W}\xeb\xbd\x18" + + "=N\xb6\xcf\xa53v\x8a\x86\xc1sV\x81e\xae\xcc" + + "h\x05\xa30\xbf\xbf\xe8\x8cs\xc3\xd13\x9a\xc3\x87\xb9" + + "]\x90L\xc3\xe6C\x88j3K\x00$\x10@\xd6V" + + "\x03\xa8\xb72Ts\x02\xca\x88I\"\x87\xac\x93p\x9c" + + "\xa1\xea\x08(\x0bB\x92\x92\xaa\xfc\xb9N\x005\xc7P" + + "]# \xb2$2\x00\xb9\xb8\x0d@]\xc3P\xdd$" + + "\xa0[\xe0V^3\xb8\x01\x92\xb3\xd0\xb2\xb0\x11\x04l" + + "\x04t-\xeeXk\xb5\x91\x1cH<&\x16W\xdf\xe6" + + "`\x13\x08\xd8\x04\xe8\x8e\x9bE\xcb^a8\xa8\xe7\x86" + + "\xf9\xa8\xc5m\x1c\xc7Z\x10\xb0\x160tJ\x08\x9cJ" + + "s\xdb\xd6Mci\xaffhc\xdc\"w\xeaX\x0d" + + "@X\xa50\xa8gr\xf7.\x10\xe49\"F\x15\x05" + + "\x03Z\xca\x1f\xde\x07\x82\xdc.\xba\x16\x1f\xd3m\x87[" + + "\xb8\"[\xf0\x96f\xa6\xd1\x87n\xd1\xf0\x1f \xb7\xfc" + + "\x07\x12m\xda\x87CX\xc5\xa4\xebr:7\x9cA\x83" + + "\x8d\x9ae\xe8.\xae\x86\xee\xe2\x12\xba\x9bb\xe8nX" + + "\x00\xa0\xde\xceP\xbdK@\x99\x95\xe0\xdd<\x1b@]" + + "\xcfP\xbdW@7\xe3o\x92\x05\x80\x10\xb8Q\xae9" + + "E\x8b\xdb$\x9b\x068\xc4\xd0\xc3w\x1a\xe0\xba\x09n" + + "\x91\xc5\x01\xde\x92fe\xc6\xc33\xa9`\xca\xc25\xba" + + "\xed\xe8\xc6\xd8r\xef\xc1\x90)\xe5\xf4\xccZ\xf2\xa5\xd1" + + "\xb3\xae}>\x00\xa2|\xc5J\x00\x14dy\x01@\xaf" + + ">f\x98\x16w\xb3\xba\x9d1\x0d\x83\x03\xcb8\xebF" + + "\xb4\x9cfdx\xb8|M\xb0\xbc\xbfl\x9a[\x13\xdc" + + "\xbaR\x8b\xb1r\xd6\x90fiy\x1b@m\x0cA[" + + "\xb8\x12@\x1d`\xa8\x0e\xc5@[J\xa0-a\xa8\xde" + + "\x14\x03m\x05\x816\xc4P]%\xa0kZ\xfa\x98n" + + "\\\xc7\x81Yqf\xd9\x8e\xa1\xe59\x01Tr~\x9d" + + "Ypt\xd3\xb0\xb19\xaa \x80\xd8\x1c\x83\xa5\xb6\x9c" + + "k>\xd5\xae\x0c\xb8\x12P\xc54f\x0ds\xbb\x98s" + + "\xd0V\x13\xa1\xfdM\xf3\x01\xd4:\x86jR\xc0^\xcb" + + "\x7f\xde\x1c\xf5\x03\x1f|\xaf\x10\xabd\xb8\xd7\x1d\xc31" + + "\xda\x04Xm\x9e\x1b\xd1\x06KP\xddMPmb\xa8" + + ">@\xfcB\x9f_\xf7\xed\x02P\x1f`\xa8~M@" + + "9!$1\x81(\xef\xa4\xe0\x7f\x98\xa1\xfa\x0d\x01]" + + "\xdb\xdfz\x100\x1b`\xda\x91\xb5\x9d\xc1B\xf0\xd7\xba" + + "\xac\xed\x0c\x99\x96\x83\"\x08(\x02\xd1\xd4\xb4y\xff(" + + "\x05\xce`6\xc7\xaf\xd7\x99\xe1`\x0d\x08XCN[" + + "Z\x86_gR\x8a\xe0k\x9c\xd2\x89\x80\x8c\x0d\x00\x95" + + "Q\xe5\x93\xa6\xbf\xc8\x9cq?\xc8\x03\xa7?J\x04\xf9" + + "\x03\x86\xea\x1f\xc5\x9c\xee&\xb3\xafb\xa8~R@W" + + "\xcbd\xcc\xa2\xe1,\x07\xa6\x8d\x95\x91>\xcdA\xcaX" + + "<\xa2H\xb0\xaf\x18F\xb3i\x8c\xeacEKsb" + + "\xc7Q,d5\x87OzT:{:\x90\xa9\x0e?" + + "l\x17.\xf1\xf0\x83\xecSq\xfc,o\xc7\x81\x18\xae" + + "\x06\x04\x9d\xf4\xc7\x19\xaaWW?\xbfuyn\xdb\xda" + + "\x18\xafH\x06\x89\x18\x00\x06\xcf\x90\x8bT4\xa8f\\" + + "\xe9y\x82\x0e\xed\xdd\xe8\xba\xfe\xe6\xc4\xb2Y\x0c\xd5\xab" + + "\x04l\xc2\xf7\\\x7f\xf79\xdb\xa2c\xe8\xe0\x96eZ" + + "\xd8\x1cU\xd2\x92\xfb\x99\xd2\x06h\x1a\x03\xdc\xd1\xf4\x1c" + + "R4\x86\xbdf\x19H\xd5\xb3H\x04\x91/\x9e5\xa4" + + "I\x14\x1e\xf1\xb3 z73Tg\x0a\xe8\x8e\x11\xf5" + + "\x86\xb8\x85\xba\x99]\xa6\x19f\x9a\xf1L\xc4\xcb\xcb\xdb" + + "j\x98wx'?\xc5{\x16/9\x1aZh\x89d" + + "a,|;\xa3\xfa\x19\x1e\xe0\x86\x91(|\xc3Tw" + + "7q\xfe.\x86\xea\xf6X}\xd8\xba8\x1e\xbf\x89$" + + "&\x00\xe4\x9dt\xfe\xdb\x19\xaa{\x84\xc9U\x96Op" + + "\xc3\x19\xd0\xc7@\xe4v$%\x13\x07\xf41\x0e\xcc\xfe" + + "\xbf\xa6M\xb1*\x0a\xe6\x88m\xe6\xb8\xc3\x07x&\xa7" + + "Q\xe4Lp\xffy\x89f\xc1\xc1U\xf2p\xb8\"\x06" + + "|>2\xbf\x89\x89\xc5AgD\xc5\x10\xc69s\xa3" + + "\xe0\x10y\xd4yt\xd8\x05\xcd\xb0+\xc2_\x9a\xbc\xab" + + "\x1f\xe2\x15\x04\x88B#\x8c\xfe`\x81\xcbN\x1f\xa5t" + + "\x1e\xf7cA\xe4G\xe8\xc6\xfc\xc8\x8d\xb0\xa8'@\xc0" + + "\x04`o\xc6[\xb0\xc2\x17v1[\xa4\xa0cJx" + + "\x1dSp\xb9\xc4\xe0F.\xcb{A\x90\x9bD7\xb0" + + "\x17\x83\xf7\xc5\x8a\xee\x87U\xa6\x8b\x1b=\x8a\xa0M;" + + "\xc4H>\xbf\x1a\xc9\xad*5jc\x9c\xe3\xa5\x1a\xb5" + + "uWDg\xbfF\x01\xc8\xbb\xf7\x02\xa8{\x18\xaa\xdf" + + "\x12\xb0\xd7o\x8c\xb09\x1a\x9f\x94x\xe9w\x04KL" + + "\xe8\xc8h\xb9\xa8d\xb9\x16/\xe4\xb4\x0c_\x88\xa5\xa6" + + "\x07\x10A@\xf4\x82!_\xb0\xb8m\xa3n\x1ajQ" + + "\xcb\xe9\xccY\x1bv\xa2F1?d\xf1\x09\x1d\xcd\xa2" + + "\xdd\xef8\xa4a\xf0\xbdLVw\x81 /\xa5\x1c" + + "\x16|\xd5\xc1\xe0{\xae\xdc\xbf\x0f\x04\xf9Z\xcaa\xc1" + + "\xe7L\x0c\xbe\xd1\xc9\xdd/\x80 w\xc7\xbe@\x04\xa8" + + "T|\x81\xf0\x1fH\x8e\xee?(\x15;\xa1\xbc\xdaQ" + + "n\x89\xdf\xd0k/w\xa4\xd1\xeb\x8f$.g~\x7f" + + "\xc9\xf3\xef\xf0\xdf&\xfe\x7f\x86FA\xad\xfa\xdf\x00\x00" + + "\x00\xff\xff\xe0\xef7\xb3" + +func init() { + schemas.Register(schema_db8274f9144abc7e, + 0x82c325a07ad22a65, + 0x839445a59fb01686, + 0x83ced0145b2f114b, + 0x84cb9536a2cf6d3c, + 0x85c8cea1ab1894f3, + 0x8635c6b4f45bf5cd, + 0x904e297b87fbecea, + 0x9496331ab9cd463f, + 0x958096448eb3373e, + 0x96b74375ce9b0ef6, + 0x97b3c5c260257622, + 0x9b87b390babc2ccf, + 0xa29a916d4ebdd894, + 0xa353a3556df74984, + 0xa766b24d4fe5da35, + 0xab6d5210c1f26687, + 0xb046e578094b1ead, + 0xb177ca2526a3ca76, + 0xb48edfbdaa25db04, + 0xb4bf9861fe035d04, + 0xb5f39f082b9ac18a, + 0xb70431c0dc014915, + 0xc082ef6e0d42ed1d, + 0xc5d6e311876a3604, + 0xc793e50592935b4a, + 0xcbd96442ae3bb01a, + 0xd4d18de97bb12de3, + 0xdb58ff694ba05cf9, + 0xdbaa9d03d52b62dc, + 0xdc3ed6801961e502, + 0xe3e37d096a5b564e, + 0xe5ceae5d6897d7be, + 0xe6646dec8feaa6ee, + 0xea50d822450d1f17, + 0xea58385c65416035, + 0xf24ec4ab5891b676, + 0xf2c122394f447e8e, + 0xf2c68e2547ec3866, + 0xf41a0f001ad49e46, + 0xf548cef9dea2a4a1, + 0xf5f383d2785edb86, + 0xf71695ec7fe85497, + 0xf9cb7f4431a307d0, + 0xfc5edf80e39c0796, + 0xfeac5c8f4899ef7c) +} diff --git a/release/DEFAULT_BUILD_TAGS b/release/DEFAULT_BUILD_TAGS index 4374ea93b6..cc2c039d8e 100644 --- a/release/DEFAULT_BUILD_TAGS +++ b/release/DEFAULT_BUILD_TAGS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/DEFAULT_BUILD_TAGS_OTHERS b/release/DEFAULT_BUILD_TAGS_OTHERS index 814b53f063..7100c5ad58 100644 --- a/release/DEFAULT_BUILD_TAGS_OTHERS +++ b/release/DEFAULT_BUILD_TAGS_OTHERS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/DEFAULT_BUILD_TAGS_WINDOWS b/release/DEFAULT_BUILD_TAGS_WINDOWS index 746827a736..7d5dd55ad8 100644 --- a/release/DEFAULT_BUILD_TAGS_WINDOWS +++ b/release/DEFAULT_BUILD_TAGS_WINDOWS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file From 87a2f4c33621060fe69a2eb6200fa007bc512ebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 10:10:32 +0800 Subject: [PATCH 10/41] Fix cloudflared registration parameter inconsistencies - Set QUIC InitialPacketSize per IP family (IPv4: 1252, IPv6: 1232) - Set MaxIncomingStreams/MaxIncomingUniStreams to 1<<60 - Populate OriginLocalIP from local socket address in both QUIC and HTTP/2 - Pass NumPreviousAttempts from retry counter to registration - Include version number in client version string - Use OS_GOARCH format for Arch field --- protocol/cloudflare/connection_http2.go | 19 ++++--- protocol/cloudflare/connection_quic.go | 63 +++++++++++++-------- protocol/cloudflare/connection_quic_test.go | 25 ++++++++ protocol/cloudflare/control.go | 10 +++- protocol/cloudflare/inbound.go | 16 +++--- 5 files changed, 92 insertions(+), 41 deletions(-) create mode 100644 protocol/cloudflare/connection_quic_test.go diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 50e07d40ae..afe0699bb1 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -40,8 +40,9 @@ type HTTP2Connection struct { gracePeriod time.Duration inbound *Inbound - registrationClient *RegistrationClient - registrationResult *RegistrationResult + numPreviousAttempts uint8 + registrationClient *RegistrationClient + registrationResult *RegistrationResult activeRequests sync.WaitGroup closeOnce sync.Once @@ -55,6 +56,7 @@ func NewHTTP2Connection( credentials Credentials, connectorID uuid.UUID, features []string, + numPreviousAttempts uint8, gracePeriod time.Duration, inbound *Inbound, logger log.ContextLogger, @@ -92,10 +94,11 @@ func NewHTTP2Connection( edgeAddr: edgeAddr, connIndex: connIndex, credentials: credentials, - connectorID: connectorID, - features: features, - gracePeriod: gracePeriod, - inbound: inbound, + connectorID: connectorID, + features: features, + numPreviousAttempts: numPreviousAttempts, + gracePeriod: gracePeriod, + inbound: inbound, }, nil } @@ -149,7 +152,9 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque c.registrationClient = NewRegistrationClient(ctx, stream) - options := BuildConnectionOptions(c.connectorID, c.features, 0) + host, _, _ := net.SplitHostPort(c.conn.LocalAddr().String()) + originLocalIP := net.ParseIP(host) + options := BuildConnectionOptions(c.connectorID, c.features, c.numPreviousAttempts, originLocalIP) result, err := c.registrationClient.RegisterConnection( ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options, ) diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index e7cfc073dd..549a6eef47 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "io" + "net" "sync" "time" @@ -25,18 +26,27 @@ const ( quicKeepAlivePeriod = 1 * time.Second ) +func quicInitialPacketSize(ipVersion int) uint16 { + initialPacketSize := uint16(1252) + if ipVersion == 4 { + initialPacketSize = 1232 + } + return initialPacketSize +} + // QUICConnection manages a single QUIC connection to the Cloudflare edge. type QUICConnection struct { - conn *quic.Conn - logger log.ContextLogger - edgeAddr *EdgeAddr - connIndex uint8 - credentials Credentials - connectorID uuid.UUID - features []string - gracePeriod time.Duration - registrationClient *RegistrationClient - registrationResult *RegistrationResult + conn *quic.Conn + logger log.ContextLogger + edgeAddr *EdgeAddr + connIndex uint8 + credentials Credentials + connectorID uuid.UUID + features []string + numPreviousAttempts uint8 + gracePeriod time.Duration + registrationClient *RegistrationClient + registrationResult *RegistrationResult closeOnce sync.Once } @@ -49,6 +59,7 @@ func NewQUICConnection( credentials Credentials, connectorID uuid.UUID, features []string, + numPreviousAttempts uint8, gracePeriod time.Duration, logger log.ContextLogger, ) (*QUICConnection, error) { @@ -65,10 +76,13 @@ func NewQUICConnection( } quicConfig := &quic.Config{ - HandshakeIdleTimeout: quicHandshakeIdleTimeout, - MaxIdleTimeout: quicMaxIdleTimeout, - KeepAlivePeriod: quicKeepAlivePeriod, - EnableDatagrams: true, + HandshakeIdleTimeout: quicHandshakeIdleTimeout, + MaxIdleTimeout: quicMaxIdleTimeout, + KeepAlivePeriod: quicKeepAlivePeriod, + MaxIncomingStreams: 1 << 60, + MaxIncomingUniStreams: 1 << 60, + EnableDatagrams: true, + InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } conn, err := quic.DialAddr(ctx, edgeAddr.UDP.String(), tlsConfig, quicConfig) @@ -77,14 +91,15 @@ func NewQUICConnection( } return &QUICConnection{ - conn: conn, - logger: logger, - edgeAddr: edgeAddr, - connIndex: connIndex, - credentials: credentials, - connectorID: connectorID, - features: features, - gracePeriod: gracePeriod, + conn: conn, + logger: logger, + edgeAddr: edgeAddr, + connIndex: connIndex, + credentials: credentials, + connectorID: connectorID, + features: features, + numPreviousAttempts: numPreviousAttempts, + gracePeriod: gracePeriod, }, nil } @@ -128,7 +143,9 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) error { q.registrationClient = NewRegistrationClient(ctx, newStreamReadWriteCloser(stream)) - options := BuildConnectionOptions(q.connectorID, q.features, 0) + host, _, _ := net.SplitHostPort(q.conn.LocalAddr().String()) + originLocalIP := net.ParseIP(host) + options := BuildConnectionOptions(q.connectorID, q.features, q.numPreviousAttempts, originLocalIP) result, err := q.registrationClient.RegisterConnection( ctx, q.credentials.Auth(), q.credentials.TunnelID, q.connIndex, options, ) diff --git a/protocol/cloudflare/connection_quic_test.go b/protocol/cloudflare/connection_quic_test.go new file mode 100644 index 0000000000..7ea4a86906 --- /dev/null +++ b/protocol/cloudflare/connection_quic_test.go @@ -0,0 +1,25 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import "testing" + +func TestQUICInitialPacketSize(t *testing.T) { + testCases := []struct { + name string + ipVersion int + expected uint16 + }{ + {name: "ipv4", ipVersion: 4, expected: 1232}, + {name: "ipv6", ipVersion: 6, expected: 1252}, + {name: "default", ipVersion: 0, expected: 1252}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + if actual := quicInitialPacketSize(testCase.ipVersion); actual != testCase.expected { + t.Fatalf("quicInitialPacketSize(%d) = %d, want %d", testCase.ipVersion, actual, testCase.expected) + } + }) + } +} diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index 9f583f29e8..a4130f0969 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -5,9 +5,11 @@ package cloudflare import ( "context" "io" + "net" "runtime" "time" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" E "github.com/sagernet/sing/common/exceptions" @@ -18,9 +20,10 @@ import ( const ( registrationTimeout = 10 * time.Second - clientVersion = "sing-box" ) +var clientVersion = "sing-box " + C.Version + // RegistrationClient handles the Cap'n Proto RPC for tunnel registration. type RegistrationClient struct { client tunnelrpc.TunnelServer @@ -148,14 +151,15 @@ func (c *RegistrationClient) Close() error { } // BuildConnectionOptions creates the ConnectionOptions to send during registration. -func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8) *RegistrationConnectionOptions { +func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions { return &RegistrationConnectionOptions{ Client: RegistrationClientInfo{ ClientID: connectorID[:], Features: features, Version: clientVersion, - Arch: runtime.GOARCH, + Arch: runtime.GOOS + "_" + runtime.GOARCH, }, + OriginLocalIP: originLocalIP, NumPreviousAttempts: numPreviousAttempts, } } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 36c04310ee..dd0bd43985 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -267,7 +267,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))] - err := i.serveConnection(connIndex, edgeAddr, features) + err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries)) if err == nil || i.ctx.Err() != nil { return } @@ -284,7 +284,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } } -func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { +func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { protocol := i.protocol if protocol == "" { protocol = "quic" @@ -292,21 +292,21 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features switch protocol { case "quic": - return i.serveQUIC(connIndex, edgeAddr, features) + return i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts) case "http2": - return i.serveHTTP2(connIndex, edgeAddr, features) + return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) default: return E.New("unsupported protocol: ", protocol) } } -func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { +func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")") connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, i.gracePeriod, i.logger, + features, numPreviousAttempts, i.gracePeriod, i.logger, ) if err != nil { return E.Cause(err, "create QUIC connection") @@ -321,13 +321,13 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri return connection.Serve(i.ctx, i) } -func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string) error { +func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")") connection, err := NewHTTP2Connection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, i.gracePeriod, i, i.logger, + features, numPreviousAttempts, i.gracePeriod, i, i.logger, ) if err != nil { return E.Cause(err, "create HTTP/2 connection") From 01a8405069af9ba46fa53175df504c87896fd9ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 11:17:39 +0800 Subject: [PATCH 11/41] Implement router-backed cloudflare tunnel ingress config --- option/cloudflare_tunnel.go | 66 +- protocol/cloudflare/connection_http2.go | 16 +- protocol/cloudflare/control.go | 1 - protocol/cloudflare/datagram_v2.go | 9 +- protocol/cloudflare/dispatch.go | 295 ++++++--- protocol/cloudflare/helpers_test.go | 6 + protocol/cloudflare/inbound.go | 202 +++--- protocol/cloudflare/ingress_test.go | 197 +++--- protocol/cloudflare/runtime_config.go | 803 ++++++++++++++++++++++++ 9 files changed, 1290 insertions(+), 305 deletions(-) create mode 100644 protocol/cloudflare/runtime_config.go diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index a1a2c44425..c0fdbfa879 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -3,11 +3,63 @@ package option import "github.com/sagernet/sing/common/json/badoption" type CloudflareTunnelInboundOptions struct { - Token string `json:"token,omitempty"` - CredentialPath string `json:"credential_path,omitempty"` - HAConnections int `json:"ha_connections,omitempty"` - Protocol string `json:"protocol,omitempty"` - EdgeIPVersion int `json:"edge_ip_version,omitempty"` - DatagramVersion string `json:"datagram_version,omitempty"` - GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Token string `json:"token,omitempty"` + CredentialPath string `json:"credential_path,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Region string `json:"region,omitempty"` + Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"` + OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` + WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"` +} + +type CloudflareTunnelIngressRule struct { + Hostname string `json:"hostname,omitempty"` + Path string `json:"path,omitempty"` + Service string `json:"service,omitempty"` + OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` +} + +type CloudflareTunnelOriginRequestOptions struct { + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"` + TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` + NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"` + KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"` + KeepAliveConnections int `json:"keep_alive_connections,omitempty"` + HTTPHostHeader string `json:"http_host_header,omitempty"` + OriginServerName string `json:"origin_server_name,omitempty"` + MatchSNIToHost bool `json:"match_sni_to_host,omitempty"` + CAPool string `json:"ca_pool,omitempty"` + NoTLSVerify bool `json:"no_tls_verify,omitempty"` + DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"` + BastionMode bool `json:"bastion_mode,omitempty"` + ProxyAddress string `json:"proxy_address,omitempty"` + ProxyPort uint `json:"proxy_port,omitempty"` + ProxyType string `json:"proxy_type,omitempty"` + IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"` + HTTP2Origin bool `json:"http2_origin,omitempty"` + Access CloudflareTunnelAccessRule `json:"access,omitempty"` +} + +type CloudflareTunnelAccessRule struct { + Required bool `json:"required,omitempty"` + TeamName string `json:"team_name,omitempty"` + AudTag []string `json:"aud_tag,omitempty"` + Environment string `json:"environment,omitempty"` +} + +type CloudflareTunnelIPRule struct { + Prefix string `json:"prefix,omitempty"` + Ports []int `json:"ports,omitempty"` + Allow bool `json:"allow,omitempty"` +} + +type CloudflareTunnelWarpRoutingOptions struct { + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + MaxActiveFlows uint64 `json:"max_active_flows,omitempty"` + TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index afe0699bb1..42d9bffeaf 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -90,10 +90,10 @@ func NewHTTP2Connection( server: &http2.Server{ MaxConcurrentStreams: math.MaxUint32, }, - logger: logger, - edgeAddr: edgeAddr, - connIndex: connIndex, - credentials: credentials, + logger: logger, + edgeAddr: edgeAddr, + connIndex: connIndex, + credentials: credentials, connectorID: connectorID, features: features, numPreviousAttempts: numPreviousAttempts, @@ -244,9 +244,13 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp w.WriteHeader(http.StatusBadRequest) return } - c.inbound.UpdateIngress(body.Version, body.Config) + result := c.inbound.ApplyConfig(body.Version, body.Config) w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(body.Version), 10) + `,"err":null}`)) + if result.Err != nil { + w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`)) + return + } + w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`)) } func (c *HTTP2Connection) close() { diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index a4130f0969..6e98811142 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -171,7 +171,6 @@ func DefaultFeatures(datagramVersion string) []string { "support_datagram_v2", "support_quic_eof", "allow_remote_config", - "management_logs", } if datagramVersion == "v3" { features = append(features, "support_datagram_v3_2") diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 8159b04cca..2071d86e96 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -318,12 +318,17 @@ func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_u func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { version := call.Params.Version() configData, _ := call.Params.Config() - s.inbound.UpdateIngress(version, configData) + updateResult := s.inbound.ApplyConfig(version, configData) result, err := call.Results.NewResult() if err != nil { return err } - result.SetErr("") + result.SetLatestAppliedVersion(updateResult.LastAppliedVersion) + if updateResult.Err != nil { + result.SetErr(updateResult.Err.Error()) + } else { + result.SetErr("") + } return nil } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 7f949a7a40..9263e16c30 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -4,12 +4,16 @@ package cloudflare import ( "context" + "crypto/tls" + "crypto/x509" "io" "net" "net/http" "net/url" + "os" "strconv" "strings" + "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -124,19 +128,42 @@ func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser metadata.Destination = M.ParseSocksaddr(request.Dest) i.handleTCPStream(ctx, stream, respWriter, metadata) case ConnectionTypeHTTP, ConnectionTypeWebsocket: - originURL := i.ResolveOriginURL(request.Dest) - request.Dest = originURL - metadata.Destination = parseHTTPDestination(originURL) - if request.Type == ConnectionTypeHTTP { - i.handleHTTPStream(ctx, stream, respWriter, request, metadata) - } else { - i.handleWebSocketStream(ctx, stream, respWriter, request, metadata) + service, originURL, err := i.resolveHTTPService(request.Dest) + if err != nil { + i.logger.ErrorContext(ctx, "resolve origin service: ", err) + respWriter.WriteResponse(err, nil) + return } + request.Dest = originURL + i.handleHTTPService(ctx, stream, respWriter, request, metadata, service) default: i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type) } } +func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) { + parsedURL, err := url.Parse(requestURL) + if err != nil { + return ResolvedService{}, "", E.Cause(err, "parse request URL") + } + service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path) + if !loaded { + return ResolvedService{}, "", E.New("no ingress rule matched request host/path") + } + if service.Kind == ResolvedServiceHelloWorld { + helloURL, err := i.ensureHelloWorldURL() + if err != nil { + return ResolvedService{}, "", err + } + service.BaseURL = helloURL + } + originURL, err := service.BuildRequestURL(requestURL) + if err != nil { + return ResolvedService{}, "", E.Cause(err, "build origin request URL") + } + return service, originURL, nil +} + func parseHTTPDestination(dest string) M.Socksaddr { parsed, err := url.Parse(dest) if err != nil { @@ -172,10 +199,81 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser <-done } -func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { +func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + switch service.Kind { + case ResolvedServiceStatus: + err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) + if err != nil { + i.logger.ErrorContext(ctx, "write status service response: ", err) + } + return + case ResolvedServiceHTTP: + metadata.Destination = service.Destination + if request.Type == ConnectionTypeHTTP { + i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service) + } else { + i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service) + } + case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: + if request.Type == ConnectionTypeHTTP { + i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) + } else { + i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service) + } + default: + err := E.New("unsupported service kind for HTTP/WebSocket request") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + } +} + +func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) + + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest) + + transport, cleanup, err := i.newDirectOriginTransport(service) + if err != nil { + i.logger.ErrorContext(ctx, "build direct origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + metadata.Network = N.NetworkTCP + i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest) + + transport, cleanup, err := i.newDirectOriginTransport(service) + if err != nil { + i.logger.ErrorContext(ctx, "build direct origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) +} + +func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) { httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) if err != nil { i.logger.ErrorContext(ctx, "build HTTP request: ", err) @@ -183,23 +281,17 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose return } - input, output := pipe.Pipe() - var innerError error - - done := make(chan struct{}) - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - innerError = it - common.Close(input, output) - close(done) - })) + httpRequest = applyOriginRequest(httpRequest, service.OriginRequest) + requestCtx := httpRequest.Context() + if service.OriginRequest.ConnectTimeout > 0 { + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout) + defer cancel() + httpRequest = httpRequest.WithContext(requestCtx) + } httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return input, nil - }, - }, + Transport: transport, CheckRedirect: func(request *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, @@ -208,87 +300,146 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose response, err := httpClient.Do(httpRequest) if err != nil { - <-done - i.logger.ErrorContext(ctx, "HTTP request: ", E.Errors(innerError, err)) + i.logger.ErrorContext(ctx, "origin request: ", err) respWriter.WriteResponse(err, nil) return } + defer response.Body.Close() responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) err = respWriter.WriteResponse(nil, responseMetadata) if err != nil { - response.Body.Close() - i.logger.ErrorContext(ctx, "write HTTP response headers: ", err) - <-done + i.logger.ErrorContext(ctx, "write origin response headers: ", err) + return + } + + if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols { + rwc, ok := response.Body.(io.ReadWriteCloser) + if !ok { + i.logger.ErrorContext(ctx, "websocket origin response body is not duplex") + return + } + bidirectionalCopy(stream, rwc) return } _, err = io.Copy(stream, response.Body) - response.Body.Close() - common.Close(input, output) if err != nil && !E.IsClosedOrCanceled(err) { i.logger.DebugContext(ctx, "copy HTTP response body: ", err) } - <-done } -func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { - metadata.Network = N.NetworkTCP - i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) - - httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream) - if err != nil { - i.logger.ErrorContext(ctx, "build WebSocket request: ", err) - respWriter.WriteResponse(err, nil) - return - } - +func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) { input, output := pipe.Pipe() - var innerError error - done := make(chan struct{}) go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - innerError = it common.Close(input, output) close(done) })) - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return input, nil - }, - }, - CheckRedirect: func(request *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse + transport := &http.Transport{ + DisableCompression: true, + ForceAttemptHTTP2: originRequest.HTTP2Origin, + TLSHandshakeTimeout: originRequest.TLSTimeout, + IdleConnTimeout: originRequest.KeepAliveTimeout, + MaxIdleConns: originRequest.KeepAliveConnections, + MaxIdleConnsPerHost: originRequest.KeepAliveConnections, + TLSClientConfig: buildOriginTLSConfig(originRequest), + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return input, nil }, } - defer httpClient.CloseIdleConnections() + return transport, func() { + common.Close(input, output) + select { + case <-done: + case <-time.After(time.Second): + } + } +} - response, err := httpClient.Do(httpRequest) - if err != nil { - <-done - i.logger.ErrorContext(ctx, "WebSocket request: ", E.Errors(innerError, err)) - respWriter.WriteResponse(err, nil) - return +func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) { + transport := &http.Transport{ + DisableCompression: true, + ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, + TLSHandshakeTimeout: service.OriginRequest.TLSTimeout, + IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, + MaxIdleConns: service.OriginRequest.KeepAliveConnections, + MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, + TLSClientConfig: buildOriginTLSConfig(service.OriginRequest), + } + switch service.Kind { + case ResolvedServiceUnix, ResolvedServiceUnixTLS: + dialer := &net.Dialer{} + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", service.UnixPath) + } + case ResolvedServiceHelloWorld: + dialer := &net.Dialer{} + target := service.BaseURL.Host + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", target) + } + default: + return nil, nil, E.New("unsupported direct origin service") } + return transport, func() {}, nil +} - responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header) - err = respWriter.WriteResponse(nil, responseMetadata) +func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { + tlsConfig := &tls.Config{ + InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec + ServerName: originRequest.OriginServerName, + } + if originRequest.CAPool == "" { + return tlsConfig + } + pemData, err := os.ReadFile(originRequest.CAPool) if err != nil { - response.Body.Close() - i.logger.ErrorContext(ctx, "write WebSocket response headers: ", err) - <-done - return + return tlsConfig } + pool := x509.NewCertPool() + if pool.AppendCertsFromPEM(pemData) { + tlsConfig.RootCAs = pool + } + return tlsConfig +} - _, err = io.Copy(stream, response.Body) - response.Body.Close() - common.Close(input, output) - if err != nil && !E.IsClosedOrCanceled(err) { - i.logger.DebugContext(ctx, "copy WebSocket response body: ", err) +func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request { + request = request.Clone(request.Context()) + if originRequest.HTTPHostHeader != "" { + request.Header.Set("X-Forwarded-Host", request.Host) + request.Host = originRequest.HTTPHostHeader } + if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" { + if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil { + request.ContentLength = contentLength + request.TransferEncoding = nil + } + } + return request +} + +func bidirectionalCopy(left, right io.ReadWriteCloser) { + var closeOnce sync.Once + closeBoth := func() { + closeOnce.Do(func() { + common.Close(left, right) + }) + } + + done := make(chan struct{}, 2) + go func() { + io.Copy(left, right) + closeBoth() + done <- struct{}{} + }() + go func() { + io.Copy(right, left) + closeBoth() + done <- struct{}{} + }() + <-done <-done } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 82676434dc..8920b7705b 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -142,6 +142,11 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i t.Fatal("create logger: ", err) } + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal("create config manager: ", err) + } + ctx, cancel := context.WithCancel(context.Background()) inboundInstance := &Inbound{ Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), @@ -156,6 +161,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i edgeIPVersion: 0, datagramVersion: "", gracePeriod: 5 * time.Second, + configManager: configManager, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index dd0bd43985..0fcb2ac65b 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -7,9 +7,10 @@ import ( "encoding/base64" "io" "math/rand" + "net" + "net/http" "net/url" "os" - "strings" "sync" "time" @@ -30,17 +31,19 @@ func RegisterInbound(registry *inbound.Registry) { type Inbound struct { inbound.Adapter - ctx context.Context - cancel context.CancelFunc - router adapter.ConnectionRouterEx - logger log.ContextLogger - credentials Credentials - connectorID uuid.UUID - haConnections int - protocol string - edgeIPVersion int - datagramVersion string - gracePeriod time.Duration + ctx context.Context + cancel context.CancelFunc + router adapter.ConnectionRouterEx + logger log.ContextLogger + credentials Credentials + connectorID uuid.UUID + haConnections int + protocol string + region string + edgeIPVersion int + datagramVersion string + gracePeriod time.Duration + configManager *ConfigManager connectionAccess sync.Mutex connections []io.Closer @@ -50,101 +53,9 @@ type Inbound struct { datagramV2Muxers map[DatagramSender]*DatagramV2Muxer datagramV3Muxers map[DatagramSender]*DatagramV3Muxer - ingressAccess sync.RWMutex - ingressVersion int32 - ingressRules []IngressRule -} - -// IngressRule maps a hostname pattern to an origin service URL. -type IngressRule struct { - Hostname string - Service string -} - -type ingressConfig struct { - Ingress []ingressConfigRule `json:"ingress"` -} - -type ingressConfigRule struct { - Hostname string `json:"hostname,omitempty"` - Service string `json:"service"` -} - -// UpdateIngress applies a new ingress configuration from the edge. -func (i *Inbound) UpdateIngress(version int32, config []byte) { - i.ingressAccess.Lock() - defer i.ingressAccess.Unlock() - - if version <= i.ingressVersion { - return - } - - var parsed ingressConfig - err := json.Unmarshal(config, &parsed) - if err != nil { - i.logger.Error("parse ingress config: ", err) - return - } - - rules := make([]IngressRule, 0, len(parsed.Ingress)) - for _, rule := range parsed.Ingress { - rules = append(rules, IngressRule{ - Hostname: rule.Hostname, - Service: rule.Service, - }) - } - i.ingressRules = rules - i.ingressVersion = version - i.logger.Info("updated ingress configuration (version ", version, ", ", len(rules), " rules)") -} - -// ResolveOrigin finds the origin service URL for a given hostname. -// Returns the service URL if matched, or empty string if no match. -func (i *Inbound) ResolveOrigin(hostname string) string { - i.ingressAccess.RLock() - defer i.ingressAccess.RUnlock() - - for _, rule := range i.ingressRules { - if rule.Hostname == "" { - return rule.Service - } - if matchIngress(rule.Hostname, hostname) { - return rule.Service - } - } - return "" -} - -func matchIngress(pattern, hostname string) bool { - if pattern == hostname { - return true - } - if strings.HasPrefix(pattern, "*.") { - suffix := pattern[1:] - return strings.HasSuffix(hostname, suffix) - } - return false -} - -// ResolveOriginURL rewrites a request URL to point to the origin service. -// For example, https://testbox.badnet.work/path → http://127.0.0.1:8083/path -func (i *Inbound) ResolveOriginURL(requestURL string) string { - parsed, err := url.Parse(requestURL) - if err != nil { - return requestURL - } - hostname := parsed.Hostname() - origin := i.ResolveOrigin(hostname) - if origin == "" || strings.HasPrefix(origin, "http_status:") { - return requestURL - } - originURL, err := url.Parse(origin) - if err != nil { - return requestURL - } - parsed.Scheme = originURL.Scheme - parsed.Host = originURL.Host - return parsed.String() + helloWorldAccess sync.Mutex + helloWorldServer *http.Server + helloWorldURL *url.URL } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { @@ -178,23 +89,30 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod = 30 * time.Second } + configManager, err := NewConfigManager(options) + if err != nil { + return nil, E.Cause(err, "build cloudflare tunnel runtime config") + } + inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), - ctx: inboundCtx, - cancel: cancel, - router: router, - logger: logger, - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - edgeIPVersion: edgeIPVersion, - datagramVersion: datagramVersion, - gracePeriod: gracePeriod, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), + ctx: inboundCtx, + cancel: cancel, + router: router, + logger: logger, + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + region: options.Region, + edgeIPVersion: edgeIPVersion, + datagramVersion: datagramVersion, + gracePeriod: gracePeriod, + configManager: configManager, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil } @@ -238,6 +156,16 @@ func (i *Inbound) Start(stage adapter.StartStage) error { return nil } +func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { + result := i.configManager.Apply(version, config) + if result.Err != nil { + i.logger.Error("update ingress configuration: ", result.Err) + return result + } + i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")") + return result +} + func (i *Inbound) Close() error { i.cancel() i.done.Wait() @@ -247,9 +175,41 @@ func (i *Inbound) Close() error { } i.connections = nil i.connectionAccess.Unlock() + if i.helloWorldServer != nil { + i.helloWorldServer.Close() + } return nil } +func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { + i.helloWorldAccess.Lock() + defer i.helloWorldAccess.Unlock() + if i.helloWorldURL != nil { + return i.helloWorldURL, nil + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte("Hello World")) + }) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, E.Cause(err, "listen hello world server") + } + server := &http.Server{Handler: mux} + go server.Serve(listener) + + i.helloWorldServer = server + i.helloWorldURL = &url.URL{ + Scheme: "http", + Host: listener.Addr().String(), + } + return i.helloWorldURL, nil +} + const ( backoffBaseTime = time.Second backoffMaxTime = 2 * time.Minute diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 190eb5b154..03a91f0f0f 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -6,143 +6,148 @@ import ( "testing" "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" ) -func newTestIngressInbound() *Inbound { - return &Inbound{logger: log.NewNOPFactory().NewLogger("test")} -} - -func TestUpdateIngress(t *testing.T) { - inboundInstance := newTestIngressInbound() - - config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`) - inboundInstance.UpdateIngress(1, config1) - - inboundInstance.ingressAccess.RLock() - count := len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 3 { - t.Fatalf("expected 3 rules, got %d", count) +func newTestIngressInbound(t *testing.T) *Inbound { + t.Helper() + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal(err) } - - inboundInstance.UpdateIngress(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) - inboundInstance.ingressAccess.RLock() - count = len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 3 { - t.Error("version 1 re-apply should not change rules, got ", count) + return &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + configManager: configManager, } +} - inboundInstance.UpdateIngress(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) - inboundInstance.ingressAccess.RLock() - count = len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 1 { - t.Error("version 2 should update to 1 rule, got ", count) +func mustResolvedService(t *testing.T, rawService string) ResolvedService { + t.Helper() + service, err := parseResolvedService(rawService, defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) } + return service } -func TestUpdateIngressInvalidJSON(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.UpdateIngress(1, []byte("not json")) +func TestApplyConfig(t *testing.T) { + inboundInstance := newTestIngressInbound(t) - inboundInstance.ingressAccess.RLock() - count := len(inboundInstance.ingressRules) - inboundInstance.ingressAccess.RUnlock() - if count != 0 { - t.Error("invalid JSON should leave rules empty, got ", count) + config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`) + result := inboundInstance.ApplyConfig(1, config1) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 1 { + t.Fatalf("expected version 1, got %d", result.LastAppliedVersion) } -} -func TestResolveOriginExact(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "test.example.com", Service: "http://localhost:8080"}, - {Hostname: "", Service: "http_status:404"}, + service, loaded := inboundInstance.configManager.Resolve("a.com", "/") + if !loaded || service.Service != "http://localhost:80" { + t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded) } - result := inboundInstance.ResolveOrigin("test.example.com") - if result != "http://localhost:8080" { - t.Error("expected http://localhost:8080, got ", result) + result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 1 { + t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion) } -} -func TestResolveOriginWildcard(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "*.example.com", Service: "http://localhost:9090"}, + service, loaded = inboundInstance.configManager.Resolve("b.com", "/") + if !loaded || service.Service != "http://localhost:81" { + t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded) } - result := inboundInstance.ResolveOrigin("sub.example.com") - if result != "http://localhost:9090" { - t.Error("wildcard should match sub.example.com, got ", result) + result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + if result.Err != nil { + t.Fatal(result.Err) + } + if result.LastAppliedVersion != 2 { + t.Fatalf("expected version 2, got %d", result.LastAppliedVersion) } - result = inboundInstance.ResolveOrigin("example.com") - if result != "" { - t.Error("wildcard should not match bare example.com, got ", result) + service, loaded = inboundInstance.configManager.Resolve("anything.com", "/") + if !loaded || service.StatusCode != 503 { + t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded) } } -func TestResolveOriginCatchAll(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "specific.com", Service: "http://localhost:1"}, - {Hostname: "", Service: "http://localhost:2"}, +func TestApplyConfigInvalidJSON(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + result := inboundInstance.ApplyConfig(1, []byte("not json")) + if result.Err == nil { + t.Fatal("expected parse error") } - - result := inboundInstance.ResolveOrigin("anything.com") - if result != "http://localhost:2" { - t.Error("catch-all should match, got ", result) + if result.LastAppliedVersion != -1 { + t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion) } } -func TestResolveOriginNoMatch(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "specific.com", Service: "http://localhost:1"}, +func TestResolveExactAndWildcard(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")}, + {Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, } - result := inboundInstance.ResolveOrigin("other.com") - if result != "" { - t.Error("expected empty for no match, got ", result) + service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/") + if !loaded || service.Service != "http://localhost:8080" { + t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded) } -} -func TestResolveOriginURLRewrite(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "foo.com", Service: "http://127.0.0.1:8083"}, + service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/") + if !loaded || service.Service != "http://localhost:9090" { + t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded) } - result := inboundInstance.ResolveOriginURL("https://foo.com/path?q=1") - if result != "http://127.0.0.1:8083/path?q=1" { - t.Error("expected http://127.0.0.1:8083/path?q=1, got ", result) + service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/") + if !loaded || service.StatusCode != 404 { + t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded) } } -func TestResolveOriginURLNoMatch(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "other.com", Service: "http://localhost:1"}, +func TestResolveHTTPService(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, } - original := "https://unknown.com/page" - result := inboundInstance.ResolveOriginURL(original) - if result != original { - t.Error("no match should return original, got ", result) + service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1") + if err != nil { + t.Fatal(err) + } + if service.Destination.String() != "127.0.0.1:8083" { + t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination) + } + if requestURL != "http://127.0.0.1:8083/path?q=1" { + t.Fatalf("expected rewritten URL, got %s", requestURL) } } -func TestResolveOriginURLHTTPStatus(t *testing.T) { - inboundInstance := newTestIngressInbound() - inboundInstance.ingressRules = []IngressRule{ - {Hostname: "", Service: "http_status:404"}, +func TestResolveHTTPServiceStatus(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Service: mustResolvedService(t, "http_status:404")}, + }, } - original := "https://any.com/page" - result := inboundInstance.ResolveOriginURL(original) - if result != original { - t.Error("http_status service should return original, got ", result) + service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path") + if err != nil { + t.Fatal(err) + } + if service.StatusCode != 404 { + t.Fatalf("expected status 404, got %#v", service) + } + if requestURL != "https://any.com/path" { + t.Fatalf("status service should keep request URL, got %s", requestURL) } } diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go new file mode 100644 index 0000000000..99fe73c818 --- /dev/null +++ b/protocol/cloudflare/runtime_config.go @@ -0,0 +1,803 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "encoding/json" + "net" + "net/url" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "golang.org/x/net/idna" +) + +const ( + defaultHTTPConnectTimeout = 30 * time.Second + defaultTLSTimeout = 10 * time.Second + defaultTCPKeepAlive = 30 * time.Second + defaultKeepAliveTimeout = 90 * time.Second + defaultKeepAliveConnections = 100 + defaultProxyAddress = "127.0.0.1" + defaultWarpRoutingConnectTime = 5 * time.Second + defaultWarpRoutingTCPKeepAlive = 30 * time.Second +) + +type ResolvedServiceKind int + +const ( + ResolvedServiceHTTP ResolvedServiceKind = iota + ResolvedServiceStream + ResolvedServiceStatus + ResolvedServiceHelloWorld + ResolvedServiceUnix + ResolvedServiceUnixTLS +) + +type ResolvedService struct { + Kind ResolvedServiceKind + Service string + Destination M.Socksaddr + BaseURL *url.URL + UnixPath string + StatusCode int + OriginRequest OriginRequestConfig +} + +func (s ResolvedService) RouterControlled() bool { + return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream +} + +func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) { + switch s.Kind { + case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS: + requestParsed, err := url.Parse(requestURL) + if err != nil { + return "", err + } + originURL := *s.BaseURL + originURL.Path = requestParsed.Path + originURL.RawPath = requestParsed.RawPath + originURL.RawQuery = requestParsed.RawQuery + originURL.Fragment = requestParsed.Fragment + return originURL.String(), nil + case ResolvedServiceHelloWorld: + if s.BaseURL == nil { + return "", E.New("hello world service is unavailable") + } + requestParsed, err := url.Parse(requestURL) + if err != nil { + return "", err + } + originURL := *s.BaseURL + originURL.Path = requestParsed.Path + originURL.RawPath = requestParsed.RawPath + originURL.RawQuery = requestParsed.RawQuery + originURL.Fragment = requestParsed.Fragment + return originURL.String(), nil + default: + return requestURL, nil + } +} + +type compiledIngressRule struct { + Hostname string + PunycodeHostname string + Path *regexp.Regexp + Service ResolvedService +} + +type RuntimeConfig struct { + Ingress []compiledIngressRule + OriginRequest OriginRequestConfig + WarpRouting WarpRoutingConfig +} + +type OriginRequestConfig struct { + ConnectTimeout time.Duration + TLSTimeout time.Duration + TCPKeepAlive time.Duration + NoHappyEyeballs bool + KeepAliveTimeout time.Duration + KeepAliveConnections int + HTTPHostHeader string + OriginServerName string + MatchSNIToHost bool + CAPool string + NoTLSVerify bool + DisableChunkedEncoding bool + BastionMode bool + ProxyAddress string + ProxyPort uint + ProxyType string + IPRules []IPRule + HTTP2Origin bool + Access AccessConfig +} + +type AccessConfig struct { + Required bool + TeamName string + AudTag []string + Environment string +} + +type IPRule struct { + Prefix string + Ports []int + Allow bool +} + +type WarpRoutingConfig struct { + ConnectTimeout time.Duration + MaxActiveFlows uint64 + TCPKeepAlive time.Duration +} + +type ConfigUpdateResult struct { + LastAppliedVersion int32 + Err error +} + +type ConfigManager struct { + access sync.RWMutex + currentVersion int32 + activeConfig RuntimeConfig +} + +func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) { + config, err := buildLocalRuntimeConfig(options) + if err != nil { + return nil, err + } + return &ConfigManager{ + currentVersion: -1, + activeConfig: config, + }, nil +} + +func (m *ConfigManager) Snapshot() RuntimeConfig { + m.access.RLock() + defer m.access.RUnlock() + return m.activeConfig +} + +func (m *ConfigManager) CurrentVersion() int32 { + m.access.RLock() + defer m.access.RUnlock() + return m.currentVersion +} + +func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult { + m.access.Lock() + defer m.access.Unlock() + + if version <= m.currentVersion { + return ConfigUpdateResult{LastAppliedVersion: m.currentVersion} + } + + config, err := buildRemoteRuntimeConfig(raw) + if err != nil { + return ConfigUpdateResult{ + LastAppliedVersion: m.currentVersion, + Err: err, + } + } + + m.activeConfig = config + m.currentVersion = version + return ConfigUpdateResult{LastAppliedVersion: m.currentVersion} +} + +func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) { + m.access.RLock() + defer m.access.RUnlock() + return m.activeConfig.Resolve(hostname, path) +} + +func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) { + host := stripPort(hostname) + for _, rule := range c.Ingress { + if !matchIngressRule(rule, host, path) { + continue + } + return rule.Service, true + } + return ResolvedService{}, false +} + +func matchIngressRule(rule compiledIngressRule, hostname, path string) bool { + hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname) + if !hostMatch && rule.PunycodeHostname != "" { + hostMatch = matchIngressHost(rule.PunycodeHostname, hostname) + } + if !hostMatch { + return false + } + return rule.Path == nil || rule.Path.MatchString(path) +} + +func matchIngressHost(pattern, hostname string) bool { + if pattern == hostname { + return true + } + if strings.HasPrefix(pattern, "*.") { + return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*")) + } + return false +} + +func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) { + defaultOriginRequest := originRequestFromOption(options.OriginRequest) + warpRouting := warpRoutingFromOption(options.WarpRouting) + var ingressRules []localIngressRule + for _, rule := range options.Ingress { + ingressRules = append(ingressRules, localIngressRule{ + Hostname: rule.Hostname, + Path: rule.Path, + Service: rule.Service, + OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest), + }) + } + compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) + if err != nil { + return RuntimeConfig{}, err + } + return RuntimeConfig{ + Ingress: compiledRules, + OriginRequest: defaultOriginRequest, + WarpRouting: warpRouting, + }, nil +} + +func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) { + var remote remoteConfigJSON + if err := json.Unmarshal(raw, &remote); err != nil { + return RuntimeConfig{}, E.Cause(err, "decode remote config") + } + defaultOriginRequest := originRequestFromRemote(remote.OriginRequest) + warpRouting := warpRoutingFromRemote(remote.WarpRouting) + var ingressRules []localIngressRule + for _, rule := range remote.Ingress { + ingressRules = append(ingressRules, localIngressRule{ + Hostname: rule.Hostname, + Path: rule.Path, + Service: rule.Service, + OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest), + }) + } + compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) + if err != nil { + return RuntimeConfig{}, err + } + return RuntimeConfig{ + Ingress: compiledRules, + OriginRequest: defaultOriginRequest, + WarpRouting: warpRouting, + }, nil +} + +type localIngressRule struct { + Hostname string + Path string + Service string + OriginRequest OriginRequestConfig +} + +type remoteConfigJSON struct { + OriginRequest remoteOriginRequestJSON `json:"originRequest"` + Ingress []remoteIngressRuleJSON `json:"ingress"` + WarpRouting remoteWarpRoutingJSON `json:"warp-routing"` +} + +type remoteIngressRuleJSON struct { + Hostname string `json:"hostname,omitempty"` + Path string `json:"path,omitempty"` + Service string `json:"service"` + OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"` +} + +type remoteOriginRequestJSON struct { + ConnectTimeout int64 `json:"connectTimeout,omitempty"` + TLSTimeout int64 `json:"tlsTimeout,omitempty"` + TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"` + NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"` + KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"` + KeepAliveConnections *int `json:"keepAliveConnections,omitempty"` + HTTPHostHeader string `json:"httpHostHeader,omitempty"` + OriginServerName string `json:"originServerName,omitempty"` + MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"` + CAPool string `json:"caPool,omitempty"` + NoTLSVerify *bool `json:"noTLSVerify,omitempty"` + DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"` + BastionMode *bool `json:"bastionMode,omitempty"` + ProxyAddress string `json:"proxyAddress,omitempty"` + ProxyPort *uint `json:"proxyPort,omitempty"` + ProxyType string `json:"proxyType,omitempty"` + IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"` + HTTP2Origin *bool `json:"http2Origin,omitempty"` + Access *remoteAccessJSON `json:"access,omitempty"` +} + +type remoteAccessJSON struct { + Required bool `json:"required,omitempty"` + TeamName string `json:"teamName,omitempty"` + AudTag []string `json:"audTag,omitempty"` + Environment string `json:"environment,omitempty"` +} + +type remoteIPRuleJSON struct { + Prefix string `json:"prefix,omitempty"` + Ports []int `json:"ports,omitempty"` + Allow bool `json:"allow,omitempty"` +} + +type remoteWarpRoutingJSON struct { + ConnectTimeout int64 `json:"connectTimeout,omitempty"` + MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"` + TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"` +} + +func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) { + if len(rawRules) == 0 { + rawRules = []localIngressRule{{ + Service: "http_status:503", + OriginRequest: defaultOriginRequest, + }} + } + if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) { + return nil, E.New("the last ingress rule must be a catch-all rule") + } + + compiled := make([]compiledIngressRule, 0, len(rawRules)) + for index, rule := range rawRules { + if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil { + return nil, err + } + service, err := parseResolvedService(rule.Service, rule.OriginRequest) + if err != nil { + return nil, err + } + var pathPattern *regexp.Regexp + if rule.Path != "" { + pathPattern, err = regexp.Compile(rule.Path) + if err != nil { + return nil, E.Cause(err, "compile ingress path regex") + } + } + punycode := "" + if rule.Hostname != "" && rule.Hostname != "*" { + punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname) + if err == nil && punycodeValue != rule.Hostname { + punycode = punycodeValue + } + } + compiled = append(compiled, compiledIngressRule{ + Hostname: rule.Hostname, + PunycodeHostname: punycode, + Path: pathPattern, + Service: service, + }) + } + return compiled, nil +} + +func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) { + switch { + case rawService == "": + return ResolvedService{}, E.New("missing ingress service") + case strings.HasPrefix(rawService, "http_status:"): + statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:")) + if err != nil { + return ResolvedService{}, E.Cause(err, "parse http_status service") + } + if statusCode < 100 || statusCode > 999 { + return ResolvedService{}, E.New("invalid http_status code: ", statusCode) + } + return ResolvedService{ + Kind: ResolvedServiceStatus, + Service: rawService, + StatusCode: statusCode, + OriginRequest: originRequest, + }, nil + case rawService == "hello_world" || rawService == "hello-world": + return ResolvedService{ + Kind: ResolvedServiceHelloWorld, + Service: rawService, + OriginRequest: originRequest, + }, nil + case strings.HasPrefix(rawService, "unix:"): + return ResolvedService{ + Kind: ResolvedServiceUnix, + Service: rawService, + UnixPath: strings.TrimPrefix(rawService, "unix:"), + BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, + OriginRequest: originRequest, + }, nil + case strings.HasPrefix(rawService, "unix+tls:"): + return ResolvedService{ + Kind: ResolvedServiceUnixTLS, + Service: rawService, + UnixPath: strings.TrimPrefix(rawService, "unix+tls:"), + BaseURL: &url.URL{Scheme: "https", Host: "localhost"}, + OriginRequest: originRequest, + }, nil + } + + parsedURL, err := url.Parse(rawService) + if err != nil { + return ResolvedService{}, E.Cause(err, "parse ingress service URL") + } + if parsedURL.Scheme == "" || parsedURL.Hostname() == "" { + return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService) + } + if parsedURL.Path != "" { + return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService) + } + + switch parsedURL.Scheme { + case "http", "https", "ws", "wss": + return ResolvedService{ + Kind: ResolvedServiceHTTP, + Service: rawService, + Destination: parseServiceDestination(parsedURL), + BaseURL: parsedURL, + OriginRequest: originRequest, + }, nil + case "tcp", "ssh", "rdp", "smb": + return ResolvedService{ + Kind: ResolvedServiceStream, + Service: rawService, + Destination: parseServiceDestination(parsedURL), + BaseURL: parsedURL, + OriginRequest: originRequest, + }, nil + default: + return ResolvedService{}, E.New("unsupported ingress service scheme: ", parsedURL.Scheme) + } +} + +func parseServiceDestination(parsedURL *url.URL) M.Socksaddr { + host := parsedURL.Hostname() + port := parsedURL.Port() + if port == "" { + switch parsedURL.Scheme { + case "https", "wss": + port = "443" + case "ssh": + port = "22" + case "rdp": + port = "3389" + case "smb": + port = "445" + case "tcp": + port = "7864" + default: + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(host, port)) +} + +func validateHostname(hostname string, isLast bool) error { + if hostname == "" || hostname == "*" { + if !isLast { + return E.New("only the last ingress rule may be a catch-all rule") + } + return nil + } + if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) { + return E.New("hostname wildcard must be in the form *.example.com") + } + if stripPort(hostname) != hostname { + return E.New("ingress hostname cannot contain a port") + } + return nil +} + +func isCatchAllRule(hostname, path string) bool { + return (hostname == "" || hostname == "*") && path == "" +} + +func stripPort(hostname string) string { + if host, _, err := net.SplitHostPort(hostname); err == nil { + return host + } + return hostname +} + +func defaultOriginRequestConfig() OriginRequestConfig { + return OriginRequestConfig{ + ConnectTimeout: defaultHTTPConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveTimeout: defaultKeepAliveTimeout, + KeepAliveConnections: defaultKeepAliveConnections, + ProxyAddress: defaultProxyAddress, + } +} + +func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { + config := defaultOriginRequestConfig() + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) + } + if input.TLSTimeout != 0 { + config.TLSTimeout = time.Duration(input.TLSTimeout) + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) + } + if input.KeepAliveTimeout != 0 { + config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) + } + if input.KeepAliveConnections != 0 { + config.KeepAliveConnections = input.KeepAliveConnections + } + config.NoHappyEyeballs = input.NoHappyEyeballs + config.HTTPHostHeader = input.HTTPHostHeader + config.OriginServerName = input.OriginServerName + config.MatchSNIToHost = input.MatchSNIToHost + config.CAPool = input.CAPool + config.NoTLSVerify = input.NoTLSVerify + config.DisableChunkedEncoding = input.DisableChunkedEncoding + config.BastionMode = input.BastionMode + if input.ProxyAddress != "" { + config.ProxyAddress = input.ProxyAddress + } + if input.ProxyPort != 0 { + config.ProxyPort = input.ProxyPort + } + config.ProxyType = input.ProxyType + config.HTTP2Origin = input.HTTP2Origin + config.Access = AccessConfig{ + Required: input.Access.Required, + TeamName: input.Access.TeamName, + AudTag: append([]string(nil), input.Access.AudTag...), + Environment: input.Access.Environment, + } + for _, rule := range input.IPRules { + config.IPRules = append(config.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + return config +} + +func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { + result := base + if override.ConnectTimeout != 0 { + result.ConnectTimeout = time.Duration(override.ConnectTimeout) + } + if override.TLSTimeout != 0 { + result.TLSTimeout = time.Duration(override.TLSTimeout) + } + if override.TCPKeepAlive != 0 { + result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) + } + if override.KeepAliveTimeout != 0 { + result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) + } + if override.KeepAliveConnections != 0 { + result.KeepAliveConnections = override.KeepAliveConnections + } + result.NoHappyEyeballs = override.NoHappyEyeballs + if override.HTTPHostHeader != "" { + result.HTTPHostHeader = override.HTTPHostHeader + } + if override.OriginServerName != "" { + result.OriginServerName = override.OriginServerName + } + result.MatchSNIToHost = override.MatchSNIToHost + if override.CAPool != "" { + result.CAPool = override.CAPool + } + result.NoTLSVerify = override.NoTLSVerify + result.DisableChunkedEncoding = override.DisableChunkedEncoding + result.BastionMode = override.BastionMode + if override.ProxyAddress != "" { + result.ProxyAddress = override.ProxyAddress + } + if override.ProxyPort != 0 { + result.ProxyPort = override.ProxyPort + } + if override.ProxyType != "" { + result.ProxyType = override.ProxyType + } + if len(override.IPRules) > 0 { + result.IPRules = nil + for _, rule := range override.IPRules { + result.IPRules = append(result.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + } + result.HTTP2Origin = override.HTTP2Origin + if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" { + result.Access = AccessConfig{ + Required: override.Access.Required, + TeamName: override.Access.TeamName, + AudTag: append([]string(nil), override.Access.AudTag...), + Environment: override.Access.Environment, + } + } + return result +} + +func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig { + config := defaultOriginRequestConfig() + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second + } + if input.TLSTimeout != 0 { + config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second + } + if input.KeepAliveTimeout != 0 { + config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second + } + if input.KeepAliveConnections != nil { + config.KeepAliveConnections = *input.KeepAliveConnections + } + if input.NoHappyEyeballs != nil { + config.NoHappyEyeballs = *input.NoHappyEyeballs + } + config.HTTPHostHeader = input.HTTPHostHeader + config.OriginServerName = input.OriginServerName + if input.MatchSNIToHost != nil { + config.MatchSNIToHost = *input.MatchSNIToHost + } + config.CAPool = input.CAPool + if input.NoTLSVerify != nil { + config.NoTLSVerify = *input.NoTLSVerify + } + if input.DisableChunkedEncoding != nil { + config.DisableChunkedEncoding = *input.DisableChunkedEncoding + } + if input.BastionMode != nil { + config.BastionMode = *input.BastionMode + } + if input.ProxyAddress != "" { + config.ProxyAddress = input.ProxyAddress + } + if input.ProxyPort != nil { + config.ProxyPort = *input.ProxyPort + } + config.ProxyType = input.ProxyType + if input.HTTP2Origin != nil { + config.HTTP2Origin = *input.HTTP2Origin + } + if input.Access != nil { + config.Access = AccessConfig{ + Required: input.Access.Required, + TeamName: input.Access.TeamName, + AudTag: append([]string(nil), input.Access.AudTag...), + Environment: input.Access.Environment, + } + } + for _, rule := range input.IPRules { + config.IPRules = append(config.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + return config +} + +func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig { + result := base + if override.ConnectTimeout != 0 { + result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second + } + if override.TLSTimeout != 0 { + result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second + } + if override.TCPKeepAlive != 0 { + result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second + } + if override.NoHappyEyeballs != nil { + result.NoHappyEyeballs = *override.NoHappyEyeballs + } + if override.KeepAliveTimeout != 0 { + result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second + } + if override.KeepAliveConnections != nil { + result.KeepAliveConnections = *override.KeepAliveConnections + } + if override.HTTPHostHeader != "" { + result.HTTPHostHeader = override.HTTPHostHeader + } + if override.OriginServerName != "" { + result.OriginServerName = override.OriginServerName + } + if override.MatchSNIToHost != nil { + result.MatchSNIToHost = *override.MatchSNIToHost + } + if override.CAPool != "" { + result.CAPool = override.CAPool + } + if override.NoTLSVerify != nil { + result.NoTLSVerify = *override.NoTLSVerify + } + if override.DisableChunkedEncoding != nil { + result.DisableChunkedEncoding = *override.DisableChunkedEncoding + } + if override.BastionMode != nil { + result.BastionMode = *override.BastionMode + } + if override.ProxyAddress != "" { + result.ProxyAddress = override.ProxyAddress + } + if override.ProxyPort != nil { + result.ProxyPort = *override.ProxyPort + } + if override.ProxyType != "" { + result.ProxyType = override.ProxyType + } + if len(override.IPRules) > 0 { + result.IPRules = nil + for _, rule := range override.IPRules { + result.IPRules = append(result.IPRules, IPRule{ + Prefix: rule.Prefix, + Ports: append([]int(nil), rule.Ports...), + Allow: rule.Allow, + }) + } + } + if override.HTTP2Origin != nil { + result.HTTP2Origin = *override.HTTP2Origin + } + if override.Access != nil { + result.Access = AccessConfig{ + Required: override.Access.Required, + TeamName: override.Access.TeamName, + AudTag: append([]string(nil), override.Access.AudTag...), + Environment: override.Access.Environment, + } + } + return result +} + +func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig { + config := WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + MaxActiveFlows: input.MaxActiveFlows, + } + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) + } + return config +} + +func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig { + config := WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + MaxActiveFlows: input.MaxActiveFlows, + } + if input.ConnectTimeout != 0 { + config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second + } + if input.TCPKeepAlive != 0 { + config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second + } + return config +} From 124379fc1d3c57fb6138e7f98cebddc488130bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 11:18:43 +0800 Subject: [PATCH 12/41] Support regional cloudflare edge selection --- protocol/cloudflare/edge_discovery.go | 21 ++++++++++++++------- protocol/cloudflare/edge_discovery_test.go | 11 ++++++++++- protocol/cloudflare/inbound.go | 12 ++++++++++-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go index 6ca9403d0d..0c08bcbf86 100644 --- a/protocol/cloudflare/edge_discovery.go +++ b/protocol/cloudflare/edge_discovery.go @@ -21,6 +21,13 @@ const ( dotTimeout = 15 * time.Second ) +func getRegionalServiceName(region string) string { + if region == "" { + return edgeSRVService + } + return region + "-" + edgeSRVService +} + // EdgeAddr represents a Cloudflare edge server address. type EdgeAddr struct { TCP *net.TCPAddr @@ -30,10 +37,10 @@ type EdgeAddr struct { // DiscoverEdge performs SRV-based edge discovery and returns addresses // partitioned into regions (typically 2). -func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) { - regions, err := lookupEdgeSRV() +func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) { + regions, err := lookupEdgeSRV(region) if err != nil { - regions, err = lookupEdgeSRVWithDoT(ctx) + regions, err = lookupEdgeSRVWithDoT(ctx, region) if err != nil { return nil, E.Cause(err, "edge discovery") } @@ -44,15 +51,15 @@ func DiscoverEdge(ctx context.Context) ([][]*EdgeAddr, error) { return regions, nil } -func lookupEdgeSRV() ([][]*EdgeAddr, error) { - _, addrs, err := net.LookupSRV(edgeSRVService, edgeSRVProto, edgeSRVName) +func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) { + _, addrs, err := net.LookupSRV(getRegionalServiceName(region), edgeSRVProto, edgeSRVName) if err != nil { return nil, err } return resolveSRVRecords(addrs) } -func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) { +func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { @@ -66,7 +73,7 @@ func lookupEdgeSRVWithDoT(ctx context.Context) ([][]*EdgeAddr, error) { } lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout) defer cancel() - _, addrs, err := resolver.LookupSRV(lookupCtx, edgeSRVService, edgeSRVProto, edgeSRVName) + _, addrs, err := resolver.LookupSRV(lookupCtx, getRegionalServiceName(region), edgeSRVProto, edgeSRVName) if err != nil { return nil, err } diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index 6d602cfa60..c282009d0d 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -9,7 +9,7 @@ import ( ) func TestDiscoverEdge(t *testing.T) { - regions, err := DiscoverEdge(context.Background()) + regions, err := DiscoverEdge(context.Background(), "") if err != nil { t.Fatal("DiscoverEdge: ", err) } @@ -86,3 +86,12 @@ func TestFilterByIPVersion(t *testing.T) { } }) } + +func TestGetRegionalServiceName(t *testing.T) { + if got := getRegionalServiceName(""); got != edgeSRVService { + t.Fatalf("expected global service %s, got %s", edgeSRVService, got) + } + if got := getRegionalServiceName("us"); got != "us-"+edgeSRVService { + t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got) + } +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 0fcb2ac65b..3b349d3e32 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -94,6 +94,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return nil, E.Cause(err, "build cloudflare tunnel runtime config") } + region := options.Region + if region != "" && credentials.Endpoint != "" { + return nil, E.New("region cannot be specified when credentials already include an endpoint") + } + if region == "" { + region = credentials.Endpoint + } + inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ @@ -106,7 +114,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo connectorID: uuid.New(), haConnections: haConnections, protocol: protocol, - region: options.Region, + region: region, edgeIPVersion: edgeIPVersion, datagramVersion: datagramVersion, gracePeriod: gracePeriod, @@ -123,7 +131,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections") - regions, err := DiscoverEdge(i.ctx) + regions, err := DiscoverEdge(i.ctx, i.region) if err != nil { return E.Cause(err, "discover edge") } From b3cad021b8e9af38a54c9abb67a05d7d2fcf1c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 11:20:26 +0800 Subject: [PATCH 13/41] Apply origin request SNI selection --- protocol/cloudflare/dispatch.go | 33 +++++++++++++++------- protocol/cloudflare/origin_request_test.go | 33 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 protocol/cloudflare/origin_request_test.go diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 9263e16c30..761e4a978f 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -231,7 +231,7 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -240,7 +240,7 @@ func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWrite metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest) + transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -249,7 +249,7 @@ func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWrit metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest) - transport, cleanup, err := i.newDirectOriginTransport(service) + transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost]) if err != nil { i.logger.ErrorContext(ctx, "build direct origin transport: ", err) respWriter.WriteResponse(err, nil) @@ -263,7 +263,7 @@ func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.Rea metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest) - transport, cleanup, err := i.newDirectOriginTransport(service) + transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost]) if err != nil { i.logger.ErrorContext(ctx, "build direct origin transport: ", err) respWriter.WriteResponse(err, nil) @@ -329,7 +329,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, } } -func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig) (*http.Transport, func()) { +func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) { input, output := pipe.Pipe() done := make(chan struct{}) go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { @@ -344,7 +344,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter IdleConnTimeout: originRequest.KeepAliveTimeout, MaxIdleConns: originRequest.KeepAliveConnections, MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(originRequest), + TLSClientConfig: buildOriginTLSConfig(originRequest, requestHost), DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil }, @@ -358,7 +358,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter } } -func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Transport, func(), error) { +func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { transport := &http.Transport{ DisableCompression: true, ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, @@ -366,7 +366,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, MaxIdleConns: service.OriginRequest.KeepAliveConnections, MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(service.OriginRequest), + TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost), } switch service.Kind { case ResolvedServiceUnix, ResolvedServiceUnixTLS: @@ -386,10 +386,10 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService) (*http.Trans return transport, func() {}, nil } -func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { +func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) *tls.Config { tlsConfig := &tls.Config{ InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec - ServerName: originRequest.OriginServerName, + ServerName: originTLSServerName(originRequest, requestHost), } if originRequest.CAPool == "" { return tlsConfig @@ -405,6 +405,19 @@ func buildOriginTLSConfig(originRequest OriginRequestConfig) *tls.Config { return tlsConfig } +func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string { + if originRequest.OriginServerName != "" { + return originRequest.OriginServerName + } + if !originRequest.MatchSNIToHost { + return "" + } + if host, _, err := net.SplitHostPort(requestHost); err == nil { + return host + } + return requestHost +} + func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request { request = request.Clone(request.Context()) if originRequest.HTTPHostHeader != "" { diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go new file mode 100644 index 0000000000..b56a0a52f2 --- /dev/null +++ b/protocol/cloudflare/origin_request_test.go @@ -0,0 +1,33 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import "testing" + +func TestOriginTLSServerName(t *testing.T) { + t.Run("origin server name overrides host", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + OriginServerName: "origin.example.com", + MatchSNIToHost: true, + }, "request.example.com") + if serverName != "origin.example.com" { + t.Fatalf("expected origin.example.com, got %s", serverName) + } + }) + + t.Run("match sni to host strips port", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + MatchSNIToHost: true, + }, "request.example.com:443") + if serverName != "request.example.com" { + t.Fatalf("expected request.example.com, got %s", serverName) + } + }) + + t.Run("disabled match keeps empty server name", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com") + if serverName != "" { + t.Fatalf("expected empty server name, got %s", serverName) + } + }) +} From 71c7a585efa5c48dbb640d556d2ae962c0b8c5b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 11:47:06 +0800 Subject: [PATCH 14/41] Route cloudflare tunnel ICMP through sing-box router --- protocol/cloudflare/datagram_v2.go | 11 +- protocol/cloudflare/datagram_v3.go | 7 +- protocol/cloudflare/helpers_test.go | 30 ++- protocol/cloudflare/icmp.go | 356 ++++++++++++++++++++++++++++ protocol/cloudflare/icmp_test.go | 242 +++++++++++++++++++ protocol/cloudflare/inbound.go | 2 +- 6 files changed, 641 insertions(+), 7 deletions(-) create mode 100644 protocol/cloudflare/icmp.go create mode 100644 protocol/cloudflare/icmp_test.go diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 2071d86e96..49fbb4d8f6 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -42,6 +42,7 @@ type DatagramV2Muxer struct { inbound *Inbound logger log.ContextLogger sender DatagramSender + icmp *ICMPBridge sessionAccess sync.RWMutex sessions map[uuid.UUID]*udpSession @@ -53,6 +54,7 @@ func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.Cont inbound: inbound, logger: logger, sender: sender, + icmp: NewICMPBridge(inbound, sender, icmpWireV2), sessions: make(map[uuid.UUID]*udpSession), } } @@ -70,10 +72,13 @@ func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) { case DatagramV2TypeUDP: m.handleUDPDatagram(ctx, payload) case DatagramV2TypeIP: - // TODO: ICMP handling - m.logger.Debug("received V2 IP datagram (ICMP not yet implemented)") + if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil { + m.logger.Debug("drop V2 ICMP datagram: ", err) + } case DatagramV2TypeIPWithTrace: - m.logger.Debug("received V2 IP+trace datagram") + if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil { + m.logger.Debug("drop V2 traced ICMP datagram: ", err) + } case DatagramV2TypeTracingSpan: // Tracing spans, ignore } diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index a47f1c2ed5..1f8f1aacd6 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -61,6 +61,7 @@ type DatagramV3Muxer struct { inbound *Inbound logger log.ContextLogger sender DatagramSender + icmp *ICMPBridge sessionAccess sync.RWMutex sessions map[RequestID]*v3Session @@ -72,6 +73,7 @@ func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.Cont inbound: inbound, logger: logger, sender: sender, + icmp: NewICMPBridge(inbound, sender, icmpWireV3), sessions: make(map[RequestID]*v3Session), } } @@ -91,8 +93,9 @@ func (m *DatagramV3Muxer) HandleDatagram(ctx context.Context, data []byte) { case DatagramV3TypePayload: m.handlePayload(payload) case DatagramV3TypeICMP: - // TODO: ICMP handling - m.logger.Debug("received V3 ICMP datagram (not yet implemented)") + if err := m.icmp.HandleV3(ctx, payload); err != nil { + m.logger.Debug("drop V3 ICMP datagram: ", err) + } case DatagramV3TypeRegistrationResponse: // Unexpected - we never send registrations m.logger.Debug("received unexpected V3 registration response") diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 8920b7705b..f06a5fa2a8 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -21,6 +21,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" "github.com/google/uuid" @@ -80,7 +81,13 @@ func startOriginServer(t *testing.T) { }) } -type testRouter struct{} +type testRouter struct { + preMatch func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) +} + +func (r *testRouter) Start(stage adapter.StartStage) error { return nil } + +func (r *testRouter) Close() error { return nil } func (r *testRouter) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { destination := metadata.Destination.String() @@ -130,6 +137,27 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC onClose(nil) } +func (r *testRouter) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + if r.preMatch != nil { + return r.preMatch(metadata, routeContext, timeout, supportBypass) + } + return nil, nil +} + +func (r *testRouter) RuleSet(tag string) (adapter.RuleSet, bool) { return nil, false } + +func (r *testRouter) Rules() []adapter.Rule { return nil } + +func (r *testRouter) NeedFindProcess() bool { return false } + +func (r *testRouter) NeedFindNeighbor() bool { return false } + +func (r *testRouter) NeighborResolver() adapter.NeighborResolver { return nil } + +func (r *testRouter) AppendTracker(tracker adapter.ConnectionTracker) {} + +func (r *testRouter) ResetNetwork() {} + func newTestInbound(t *testing.T, token string, protocol string, haConnections int) *Inbound { t.Helper() credentials, err := parseToken(token) diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go new file mode 100644 index 0000000000..8e000c0db8 --- /dev/null +++ b/protocol/cloudflare/icmp.go @@ -0,0 +1,356 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/binary" + "net/netip" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const ( + icmpFlowTimeout = 30 * time.Second + icmpTraceIdentityLength = 16 + 8 + 1 +) + +type ICMPTraceContext struct { + Traced bool + Identity []byte +} + +type ICMPFlowKey struct { + IPVersion uint8 + SourceIP netip.Addr + Destination netip.Addr +} + +type ICMPRequestKey struct { + Flow ICMPFlowKey + Identifier uint16 + Sequence uint16 +} + +type ICMPPacketInfo struct { + IPVersion uint8 + Protocol uint8 + SourceIP netip.Addr + Destination netip.Addr + ICMPType uint8 + ICMPCode uint8 + Identifier uint16 + Sequence uint16 + RawPacket []byte +} + +func (i ICMPPacketInfo) FlowKey() ICMPFlowKey { + return ICMPFlowKey{ + IPVersion: i.IPVersion, + SourceIP: i.SourceIP, + Destination: i.Destination, + } +} + +func (i ICMPPacketInfo) RequestKey() ICMPRequestKey { + return ICMPRequestKey{ + Flow: i.FlowKey(), + Identifier: i.Identifier, + Sequence: i.Sequence, + } +} + +func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey { + return ICMPRequestKey{ + Flow: ICMPFlowKey{ + IPVersion: i.IPVersion, + SourceIP: i.Destination, + Destination: i.SourceIP, + }, + Identifier: i.Identifier, + Sequence: i.Sequence, + } +} + +func (i ICMPPacketInfo) IsEchoRequest() bool { + switch i.IPVersion { + case 4: + return i.ICMPType == 8 && i.ICMPCode == 0 + case 6: + return i.ICMPType == 128 && i.ICMPCode == 0 + default: + return false + } +} + +func (i ICMPPacketInfo) IsEchoReply() bool { + switch i.IPVersion { + case 4: + return i.ICMPType == 0 && i.ICMPCode == 0 + case 6: + return i.ICMPType == 129 && i.ICMPCode == 0 + default: + return false + } +} + +type icmpWireVersion uint8 + +const ( + icmpWireV2 icmpWireVersion = iota + 1 + icmpWireV3 +) + +type icmpFlowState struct { + writer *ICMPReplyWriter +} + +type ICMPReplyWriter struct { + sender DatagramSender + wireVersion icmpWireVersion + + access sync.Mutex + traces map[ICMPRequestKey]ICMPTraceContext +} + +func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter { + return &ICMPReplyWriter{ + sender: sender, + wireVersion: wireVersion, + traces: make(map[ICMPRequestKey]ICMPTraceContext), + } +} + +func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceContext ICMPTraceContext) { + if !traceContext.Traced { + return + } + w.access.Lock() + w.traces[packetInfo.RequestKey()] = traceContext + w.access.Unlock() +} + +func (w *ICMPReplyWriter) WritePacket(packet []byte) error { + packetInfo, err := ParseICMPPacket(packet) + if err != nil { + return err + } + if !packetInfo.IsEchoReply() { + return nil + } + + requestKey := packetInfo.ReplyRequestKey() + w.access.Lock() + traceContext, loaded := w.traces[requestKey] + if loaded { + delete(w.traces, requestKey) + } + w.access.Unlock() + + var datagram []byte + switch w.wireVersion { + case icmpWireV2: + datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext) + case icmpWireV3: + datagram = encodeV3ICMPDatagram(packetInfo.RawPacket) + default: + err = E.New("unsupported icmp wire version: ", w.wireVersion) + } + if err != nil { + return err + } + return w.sender.SendDatagram(datagram) +} + +type ICMPBridge struct { + inbound *Inbound + sender DatagramSender + wireVersion icmpWireVersion + routeMapping *tun.DirectRouteMapping + + flowAccess sync.Mutex + flows map[ICMPFlowKey]*icmpFlowState +} + +func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge { + return &ICMPBridge{ + inbound: inbound, + sender: sender, + wireVersion: wireVersion, + routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout), + flows: make(map[ICMPFlowKey]*icmpFlowState), + } +} + +func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error { + traceContext := ICMPTraceContext{} + switch datagramType { + case DatagramV2TypeIP: + case DatagramV2TypeIPWithTrace: + if len(payload) < icmpTraceIdentityLength { + return E.New("icmp trace payload is too short") + } + traceContext.Traced = true + traceContext.Identity = append([]byte(nil), payload[len(payload)-icmpTraceIdentityLength:]...) + payload = payload[:len(payload)-icmpTraceIdentityLength] + default: + return E.New("unsupported v2 icmp datagram type: ", datagramType) + } + return b.handlePacket(ctx, payload, traceContext) +} + +func (b *ICMPBridge) HandleV3(ctx context.Context, payload []byte) error { + return b.handlePacket(ctx, payload, ICMPTraceContext{}) +} + +func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceContext ICMPTraceContext) error { + packetInfo, err := ParseICMPPacket(payload) + if err != nil { + return err + } + if !packetInfo.IsEchoRequest() { + return nil + } + + state := b.getFlowState(packetInfo.FlowKey()) + if traceContext.Traced { + state.writer.RegisterRequestTrace(packetInfo, traceContext) + } + + action, err := b.routeMapping.Lookup(tun.DirectRouteSession{ + Source: packetInfo.SourceIP, + Destination: packetInfo.Destination, + }, func(timeout time.Duration) (tun.DirectRouteDestination, error) { + metadata := adapter.InboundContext{ + Inbound: b.inbound.Tag(), + InboundType: b.inbound.Type(), + IPVersion: packetInfo.IPVersion, + Network: N.NetworkICMP, + Source: M.SocksaddrFrom(packetInfo.SourceIP, 0), + Destination: M.SocksaddrFrom(packetInfo.Destination, 0), + OriginDestination: M.SocksaddrFrom(packetInfo.Destination, 0), + } + return b.inbound.router.PreMatch(metadata, state.writer, timeout, false) + }) + if err != nil { + return nil + } + return action.WritePacket(buf.As(packetInfo.RawPacket).ToOwned()) +} + +func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState { + b.flowAccess.Lock() + defer b.flowAccess.Unlock() + state, loaded := b.flows[key] + if loaded { + return state + } + state = &icmpFlowState{ + writer: NewICMPReplyWriter(b.sender, b.wireVersion), + } + b.flows[key] = state + return state +} + +func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) { + if len(packet) < 1 { + return ICMPPacketInfo{}, E.New("empty IP packet") + } + version := packet[0] >> 4 + switch version { + case 4: + return parseIPv4ICMPPacket(packet) + case 6: + return parseIPv6ICMPPacket(packet) + default: + return ICMPPacketInfo{}, E.New("unsupported IP version: ", version) + } +} + +func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) { + if len(packet) < 20 { + return ICMPPacketInfo{}, E.New("IPv4 packet too short") + } + headerLen := int(packet[0]&0x0F) * 4 + if headerLen < 20 || len(packet) < headerLen+8 { + return ICMPPacketInfo{}, E.New("invalid IPv4 header length") + } + if packet[9] != 1 { + return ICMPPacketInfo{}, E.New("IPv4 packet is not ICMP") + } + sourceIP, ok := netip.AddrFromSlice(packet[12:16]) + if !ok { + return ICMPPacketInfo{}, E.New("invalid IPv4 source address") + } + destinationIP, ok := netip.AddrFromSlice(packet[16:20]) + if !ok { + return ICMPPacketInfo{}, E.New("invalid IPv4 destination address") + } + return ICMPPacketInfo{ + IPVersion: 4, + Protocol: 1, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[headerLen], + ICMPCode: packet[headerLen+1], + Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), + Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), + RawPacket: append([]byte(nil), packet...), + }, nil +} + +func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) { + if len(packet) < 48 { + return ICMPPacketInfo{}, E.New("IPv6 packet too short") + } + if packet[6] != 58 { + return ICMPPacketInfo{}, E.New("IPv6 packet is not ICMP") + } + sourceIP, ok := netip.AddrFromSlice(packet[8:24]) + if !ok { + return ICMPPacketInfo{}, E.New("invalid IPv6 source address") + } + destinationIP, ok := netip.AddrFromSlice(packet[24:40]) + if !ok { + return ICMPPacketInfo{}, E.New("invalid IPv6 destination address") + } + return ICMPPacketInfo{ + IPVersion: 6, + Protocol: 58, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[40], + ICMPCode: packet[41], + Identifier: binary.BigEndian.Uint16(packet[44:46]), + Sequence: binary.BigEndian.Uint16(packet[46:48]), + RawPacket: append([]byte(nil), packet...), + }, nil +} + +func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) { + if traceContext.Traced { + data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1) + data = append(data, packet...) + data = append(data, traceContext.Identity...) + data = append(data, byte(DatagramV2TypeIPWithTrace)) + return data, nil + } + data := make([]byte, 0, len(packet)+1) + data = append(data, packet...) + data = append(data, byte(DatagramV2TypeIP)) + return data, nil +} + +func encodeV3ICMPDatagram(packet []byte) []byte { + data := make([]byte, 0, len(packet)+1) + data = append(data, byte(DatagramV3TypeICMP)) + data = append(data, packet...) + return data +} diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go new file mode 100644 index 0000000000..6f985050f2 --- /dev/null +++ b/protocol/cloudflare/icmp_test.go @@ -0,0 +1,242 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "bytes" + "context" + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +type captureDatagramSender struct { + sent [][]byte +} + +func (s *captureDatagramSender) SendDatagram(data []byte) error { + s.sent = append(s.sent, append([]byte(nil), data...)) + return nil +} + +type fakeDirectRouteDestination struct { + routeContext tun.DirectRouteContext + packets [][]byte + reply func(packet []byte) []byte + closed bool +} + +func (d *fakeDirectRouteDestination) WritePacket(packet *buf.Buffer) error { + data := append([]byte(nil), packet.Bytes()...) + packet.Release() + d.packets = append(d.packets, data) + if d.reply != nil { + reply := d.reply(data) + if reply != nil { + return d.routeContext.WritePacket(reply) + } + } + return nil +} + +func (d *fakeDirectRouteDestination) Close() error { + d.closed = true + return nil +} + +func (d *fakeDirectRouteDestination) IsClosed() bool { + return d.closed +} + +func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) { + var ( + preMatchCalls int + captured adapter.InboundContext + destination *fakeDirectRouteDestination + ) + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + captured = metadata + destination = &fakeDirectRouteDestination{routeContext: routeContext} + return destination, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: router, + } + sender := &captureDatagramSender{} + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) + + source := netip.MustParseAddr("198.18.0.2") + target := netip.MustParseAddr("1.1.1.1") + packet1 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 1) + packet2 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 2) + + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet1); err != nil { + t.Fatal(err) + } + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet2); err != nil { + t.Fatal(err) + } + if preMatchCalls != 1 { + t.Fatalf("expected one direct-route lookup, got %d", preMatchCalls) + } + if captured.Network != N.NetworkICMP { + t.Fatalf("expected NetworkICMP, got %s", captured.Network) + } + if captured.Source.Addr != source || captured.Destination.Addr != target { + t.Fatalf("unexpected metadata source/destination: %#v", captured) + } + if len(destination.packets) != 2 { + t.Fatalf("expected two packets written, got %d", len(destination.packets)) + } + if len(sender.sent) != 0 { + t.Fatalf("expected no reply datagrams, got %d", len(sender.sent)) + } +} + +func TestICMPBridgeHandleV2TracedReply(t *testing.T) { + traceIdentity := bytes.Repeat([]byte{0x7a}, icmpTraceIdentityLength) + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + return &fakeDirectRouteDestination{ + routeContext: routeContext, + reply: buildEchoReply, + }, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) + + request := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7) + request = append(request, traceIdentity...) + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, request); err != nil { + t.Fatal(err) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one reply datagram, got %d", len(sender.sent)) + } + reply := sender.sent[0] + if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) { + t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1]) + } + gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1] + if !bytes.Equal(gotIdentity, traceIdentity) { + t.Fatalf("unexpected trace identity: %x", gotIdentity) + } +} + +func TestICMPBridgeHandleV3Reply(t *testing.T) { + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + return &fakeDirectRouteDestination{ + routeContext: routeContext, + reply: buildEchoReply, + }, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3) + + request := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), 128, 0, 3, 5) + if err := bridge.HandleV3(context.Background(), request); err != nil { + t.Fatal(err) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one reply datagram, got %d", len(sender.sent)) + } + reply := sender.sent[0] + if reply[0] != byte(DatagramV3TypeICMP) { + t.Fatalf("expected v3 ICMP datagram, got %d", reply[0]) + } +} + +func TestICMPBridgeDropsNonEcho(t *testing.T) { + var preMatchCalls int + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + return nil, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: router, + } + sender := &captureDatagramSender{} + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) + + packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 3, 0, 1, 1) + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil { + t.Fatal(err) + } + if preMatchCalls != 0 { + t.Fatalf("expected no route lookup, got %d", preMatchCalls) + } + if len(sender.sent) != 0 { + t.Fatalf("expected no sender datagrams, got %d", len(sender.sent)) + } +} + +func buildEchoReply(packet []byte) []byte { + info, err := ParseICMPPacket(packet) + if err != nil { + panic(err) + } + switch info.IPVersion { + case 4: + return buildIPv4ICMPPacket(info.Destination, info.SourceIP, 0, 0, info.Identifier, info.Sequence) + case 6: + return buildIPv6ICMPPacket(info.Destination, info.SourceIP, 129, 0, info.Identifier, info.Sequence) + default: + panic("unsupported version") + } +} + +func buildIPv4ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte { + packet := make([]byte, 28) + packet[0] = 0x45 + binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet))) + packet[8] = 64 + packet[9] = 1 + copy(packet[12:16], source.AsSlice()) + copy(packet[16:20], destination.AsSlice()) + packet[20] = icmpType + packet[21] = icmpCode + binary.BigEndian.PutUint16(packet[24:26], identifier) + binary.BigEndian.PutUint16(packet[26:28], sequence) + return packet +} + +func buildIPv6ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte { + packet := make([]byte, 48) + packet[0] = 0x60 + binary.BigEndian.PutUint16(packet[4:6], 8) + packet[6] = 58 + packet[7] = 64 + copy(packet[8:24], source.AsSlice()) + copy(packet[24:40], destination.AsSlice()) + packet[40] = icmpType + packet[41] = icmpCode + binary.BigEndian.PutUint16(packet[44:46], identifier) + binary.BigEndian.PutUint16(packet[46:48], sequence) + return packet +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 3b349d3e32..4d38c82511 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -33,7 +33,7 @@ type Inbound struct { inbound.Adapter ctx context.Context cancel context.CancelFunc - router adapter.ConnectionRouterEx + router adapter.Router logger log.ContextLogger credentials Credentials connectorID uuid.UUID From 4579ca9ecc98949cd3b65f9d7c429e2acdeb583f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:23:08 +0800 Subject: [PATCH 15/41] Add cloudflare tunnel bastion and socks special services --- protocol/cloudflare/dispatch.go | 16 ++ protocol/cloudflare/runtime_config.go | 21 ++ protocol/cloudflare/special_service.go | 233 +++++++++++++++++++ protocol/cloudflare/special_service_test.go | 235 ++++++++++++++++++++ 4 files changed, 505 insertions(+) create mode 100644 protocol/cloudflare/special_service.go create mode 100644 protocol/cloudflare/special_service_test.go diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 761e4a978f..29671f0db2 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -220,6 +220,22 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos } else { i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service) } + case ResolvedServiceBastion: + if request.Type != ConnectionTypeWebsocket { + err := E.New("bastion service requires websocket request type") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + i.handleBastionStream(ctx, stream, respWriter, request, metadata) + case ResolvedServiceSocksProxy: + if request.Type != ConnectionTypeWebsocket { + err := E.New("socks-proxy service requires websocket request type") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata) default: err := E.New("unsupported service kind for HTTP/WebSocket request") i.logger.ErrorContext(ctx, err) diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 99fe73c818..276e99d412 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -39,6 +39,8 @@ const ( ResolvedServiceHelloWorld ResolvedServiceUnix ResolvedServiceUnixTLS + ResolvedServiceBastion + ResolvedServiceSocksProxy ) type ResolvedService struct { @@ -392,6 +394,13 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) { switch { case rawService == "": + if originRequest.BastionMode { + return ResolvedService{ + Kind: ResolvedServiceBastion, + Service: "bastion", + OriginRequest: originRequest, + }, nil + } return ResolvedService{}, E.New("missing ingress service") case strings.HasPrefix(rawService, "http_status:"): statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:")) @@ -413,6 +422,18 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) Service: rawService, OriginRequest: originRequest, }, nil + case rawService == "bastion": + return ResolvedService{ + Kind: ResolvedServiceBastion, + Service: rawService, + OriginRequest: originRequest, + }, nil + case rawService == "socks-proxy": + return ResolvedService{ + Kind: ResolvedServiceSocksProxy, + Service: rawService, + OriginRequest: originRequest, + }, nil case strings.HasPrefix(rawService, "unix:"): return ResolvedService{ Kind: ResolvedServiceUnix, diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go new file mode 100644 index 0000000000..61f6cb5749 --- /dev/null +++ b/protocol/cloudflare/special_service.go @@ -0,0 +1,233 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "crypto/sha1" + "encoding/base64" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/transport/v2raywebsocket" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/ws" +) + +var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + destination, err := resolveBastionDestination(request) + if err != nil { + respWriter.WriteResponse(err, nil) + return + } + + targetConn, cleanup, err := i.dialRouterTCP(ctx, M.ParseSocksaddr(destination)) + if err != nil { + respWriter.WriteResponse(err, nil) + return + } + defer cleanup() + + err = respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request))) + if err != nil { + i.logger.ErrorContext(ctx, "write bastion websocket response: ", err) + return + } + + wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) + defer wsConn.Close() + _ = bufio.CopyConn(ctx, wsConn, targetConn) +} + +func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { + err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request))) + if err != nil { + i.logger.ErrorContext(ctx, "write socks-proxy websocket response: ", err) + return + } + + wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) + defer wsConn.Close() + if err := i.serveSocksProxy(ctx, wsConn); err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "socks-proxy stream closed: ", err) + } +} + +func resolveBastionDestination(request *ConnectRequest) (string, error) { + headerValue := requestHeaderValue(request, "Cf-Access-Jump-Destination") + if headerValue == "" { + return "", E.New("missing Cf-Access-Jump-Destination header") + } + if parsed, err := url.Parse(headerValue); err == nil && parsed.Host != "" { + headerValue = parsed.Host + } + return strings.SplitN(headerValue, "/", 2)[0], nil +} + +func websocketResponseHeaders(request *ConnectRequest) http.Header { + header := http.Header{} + header.Set("Connection", "Upgrade") + header.Set("Upgrade", "websocket") + secKey := requestHeaderValue(request, "Sec-WebSocket-Key") + if secKey != "" { + sum := sha1.Sum(append([]byte(secKey), wsAcceptGUID...)) + header.Set("Sec-WebSocket-Accept", base64.StdEncoding.EncodeToString(sum[:])) + } + return header +} + +func requestHeaderValue(request *ConnectRequest, headerName string) string { + for _, entry := range request.Metadata { + if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") { + continue + } + name := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":") + if strings.EqualFold(name, headerName) { + return entry.Val + } + } + return "" +} + +func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) { + input, output := pipe.Pipe() + done := make(chan struct{}) + metadata := adapter.InboundContext{ + Inbound: i.Tag(), + InboundType: i.Type(), + Network: N.NetworkTCP, + Destination: destination, + } + go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { + common.Close(input, output) + close(done) + })) + return input, func() { + common.Close(input, output) + select { + case <-done: + case <-time.After(time.Second): + } + }, nil +} + +func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn) error { + version := make([]byte, 1) + if _, err := io.ReadFull(conn, version); err != nil { + return err + } + if version[0] != 5 { + return E.New("unsupported SOCKS version: ", version[0]) + } + + methodCount := make([]byte, 1) + if _, err := io.ReadFull(conn, methodCount); err != nil { + return err + } + methods := make([]byte, int(methodCount[0])) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + if _, err := conn.Write([]byte{5, 0}); err != nil { + return err + } + + requestHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, requestHeader); err != nil { + return err + } + if requestHeader[0] != 5 { + return E.New("unsupported SOCKS request version: ", requestHeader[0]) + } + if requestHeader[1] != 1 { + _, _ = conn.Write([]byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}) + return E.New("unsupported SOCKS command: ", requestHeader[1]) + } + + destination, err := readSocksDestination(conn, requestHeader[3]) + if err != nil { + return err + } + targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) + if err != nil { + _, _ = conn.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) + return err + } + defer cleanup() + + if _, err := conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + return err + } + return bufio.CopyConn(ctx, conn, targetConn) +} + +func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) { + switch addressType { + case 1: + addr := make([]byte, 4) + if _, err := io.ReadFull(conn, addr); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + ipAddr, ok := netip.AddrFromSlice(addr) + if !ok { + return M.Socksaddr{}, E.New("invalid IPv4 SOCKS destination") + } + return M.SocksaddrFrom(ipAddr, port), nil + case 3: + length := make([]byte, 1) + if _, err := io.ReadFull(conn, length); err != nil { + return M.Socksaddr{}, err + } + host := make([]byte, int(length[0])) + if _, err := io.ReadFull(conn, host); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + return M.ParseSocksaddr(net.JoinHostPort(string(host), strconv.Itoa(int(port)))), nil + case 4: + addr := make([]byte, 16) + if _, err := io.ReadFull(conn, addr); err != nil { + return M.Socksaddr{}, err + } + port, err := readSocksPort(conn) + if err != nil { + return M.Socksaddr{}, err + } + ipAddr, ok := netip.AddrFromSlice(addr) + if !ok { + return M.Socksaddr{}, E.New("invalid IPv6 SOCKS destination") + } + return M.SocksaddrFrom(ipAddr, port), nil + default: + return M.Socksaddr{}, E.New("unsupported SOCKS address type: ", addressType) + } +} + +func readSocksPort(conn net.Conn) (uint16, error) { + port := make([]byte, 2) + if _, err := io.ReadFull(conn, port); err != nil { + return 0, err + } + return uint16(port[0])<<8 | uint16(port[1]), nil +} diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go new file mode 100644 index 0000000000..5c29d40fa2 --- /dev/null +++ b/protocol/cloudflare/special_service_test.go @@ -0,0 +1,235 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/ws" + "github.com/sagernet/ws/wsutil" +) + +type fakeConnectResponseWriter struct { + status int + headers http.Header + err error + done chan struct{} +} + +func (w *fakeConnectResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + w.err = responseError + w.headers = make(http.Header) + for _, entry := range metadata { + switch { + case entry.Key == metadataHTTPStatus: + status, _ := strconv.Atoi(entry.Val) + w.status = status + case len(entry.Key) > len(metadataHTTPHeader)+1 && entry.Key[:len(metadataHTTPHeader)+1] == metadataHTTPHeader+":": + w.headers.Add(entry.Key[len(metadataHTTPHeader)+1:], entry.Val) + } + } + if w.done != nil { + close(w.done) + w.done = nil + } + return nil +} + +func newSpecialServiceInbound(t *testing.T) *Inbound { + t.Helper() + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal(err) + } + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal(err) + } + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + configManager: configManager, + } +} + +func TestHandleBastionStream(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(conn) + } + }() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + {Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bastion connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + if respWriter.headers.Get("Sec-WebSocket-Accept") == "" { + t.Fatal("expected websocket accept header") + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, opCode, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if opCode != ws.OpBinary { + t.Fatalf("expected binary frame, got %v", opCode) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bastion stream did not exit") + } +} + +func TestHandleSocksProxyStream(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(conn) + } + }() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for socks-proxy connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte{5, 1, 0}); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != string([]byte{5, 0}) { + t.Fatalf("unexpected auth response: %v", data) + } + + host, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + requestBytes := []byte{5, 1, 0, 1} + requestBytes = append(requestBytes, net.ParseIP(host).To4()...) + requestBytes = append(requestBytes, byte(port>>8), byte(port)) + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, requestBytes); err != nil { + t.Fatal(err) + } + data, _, err = wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if len(data) != 10 || data[1] != 0 { + t.Fatalf("unexpected connect response: %v", data) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, _, err = wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } +} From 25a94ac5b666337fe5bfc3c7a4972210c42a4436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:27:43 +0800 Subject: [PATCH 16/41] Support router-backed cloudflare stream services --- protocol/cloudflare/dispatch.go | 8 +++ protocol/cloudflare/special_service.go | 9 ++- protocol/cloudflare/special_service_test.go | 72 +++++++++++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 29671f0db2..f2b9f8d2c3 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -214,6 +214,14 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos } else { i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service) } + case ResolvedServiceStream: + if request.Type != ConnectionTypeWebsocket { + err := E.New("stream service requires websocket request type") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + i.handleStreamService(ctx, stream, respWriter, request, metadata, service.Destination) case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: if request.Type == ConnectionTypeHTTP { i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index 61f6cb5749..c7214ec9a6 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -34,8 +34,15 @@ func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCl respWriter.WriteResponse(err, nil) return } + i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination)) +} + +func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, destination M.Socksaddr) { + i.handleRouterBackedStream(ctx, stream, respWriter, request, destination) +} - targetConn, cleanup, err := i.dialRouterTCP(ctx, M.ParseSocksaddr(destination)) +func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr) { + targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) if err != nil { respWriter.WriteResponse(err, nil) return diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 5c29d40fa2..0e4b2083fb 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -16,6 +16,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/ws" "github.com/sagernet/ws/wsutil" ) @@ -233,3 +234,74 @@ func TestHandleSocksProxyStream(t *testing.T) { t.Fatal("socks-proxy stream did not exit") } } + +func TestHandleStreamService(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(conn) + } + }() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, M.ParseSocksaddr(listener.Addr().String())) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for stream service connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, opCode, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if opCode != ws.OpBinary { + t.Fatalf("expected binary frame, got %v", opCode) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("stream service did not exit") + } +} From 854718992fdf73340c133d7f0e04c3d7f8a9569c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:33:44 +0800 Subject: [PATCH 17/41] Honor cloudflare warp active flow limits --- protocol/cloudflare/datagram_v2.go | 10 ++- protocol/cloudflare/datagram_v3.go | 11 ++- protocol/cloudflare/dispatch.go | 8 ++ protocol/cloudflare/flow_limiter.go | 34 ++++++++ protocol/cloudflare/flow_limiter_test.go | 100 +++++++++++++++++++++++ protocol/cloudflare/inbound.go | 6 ++ 6 files changed, 165 insertions(+), 4 deletions(-) create mode 100644 protocol/cloudflare/flow_limiter.go create mode 100644 protocol/cloudflare/flow_limiter_test.go diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 49fbb4d8f6..a17e7c5cc3 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -133,6 +133,11 @@ func (m *DatagramV2Muxer) RegisterSession( m.sessionAccess.Unlock() return nil } + limit := m.inbound.maxActiveFlows() + if !m.inbound.flowLimiter.Acquire(limit) { + m.sessionAccess.Unlock() + return E.New("too many active flows") + } session := newUDPSession(sessionID, destination, closeAfterIdle, m) m.sessions[sessionID] = session @@ -140,7 +145,7 @@ func (m *DatagramV2Muxer) RegisterSession( m.logger.Info("registered V2 UDP session ", sessionID, " to ", destination) - go m.serveSession(ctx, session) + go m.serveSession(ctx, session, limit) return nil } @@ -159,8 +164,9 @@ func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID) { } } -func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession) { +func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession, limit uint64) { defer m.UnregisterSession(session.id) + defer m.inbound.flowLimiter.Release(limit) metadata := adapter.InboundContext{ Inbound: m.inbound.Tag(), diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index 1f8f1aacd6..b8c9796d5c 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -154,6 +154,12 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { } return } + limit := m.inbound.maxActiveFlows() + if !m.inbound.flowLimiter.Acquire(limit) { + m.sessionAccess.Unlock() + m.sendRegistrationResponse(requestID, v3ResponseTooManyActiveFlows, "") + return + } session := newV3Session(requestID, destination, closeAfterIdle, m) m.sessions[requestID] = session @@ -167,7 +173,7 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { session.writeToOrigin(data[offset:]) } - go m.serveV3Session(ctx, session) + go m.serveV3Session(ctx, session, limit) } func (m *DatagramV3Muxer) handlePayload(data []byte) { @@ -222,8 +228,9 @@ func (m *DatagramV3Muxer) unregisterSession(requestID RequestID) { } } -func (m *DatagramV3Muxer) serveV3Session(ctx context.Context, session *v3Session) { +func (m *DatagramV3Muxer) serveV3Session(ctx context.Context, session *v3Session, limit uint64) { defer m.unregisterSession(session.id) + defer m.inbound.flowLimiter.Release(limit) metadata := adapter.InboundContext{ Inbound: m.inbound.Tag(), diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index f2b9f8d2c3..2dc2c73e70 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -185,6 +185,14 @@ func parseHTTPDestination(dest string) M.Socksaddr { func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, metadata adapter.InboundContext) { metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound TCP connection to ", metadata.Destination) + limit := i.maxActiveFlows() + if !i.flowLimiter.Acquire(limit) { + err := E.New("too many active flows") + i.logger.ErrorContext(ctx, err) + respWriter.WriteResponse(err, nil) + return + } + defer i.flowLimiter.Release(limit) err := respWriter.WriteResponse(nil, nil) if err != nil { diff --git a/protocol/cloudflare/flow_limiter.go b/protocol/cloudflare/flow_limiter.go new file mode 100644 index 0000000000..cfe753f6b7 --- /dev/null +++ b/protocol/cloudflare/flow_limiter.go @@ -0,0 +1,34 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import "sync" + +type FlowLimiter struct { + access sync.Mutex + active uint64 +} + +func (l *FlowLimiter) Acquire(limit uint64) bool { + if limit == 0 { + return true + } + l.access.Lock() + defer l.access.Unlock() + if l.active >= limit { + return false + } + l.active++ + return true +} + +func (l *FlowLimiter) Release(limit uint64) { + if limit == 0 { + return + } + l.access.Lock() + defer l.access.Unlock() + if l.active > 0 { + l.active-- + } +} diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go new file mode 100644 index 0000000000..ad27c534b8 --- /dev/null +++ b/protocol/cloudflare/flow_limiter_test.go @@ -0,0 +1,100 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/binary" + "net" + "testing" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + + "github.com/google/uuid" +) + +func newLimitedInbound(t *testing.T, limit uint64) *Inbound { + t.Helper() + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal(err) + } + configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + if err != nil { + t.Fatal(err) + } + config := configManager.Snapshot() + config.WarpRouting.MaxActiveFlows = limit + configManager.activeConfig = config + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + configManager: configManager, + flowLimiter: &FlowLimiter{}, + } +} + +func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) { + inboundInstance := newLimitedInbound(t, 1) + if !inboundInstance.flowLimiter.Acquire(1) { + t.Fatal("failed to pre-acquire limiter") + } + + stream, peer := net.Pipe() + defer stream.Close() + defer peer.Close() + respWriter := &fakeConnectResponseWriter{} + inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{}) + if respWriter.err == nil { + t.Fatal("expected too many active flows error") + } +} + +func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) { + inboundInstance := newLimitedInbound(t, 1) + if !inboundInstance.flowLimiter.Acquire(1) { + t.Fatal("failed to pre-acquire limiter") + } + muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger) + err := muxer.RegisterSession(context.Background(), uuidTest(1), net.IPv4(1, 1, 1, 1), 53, 0) + if err == nil { + t.Fatal("expected too many active flows error") + } +} + +func TestDatagramV3RegistrationTooManyActiveFlows(t *testing.T) { + inboundInstance := newLimitedInbound(t, 1) + if !inboundInstance.flowLimiter.Acquire(1) { + t.Fatal("failed to pre-acquire limiter") + } + sender := &captureDatagramSender{} + muxer := NewDatagramV3Muxer(inboundInstance, sender, inboundInstance.logger) + + requestID := RequestID{} + requestID[15] = 1 + payload := make([]byte, 1+1+2+2+16+4) + payload[0] = 0 + binary.BigEndian.PutUint16(payload[1:3], 53) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + copy(payload[21:25], []byte{1, 1, 1, 1}) + + muxer.handleRegistration(context.Background(), payload) + if len(sender.sent) != 1 { + t.Fatalf("expected one registration response, got %d", len(sender.sent)) + } + if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseTooManyActiveFlows { + t.Fatalf("unexpected v3 response: %v", sender.sent[0]) + } +} + +func uuidTest(last byte) uuid.UUID { + var value uuid.UUID + value[15] = last + return value +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 4d38c82511..6d31eaa057 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -44,6 +44,7 @@ type Inbound struct { datagramVersion string gracePeriod time.Duration configManager *ConfigManager + flowLimiter *FlowLimiter connectionAccess sync.Mutex connections []io.Closer @@ -119,6 +120,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo datagramVersion: datagramVersion, gracePeriod: gracePeriod, configManager: configManager, + flowLimiter: &FlowLimiter{}, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil @@ -174,6 +176,10 @@ func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { return result } +func (i *Inbound) maxActiveFlows() uint64 { + return i.configManager.Snapshot().WarpRouting.MaxActiveFlows +} + func (i *Inbound) Close() error { i.cancel() i.done.Wait() From ed6be9b0785ec9f8554b9e0582b369c0dcbb23c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:50:45 +0800 Subject: [PATCH 18/41] Validate cloudflare access protected origins --- go.mod | 2 + go.sum | 2 + protocol/cloudflare/access.go | 104 ++++++++++++++++++++++++++ protocol/cloudflare/access_test.go | 92 +++++++++++++++++++++++ protocol/cloudflare/dispatch.go | 12 +++ protocol/cloudflare/inbound.go | 2 + protocol/cloudflare/runtime_config.go | 3 + 7 files changed, 217 insertions(+) create mode 100644 protocol/cloudflare/access.go create mode 100644 protocol/cloudflare/access_test.go diff --git a/go.mod b/go.mod index 9709176e4e..fa945dfce2 100644 --- a/go.mod +++ b/go.mod @@ -75,6 +75,7 @@ require ( github.com/andybalholm/brotli v1.1.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect + github.com/coreos/go-oidc/v3 v3.12.0 // indirect github.com/database64128/netx-go v0.1.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect @@ -84,6 +85,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gaissmai/bart v0.18.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect diff --git a/go.sum b/go.sum index c5d7315e8d..fe44e4b2bd 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9 github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo= +github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo= github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI= diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go new file mode 100644 index 0000000000..9407d03122 --- /dev/null +++ b/protocol/cloudflare/access.go @@ -0,0 +1,104 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" + E "github.com/sagernet/sing/common/exceptions" +) + +const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion" + +var newAccessValidator = func(access AccessConfig) (accessValidator, error) { + issuerURL := accessIssuerURL(access.TeamName, access.Environment) + keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs") + verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ + SkipClientIDCheck: true, + }) + return &oidcAccessValidator{ + verifier: verifier, + audTags: append([]string(nil), access.AudTag...), + }, nil +} + +type accessValidator interface { + Validate(ctx context.Context, request *http.Request) error +} + +type oidcAccessValidator struct { + verifier *oidc.IDTokenVerifier + audTags []string +} + +func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error { + accessJWT := request.Header.Get(accessJWTAssertionHeader) + if accessJWT == "" { + return E.New("missing access jwt assertion") + } + token, err := v.verifier.Verify(ctx, accessJWT) + if err != nil { + return err + } + if len(v.audTags) == 0 { + return nil + } + for _, jwtAudTag := range token.Audience { + for _, acceptedAudTag := range v.audTags { + if acceptedAudTag == jwtAudTag { + return nil + } + } + } + return E.New("access token audience does not match configured aud_tag") +} + +func accessIssuerURL(teamName string, environment string) string { + if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") { + return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName) + } + return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName) +} + +func validateAccessConfiguration(access AccessConfig) error { + if !access.Required { + return nil + } + if access.TeamName == "" && len(access.AudTag) > 0 { + return E.New("access.team_name cannot be blank when access.aud_tag is present") + } + return nil +} + +func accessValidatorKey(access AccessConfig) string { + return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",") +} + +type accessValidatorCache struct { + access sync.RWMutex + values map[string]accessValidator +} + +func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) { + key := accessValidatorKey(accessConfig) + c.access.RLock() + validator, loaded := c.values[key] + c.access.RUnlock() + if loaded { + return validator, nil + } + + validator, err := newAccessValidator(accessConfig) + if err != nil { + return nil, err + } + c.access.Lock() + c.values[key] = validator + c.access.Unlock() + return validator, nil +} diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go new file mode 100644 index 0000000000..594f94d77c --- /dev/null +++ b/protocol/cloudflare/access_test.go @@ -0,0 +1,92 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "io" + "net" + "net/http" + "testing" + + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +type fakeAccessValidator struct { + err error +} + +func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error { + return v.err +} + +func newAccessTestInbound(t *testing.T) *Inbound { + t.Helper() + logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) + if err != nil { + t.Fatal(err) + } + return &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + logger: logFactory.NewLogger("test"), + accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, + router: &testRouter{}, + } +} + +func TestValidateAccessConfiguration(t *testing.T) { + err := validateAccessConfiguration(AccessConfig{ + Required: true, + AudTag: []string{"aud"}, + }) + if err == nil { + t.Fatal("expected access config validation error") + } +} + +func TestRoundTripHTTPAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory + }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } + + inboundInstance := newAccessTestInbound(t) + service := ResolvedService{ + Kind: ResolvedServiceHTTP, + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + } + serverSide, clientSide := net.Pipe() + defer serverSide.Close() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeHTTP, + Dest: "http://127.0.0.1:8083", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + go func() { + defer clientSide.Close() + _, _ = io.Copy(io.Discard, clientSide) + }() + + inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{}) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 2dc2c73e70..7cf664aca1 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -321,6 +321,18 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, defer cancel() httpRequest = httpRequest.WithContext(requestCtx) } + if service.OriginRequest.Access.Required { + validator, err := i.accessCache.Get(service.OriginRequest.Access) + if err != nil { + i.logger.ErrorContext(ctx, "create access validator: ", err) + respWriter.WriteResponse(err, nil) + return + } + if err := validator.Validate(requestCtx, httpRequest); err != nil { + respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) + return + } + } httpClient := &http.Client{ Transport: transport, diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 6d31eaa057..e7cbdb4b62 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -45,6 +45,7 @@ type Inbound struct { gracePeriod time.Duration configManager *ConfigManager flowLimiter *FlowLimiter + accessCache *accessValidatorCache connectionAccess sync.Mutex connections []io.Closer @@ -121,6 +122,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod: gracePeriod, configManager: configManager, flowLimiter: &FlowLimiter{}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 276e99d412..c35c5505c9 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -363,6 +363,9 @@ func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []lo if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil { return nil, err } + if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil { + return nil, err + } service, err := parseResolvedService(rule.Service, rule.OriginRequest) if err != nil { return nil, err From 1ea083cd6ffae9e9b47936e31c1da9b543390e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:52:41 +0800 Subject: [PATCH 19/41] Apply cloudflare origin proxy transport options --- protocol/cloudflare/dispatch.go | 25 ++++++++++- protocol/cloudflare/origin_request_test.go | 49 +++++++++++++++++++++- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 7cf664aca1..6d049d66c9 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -393,6 +393,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter return input, nil }, } + applyHTTPTransportProxy(transport, originRequest) return transport, func() { common.Close(input, output) select { @@ -403,6 +404,13 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter } func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { + dialer := &net.Dialer{ + Timeout: service.OriginRequest.ConnectTimeout, + KeepAlive: service.OriginRequest.TCPKeepAlive, + } + if service.OriginRequest.NoHappyEyeballs { + dialer.FallbackDelay = -1 + } transport := &http.Transport{ DisableCompression: true, ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, @@ -412,14 +420,13 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost), } + applyHTTPTransportProxy(transport, service.OriginRequest) switch service.Kind { case ResolvedServiceUnix, ResolvedServiceUnixTLS: - dialer := &net.Dialer{} transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { return dialer.DialContext(ctx, "unix", service.UnixPath) } case ResolvedServiceHelloWorld: - dialer := &net.Dialer{} target := service.BaseURL.Host transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", target) @@ -449,6 +456,20 @@ func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) return tlsConfig } +func applyHTTPTransportProxy(transport *http.Transport, originRequest OriginRequestConfig) { + if originRequest.ProxyAddress == "" || originRequest.ProxyPort == 0 { + return + } + switch strings.ToLower(originRequest.ProxyType) { + case "", "http": + proxyURL := &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(originRequest.ProxyAddress, strconv.Itoa(int(originRequest.ProxyPort))), + } + transport.Proxy = http.ProxyURL(proxyURL) + } +} + func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string { if originRequest.OriginServerName != "" { return originRequest.OriginServerName diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index b56a0a52f2..d8a6716ab4 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -2,7 +2,11 @@ package cloudflare -import "testing" +import ( + "net/http" + "net/url" + "testing" +) func TestOriginTLSServerName(t *testing.T) { t.Run("origin server name overrides host", func(t *testing.T) { @@ -31,3 +35,46 @@ func TestOriginTLSServerName(t *testing.T) { } }) } + +func TestApplyHTTPTransportProxy(t *testing.T) { + transport := &http.Transport{} + applyHTTPTransportProxy(transport, OriginRequestConfig{ + ProxyAddress: "127.0.0.1", + ProxyPort: 8080, + ProxyType: "http", + }) + if transport.Proxy == nil { + t.Fatal("expected proxy function to be configured") + } + proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}) + if err != nil { + t.Fatal(err) + } + if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:8080" { + t.Fatalf("unexpected proxy URL: %#v", proxyURL) + } +} + +func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) { + inbound := &Inbound{} + transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{ + Kind: ResolvedServiceHelloWorld, + BaseURL: &url.URL{ + Scheme: "http", + Host: "127.0.0.1:8080", + }, + OriginRequest: OriginRequestConfig{ + NoHappyEyeballs: true, + }, + }, "") + if err != nil { + t.Fatal(err) + } + defer cleanup() + if transport.Proxy != nil { + t.Fatal("expected no proxy when proxy fields are empty") + } + if transport.DialContext == nil { + t.Fatal("expected custom direct dial context") + } +} From 289101fc566cc7154b2983b19d6096c494aea935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 12:56:05 +0800 Subject: [PATCH 20/41] Enforce cloudflare access on all ingress services --- protocol/cloudflare/access_test.go | 84 +++++++++++++++++++++++++----- protocol/cloudflare/dispatch.go | 42 ++++++++++----- 2 files changed, 101 insertions(+), 25 deletions(-) diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 594f94d77c..357cd9f431 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -4,16 +4,16 @@ package cloudflare import ( "context" - "io" - "net" "net/http" "testing" + "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" ) type fakeAccessValidator struct { @@ -58,34 +58,94 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) { } inboundInstance := newAccessTestInbound(t) - service := ResolvedService{ - Kind: ResolvedServiceHTTP, + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeHTTP, + Dest: "http://127.0.0.1:8083/test", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceHTTP, + Destination: M.ParseSocksaddr("127.0.0.1:8083"), OriginRequest: OriginRequestConfig{ Access: AccessConfig{ Required: true, TeamName: "team", }, }, + }) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) } - serverSide, clientSide := net.Pipe() - defer serverSide.Close() - defer clientSide.Close() +} +func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory + }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } + + inboundInstance := newAccessTestInbound(t) respWriter := &fakeConnectResponseWriter{} request := &ConnectRequest{ Type: ConnectionTypeHTTP, - Dest: "http://127.0.0.1:8083", + Dest: "https://example.com/status", Metadata: []Metadata{ {Key: metadataHTTPMethod, Val: http.MethodGet}, {Key: metadataHTTPHost, Val: "example.com"}, }, } - go func() { - defer clientSide.Close() - _, _ = io.Copy(io.Discard, clientSide) + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStatus, + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + StatusCode: 404, + }) + if respWriter.status != http.StatusForbidden { + t.Fatalf("expected 403, got %d", respWriter.status) + } +} + +func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) { + originalFactory := newAccessValidator + defer func() { + newAccessValidator = originalFactory }() + newAccessValidator = func(access AccessConfig) (accessValidator, error) { + return &fakeAccessValidator{err: E.New("forbidden")}, nil + } - inboundInstance.roundTripHTTP(context.Background(), serverSide, respWriter, request, service, &http.Transport{}) + inboundInstance := newAccessTestInbound(t) + respWriter := &fakeConnectResponseWriter{} + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Dest: "https://example.com/ws", + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr("127.0.0.1:8080"), + OriginRequest: OriginRequestConfig{ + Access: AccessConfig{ + Required: true, + TeamName: "team", + }, + }, + }) if respWriter.status != http.StatusForbidden { t.Fatalf("expected 403, got %d", respWriter.status) } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 6d049d66c9..b33dea97c1 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -208,9 +208,29 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser } func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request) + if err != nil { + i.logger.ErrorContext(ctx, "build request for access validation: ", err) + respWriter.WriteResponse(err, nil) + return + } + validationRequest = applyOriginRequest(validationRequest, service.OriginRequest) + if service.OriginRequest.Access.Required { + validator, err := i.accessCache.Get(service.OriginRequest.Access) + if err != nil { + i.logger.ErrorContext(ctx, "create access validator: ", err) + respWriter.WriteResponse(err, nil) + return + } + if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil { + respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) + return + } + } + switch service.Kind { case ResolvedServiceStatus: - err := respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) + err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{})) if err != nil { i.logger.ErrorContext(ctx, "write status service response: ", err) } @@ -321,18 +341,6 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, defer cancel() httpRequest = httpRequest.WithContext(requestCtx) } - if service.OriginRequest.Access.Required { - validator, err := i.accessCache.Get(service.OriginRequest.Access) - if err != nil { - i.logger.ErrorContext(ctx, "create access validator: ", err) - respWriter.WriteResponse(err, nil) - return - } - if err := validator.Validate(requestCtx, httpRequest); err != nil { - respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{})) - return - } - } httpClient := &http.Client{ Transport: transport, @@ -498,6 +506,14 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig return request } +func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) { + return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{ + Dest: connectRequest.Dest, + Type: connectRequest.Type, + Metadata: append([]Metadata(nil), connectRequest.Metadata...), + }, http.NoBody) +} + func bidirectionalCopy(left, right io.ReadWriteCloser) { var closeOnce sync.Once closeBoth := func() { From d017cbe0087df16b6f640b50a40c1511b5850006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:09:09 +0800 Subject: [PATCH 21/41] Stabilize cloudflare edge transport fallback --- protocol/cloudflare/connection_http2.go | 5 +- protocol/cloudflare/connection_quic.go | 79 ++++++++++++++++++++++--- protocol/cloudflare/inbound.go | 7 ++- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 42d9bffeaf..2de68d4648 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -67,9 +67,8 @@ func NewHTTP2Connection( } tlsConfig := &tls.Config{ - RootCAs: rootCAs, - ServerName: h2EdgeSNI, - CurvePreferences: []tls.CurveID{tls.CurveP256}, + RootCAs: rootCAs, + ServerName: h2EdgeSNI, } dialer := &net.Dialer{} diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 549a6eef47..f06ad714ff 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -5,8 +5,10 @@ package cloudflare import ( "context" "crypto/tls" + "fmt" "io" "net" + "runtime" "sync" "time" @@ -36,7 +38,7 @@ func quicInitialPacketSize(ipVersion int) uint16 { // QUICConnection manages a single QUIC connection to the Cloudflare edge. type QUICConnection struct { - conn *quic.Conn + conn quicConnection logger log.ContextLogger edgeAddr *EdgeAddr connIndex uint8 @@ -51,6 +53,31 @@ type QUICConnection struct { closeOnce sync.Once } +type quicConnection interface { + OpenStream() (*quic.Stream, error) + AcceptStream(ctx context.Context) (*quic.Stream, error) + ReceiveDatagram(ctx context.Context) ([]byte, error) + SendDatagram(data []byte) error + LocalAddr() net.Addr + CloseWithError(code quic.ApplicationErrorCode, reason string) error +} + +type closeableQUICConn struct { + *quic.Conn + udpConn *net.UDPConn +} + +func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reason string) error { + err := c.Conn.CloseWithError(code, reason) + _ = c.udpConn.Close() + return err +} + +var ( + quicPortByConnIndex = make(map[uint8]int) + quicPortAccess sync.Mutex +) + // NewQUICConnection dials the edge and establishes a QUIC connection. func NewQUICConnection( ctx context.Context, @@ -69,10 +96,9 @@ func NewQUICConnection( } tlsConfig := &tls.Config{ - RootCAs: rootCAs, - ServerName: quicEdgeSNI, - NextProtos: []string{quicEdgeALPN}, - CurvePreferences: []tls.CurveID{tls.CurveP256}, + RootCAs: rootCAs, + ServerName: quicEdgeSNI, + NextProtos: []string{quicEdgeALPN}, } quicConfig := &quic.Config{ @@ -85,13 +111,19 @@ func NewQUICConnection( InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } - conn, err := quic.DialAddr(ctx, edgeAddr.UDP.String(), tlsConfig, quicConfig) + udpConn, err := createUDPConnForConnIndex(connIndex, edgeAddr) + if err != nil { + return nil, E.Cause(err, "listen UDP for QUIC edge") + } + + conn, err := quic.Dial(ctx, udpConn, edgeAddr.UDP, tlsConfig, quicConfig) if err != nil { + udpConn.Close() return nil, E.Cause(err, "dial QUIC edge") } return &QUICConnection{ - conn: conn, + conn: &closeableQUICConn{Conn: conn, udpConn: udpConn}, logger: logger, edgeAddr: edgeAddr, connIndex: connIndex, @@ -103,6 +135,39 @@ func NewQUICConnection( }, nil } +func createUDPConnForConnIndex(connIndex uint8, edgeAddr *EdgeAddr) (*net.UDPConn, error) { + quicPortAccess.Lock() + defer quicPortAccess.Unlock() + + network := "udp" + if runtime.GOOS == "darwin" { + if edgeAddr.IPVersion == 4 { + network = "udp4" + } else { + network = "udp6" + } + } + + if port, loaded := quicPortByConnIndex[connIndex]; loaded { + udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err == nil { + return udpConn, nil + } + } + + udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: 0}) + if err != nil { + return nil, err + } + udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr) + if !ok { + udpConn.Close() + return nil, fmt.Errorf("unexpected local UDP address type %T", udpConn.LocalAddr()) + } + quicPortByConnIndex[connIndex] = udpAddr.Port + return udpConn, nil +} + // Serve runs the QUIC connection: registers, accepts streams, handles datagrams. // Blocks until the context is cancelled or a fatal error occurs. func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index e7cbdb4b62..ae48ba40ed 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -268,7 +268,12 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features switch protocol { case "quic": - return i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts) + err := i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts) + if err == nil || i.ctx.Err() != nil { + return err + } + i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err) + return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) case "http2": return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) default: From 2321e941e0d0824a3df0c3e8cd9df097cd5dfd01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:52:55 +0800 Subject: [PATCH 22/41] Route cloudflare control plane through configurable dialer --- option/cloudflare_tunnel.go | 1 + protocol/cloudflare/access.go | 17 +++++++++--- protocol/cloudflare/access_test.go | 16 ++++++----- protocol/cloudflare/connection_http2.go | 19 +++++++++---- protocol/cloudflare/connection_quic.go | 31 ++++++++-------------- protocol/cloudflare/edge_discovery.go | 11 ++++---- protocol/cloudflare/edge_discovery_test.go | 4 ++- protocol/cloudflare/helpers_test.go | 2 ++ protocol/cloudflare/inbound.go | 18 ++++++++++--- 9 files changed, 75 insertions(+), 44 deletions(-) diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index c0fdbfa879..bf388a044e 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -7,6 +7,7 @@ type CloudflareTunnelInboundOptions struct { CredentialPath string `json:"credential_path,omitempty"` HAConnections int `json:"ha_connections,omitempty"` Protocol string `json:"protocol,omitempty"` + ControlDialer DialerOptions `json:"control_dialer,omitempty"` EdgeIPVersion int `json:"edge_ip_version,omitempty"` DatagramVersion string `json:"datagram_version,omitempty"` GracePeriod badoption.Duration `json:"grace_period,omitempty"` diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go index 9407d03122..75c1e8ada4 100644 --- a/protocol/cloudflare/access.go +++ b/protocol/cloudflare/access.go @@ -5,19 +5,29 @@ package cloudflare import ( "context" "fmt" + "net" "net/http" "strings" "sync" "github.com/coreos/go-oidc/v3/oidc" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion" -var newAccessValidator = func(access AccessConfig) (accessValidator, error) { +var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { issuerURL := accessIssuerURL(access.TeamName, access.Environment) - keySet := oidc.NewRemoteKeySet(context.Background(), issuerURL+"/cdn-cgi/access/certs") + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(address)) + }, + }, + } + keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs") verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ SkipClientIDCheck: true, }) @@ -82,6 +92,7 @@ func accessValidatorKey(access AccessConfig) string { type accessValidatorCache struct { access sync.RWMutex values map[string]accessValidator + dialer N.Dialer } func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) { @@ -93,7 +104,7 @@ func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, return validator, nil } - validator, err := newAccessValidator(accessConfig) + validator, err := newAccessValidator(accessConfig, c.dialer) if err != nil { return nil, err } diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 357cd9f431..8c7d2b9e10 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -14,6 +14,7 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type fakeAccessValidator struct { @@ -31,10 +32,11 @@ func newAccessTestInbound(t *testing.T) *Inbound { t.Fatal(err) } return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), - logger: logFactory.NewLogger("test"), - accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, - router: &testRouter{}, + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + logger: logFactory.NewLogger("test"), + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, + router: &testRouter{}, + controlDialer: N.SystemDialer, } } @@ -53,7 +55,7 @@ func TestRoundTripHTTPAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } @@ -87,7 +89,7 @@ func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } @@ -121,7 +123,7 @@ func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) { defer func() { newAccessValidator = originalFactory }() - newAccessValidator = func(access AccessConfig) (accessValidator, error) { + newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) { return &fakeAccessValidator{err: E.New("forbidden")}, nil } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 2de68d4648..9fed72cf80 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -17,6 +17,7 @@ import ( "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + M "github.com/sagernet/sing/common/metadata" "github.com/google/uuid" "golang.org/x/net/http2" @@ -71,8 +72,7 @@ func NewHTTP2Connection( ServerName: h2EdgeSNI, } - dialer := &net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", edgeAddr.TCP.String()) + tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) if err != nil { return nil, E.Cause(err, "dial edge TCP") } @@ -113,10 +113,13 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { Handler: c, }) - if c.registrationResult != nil { - return nil + if ctx.Err() != nil { + return ctx.Err() + } + if c.registrationResult == nil { + return E.New("edge connection closed before registration") } - return E.New("edge connection closed before registration") + return E.New("edge connection closed") } func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -167,6 +170,12 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque " (connection ", result.ConnectionID, ")") <-ctx.Done() + unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) + defer cancel() + err = c.registrationClient.Unregister(unregisterCtx) + if err != nil { + c.logger.Debug("failed to unregister: ", err) + } c.registrationClient.Close() } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index f06ad714ff..fd2e56a3bb 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -8,13 +8,14 @@ import ( "fmt" "io" "net" - "runtime" "sync" "time" "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/google/uuid" ) @@ -88,6 +89,7 @@ func NewQUICConnection( features []string, numPreviousAttempts uint8, gracePeriod time.Duration, + controlDialer N.Dialer, logger log.ContextLogger, ) (*QUICConnection, error) { rootCAs, err := cloudflareRootCertPool() @@ -111,7 +113,7 @@ func NewQUICConnection( InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } - udpConn, err := createUDPConnForConnIndex(connIndex, edgeAddr) + udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer) if err != nil { return nil, E.Cause(err, "listen UDP for QUIC edge") } @@ -135,30 +137,19 @@ func NewQUICConnection( }, nil } -func createUDPConnForConnIndex(connIndex uint8, edgeAddr *EdgeAddr) (*net.UDPConn, error) { +func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) { quicPortAccess.Lock() defer quicPortAccess.Unlock() - network := "udp" - if runtime.GOOS == "darwin" { - if edgeAddr.IPVersion == 4 { - network = "udp4" - } else { - network = "udp6" - } - } - - if port, loaded := quicPortByConnIndex[connIndex]; loaded { - udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err == nil { - return udpConn, nil - } - } - - udpConn, err := net.ListenUDP(network, &net.UDPAddr{Port: 0}) + packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) if err != nil { return nil, err } + udpConn, ok := packetConn.(*net.UDPConn) + if !ok { + packetConn.Close() + return nil, fmt.Errorf("unexpected packet conn type %T", packetConn) + } udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr) if !ok { udpConn.Close() diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go index 0c08bcbf86..922063ce43 100644 --- a/protocol/cloudflare/edge_discovery.go +++ b/protocol/cloudflare/edge_discovery.go @@ -9,6 +9,8 @@ import ( "time" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) const ( @@ -37,10 +39,10 @@ type EdgeAddr struct { // DiscoverEdge performs SRV-based edge discovery and returns addresses // partitioned into regions (typically 2). -func DiscoverEdge(ctx context.Context, region string) ([][]*EdgeAddr, error) { +func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) { regions, err := lookupEdgeSRV(region) if err != nil { - regions, err = lookupEdgeSRVWithDoT(ctx, region) + regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer) if err != nil { return nil, E.Cause(err, "edge discovery") } @@ -59,12 +61,11 @@ func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) { return resolveSRVRecords(addrs) } -func lookupEdgeSRVWithDoT(ctx context.Context, region string) ([][]*EdgeAddr, error) { +func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) { resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { - var dialer net.Dialer - conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr) + conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr)) if err != nil { return nil, err } diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index c282009d0d..930fd46be2 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -6,10 +6,12 @@ import ( "context" "net" "testing" + + N "github.com/sagernet/sing/common/network" ) func TestDiscoverEdge(t *testing.T) { - regions, err := DiscoverEdge(context.Background(), "") + regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer) if err != nil { t.Fatal("DiscoverEdge: ", err) } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index f06a5fa2a8..d873a40b11 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -192,6 +192,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i configManager: configManager, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + controlDialer: N.SystemDialer, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, } t.Cleanup(func() { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index ae48ba40ed..5e30b89d18 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -16,11 +16,13 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" + boxDialer "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + N "github.com/sagernet/sing/common/network" "github.com/google/uuid" ) @@ -46,6 +48,7 @@ type Inbound struct { configManager *ConfigManager flowLimiter *FlowLimiter accessCache *accessValidatorCache + controlDialer N.Dialer connectionAccess sync.Mutex connections []io.Closer @@ -95,6 +98,14 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if err != nil { return nil, E.Cause(err, "build cloudflare tunnel runtime config") } + controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{ + Context: ctx, + Options: options.ControlDialer, + RemoteIsDomain: true, + }) + if err != nil { + return nil, E.Cause(err, "build cloudflare tunnel control dialer") + } region := options.Region if region != "" && credentials.Endpoint != "" { @@ -122,7 +133,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod: gracePeriod, configManager: configManager, flowLimiter: &FlowLimiter{}, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator)}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, + controlDialer: controlDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), }, nil @@ -135,7 +147,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections") - regions, err := DiscoverEdge(i.ctx, i.region) + regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer) if err != nil { return E.Cause(err, "discover edge") } @@ -287,7 +299,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, numPreviousAttempts, i.gracePeriod, i.logger, + features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger, ) if err != nil { return E.Cause(err, "create QUIC connection") From d7b8689b264ecd6a33561550ab202a107205dc83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:56:33 +0800 Subject: [PATCH 23/41] Serve cloudflare hello world over TLS --- protocol/cloudflare/dispatch.go | 1 + protocol/cloudflare/inbound.go | 14 ++++++++++++-- protocol/cloudflare/ingress_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index b33dea97c1..8572e437e8 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -156,6 +156,7 @@ func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string return ResolvedService{}, "", err } service.BaseURL = helloURL + service.OriginRequest.NoTLSVerify = true } originURL, err := service.BuildRequestURL(requestURL) if err != nil { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 5e30b89d18..fd6a0f0b58 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -4,6 +4,7 @@ package cloudflare import ( "context" + stdTLS "crypto/tls" "encoding/base64" "io" "math/rand" @@ -17,6 +18,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" boxDialer "github.com/sagernet/sing-box/common/dialer" + boxTLS "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -227,12 +229,20 @@ func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { if err != nil { return nil, E.Cause(err, "listen hello world server") } + certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost") + if err != nil { + _ = listener.Close() + return nil, E.Cause(err, "generate hello world certificate") + } + tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{ + Certificates: []stdTLS.Certificate{*certificate}, + }) server := &http.Server{Handler: mux} - go server.Serve(listener) + go server.Serve(tlsListener) i.helloWorldServer = server i.helloWorldURL = &url.URL{ - Scheme: "http", + Scheme: "https", Host: listener.Addr().String(), } return i.helloWorldURL, nil diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 03a91f0f0f..e4432cf321 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -151,3 +151,29 @@ func TestResolveHTTPServiceStatus(t *testing.T) { t.Fatalf("status service should keep request URL, got %s", requestURL) } } + +func TestResolveHTTPServiceHelloWorld(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Service: mustResolvedService(t, "hello_world")}, + }, + } + + service, requestURL, err := inboundInstance.resolveHTTPService("https://hello.example.com/path") + if err != nil { + t.Fatal(err) + } + if service.Kind != ResolvedServiceHelloWorld { + t.Fatalf("expected hello world service, got %#v", service) + } + if service.BaseURL == nil || service.BaseURL.Scheme != "https" { + t.Fatalf("expected hello world base URL to be https, got %#v", service.BaseURL) + } + if !service.OriginRequest.NoTLSVerify { + t.Fatal("expected hello world to force no_tls_verify") + } + if requestURL == "" || requestURL[:8] != "https://" { + t.Fatalf("expected https request URL, got %s", requestURL) + } +} From e6a7efc49abffbc6b9dab7702ea00237f0f6a0a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:58:17 +0800 Subject: [PATCH 24/41] Cover direct cloudflare origin services --- protocol/cloudflare/direct_origin_test.go | 120 ++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 protocol/cloudflare/direct_origin_test.go diff --git a/protocol/cloudflare/direct_origin_test.go b/protocol/cloudflare/direct_origin_test.go new file mode 100644 index 0000000000..f38c96e226 --- /dev/null +++ b/protocol/cloudflare/direct_origin_test.go @@ -0,0 +1,120 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + stdTLS "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "testing" + "time" + + boxTLS "github.com/sagernet/sing-box/common/tls" +) + +func TestNewDirectOriginTransportUnix(t *testing.T) { + socketPath := fmt.Sprintf("/tmp/cf-origin-%d.sock", time.Now().UnixNano()) + _ = os.Remove(socketPath) + t.Cleanup(func() { _ = os.Remove(socketPath) }) + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go serveTestHTTPOverListener(listener, func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte("unix-ok")) + }) + + inboundInstance := &Inbound{} + transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{ + Kind: ResolvedServiceUnix, + UnixPath: socketPath, + BaseURL: &url.URL{ + Scheme: "http", + Host: "localhost", + }, + }, "") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + client := &http.Client{Transport: transport} + resp, err := client.Get("http://localhost/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "unix-ok" { + t.Fatalf("unexpected response body: %q", string(body)) + } +} + +func TestNewDirectOriginTransportUnixTLS(t *testing.T) { + socketPath := fmt.Sprintf("/tmp/cf-origin-tls-%d.sock", time.Now().UnixNano()) + _ = os.Remove(socketPath) + t.Cleanup(func() { _ = os.Remove(socketPath) }) + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost") + if err != nil { + t.Fatal(err) + } + tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{ + Certificates: []stdTLS.Certificate{*certificate}, + }) + defer tlsListener.Close() + + go serveTestHTTPOverListener(tlsListener, func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte("unix-tls-ok")) + }) + + inboundInstance := &Inbound{} + transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{ + Kind: ResolvedServiceUnixTLS, + OriginRequest: OriginRequestConfig{ + NoTLSVerify: true, + }, + UnixPath: socketPath, + BaseURL: &url.URL{ + Scheme: "https", + Host: "localhost", + }, + }, "") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + client := &http.Client{Transport: transport} + resp, err := client.Get("https://localhost/") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "unix-tls-ok" { + t.Fatalf("unexpected response body: %q", string(body)) + } +} + +func serveTestHTTPOverListener(listener net.Listener, handler func(http.ResponseWriter, *http.Request)) { + server := &http.Server{Handler: http.HandlerFunc(handler)} + _ = server.Serve(listener) +} From 2340db6fcf66045dbec2f277ab38a8576b3e56a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 13:59:13 +0800 Subject: [PATCH 25/41] Report unreachable cloudflare v3 registrations --- protocol/cloudflare/datagram_v3.go | 4 +++ protocol/cloudflare/datagram_v3_test.go | 38 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 protocol/cloudflare/datagram_v3_test.go diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index b8c9796d5c..c40459117c 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -142,6 +142,10 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { if closeAfterIdle == 0 { closeAfterIdle = 210 * time.Second } + if !destination.Addr().IsValid() || destination.Addr().IsUnspecified() || destination.Port() == 0 { + m.sendRegistrationResponse(requestID, v3ResponseDestinationUnreachable, "") + return + } m.sessionAccess.Lock() if existing, exists := m.sessions[requestID]; exists { diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go new file mode 100644 index 0000000000..ae41d8ca9d --- /dev/null +++ b/protocol/cloudflare/datagram_v3_test.go @@ -0,0 +1,38 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "encoding/binary" + "testing" + + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" +) + +func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { + sender := &captureDatagramSender{} + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + flowLimiter: &FlowLimiter{}, + } + muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) + + requestID := RequestID{} + requestID[15] = 1 + payload := make([]byte, 1+2+2+16+4) + payload[0] = 0 + binary.BigEndian.PutUint16(payload[1:3], 0) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + copy(payload[21:25], []byte{0, 0, 0, 0}) + + muxer.handleRegistration(context.Background(), payload) + if len(sender.sent) != 1 { + t.Fatalf("expected one registration response, got %d", len(sender.sent)) + } + if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseDestinationUnreachable { + t.Fatalf("unexpected datagram response: %v", sender.sent[0]) + } +} From e54707cfe9ce5ca6216367f3f6454fd86900c7fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 14:01:28 +0800 Subject: [PATCH 26/41] Return v3 registration protocol errors --- protocol/cloudflare/datagram_v3.go | 4 ++-- protocol/cloudflare/datagram_v3_test.go | 26 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index c40459117c..ea23ed21c8 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -120,7 +120,7 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { if flags&v3FlagIPv6 != 0 { if len(data) < offset+v3IPv6AddrLen { - m.logger.Debug("V3 registration too short for IPv6") + m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv6") return } var addr [16]byte @@ -129,7 +129,7 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { offset += v3IPv6AddrLen } else { if len(data) < offset+v3IPv4AddrLen { - m.logger.Debug("V3 registration too short for IPv4") + m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv4") return } var addr [4]byte diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index ae41d8ca9d..5703310c0a 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -36,3 +36,29 @@ func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { t.Fatalf("unexpected datagram response: %v", sender.sent[0]) } } + +func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) { + sender := &captureDatagramSender{} + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + flowLimiter: &FlowLimiter{}, + } + muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) + + requestID := RequestID{} + requestID[15] = 2 + payload := make([]byte, 1+2+2+16+1) + payload[0] = 1 + binary.BigEndian.PutUint16(payload[1:3], 53) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + payload[21] = 0xaa + + muxer.handleRegistration(context.Background(), payload) + if len(sender.sent) != 1 { + t.Fatalf("expected one registration response, got %d", len(sender.sent)) + } + if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseErrorWithMsg { + t.Fatalf("unexpected datagram response: %v", sender.sent[0]) + } +} From a95f56cdea95cf2cdf534f53f5f5af4a3c2ae1bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 14:51:45 +0800 Subject: [PATCH 27/41] cloudflare: enforce socks-proxy ip_rules --- protocol/cloudflare/dispatch.go | 2 +- protocol/cloudflare/ip_rule_policy.go | 96 ++++++ protocol/cloudflare/runtime_config.go | 6 + protocol/cloudflare/special_service.go | 33 +- protocol/cloudflare/special_service_test.go | 323 +++++++++++++++----- 5 files changed, 376 insertions(+), 84 deletions(-) create mode 100644 protocol/cloudflare/ip_rule_policy.go diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 8572e437e8..e8089858c6 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -272,7 +272,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos respWriter.WriteResponse(err, nil) return } - i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata) + i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata, service) default: err := E.New("unsupported service kind for HTTP/WebSocket request") i.logger.ErrorContext(ctx, err) diff --git a/protocol/cloudflare/ip_rule_policy.go b/protocol/cloudflare/ip_rule_policy.go new file mode 100644 index 0000000000..d2526306dd --- /dev/null +++ b/protocol/cloudflare/ip_rule_policy.go @@ -0,0 +1,96 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "net" + "net/netip" + "sort" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +type compiledIPRule struct { + prefix netip.Prefix + ports []int + allow bool +} + +type ipRulePolicy struct { + rules []compiledIPRule +} + +func newIPRulePolicy(rawRules []IPRule) (*ipRulePolicy, error) { + policy := &ipRulePolicy{ + rules: make([]compiledIPRule, 0, len(rawRules)), + } + for _, rawRule := range rawRules { + if rawRule.Prefix == "" { + return nil, E.New("ip_rule prefix cannot be blank") + } + prefix, err := netip.ParsePrefix(rawRule.Prefix) + if err != nil { + return nil, E.Cause(err, "parse ip_rule prefix") + } + ports := append([]int(nil), rawRule.Ports...) + sort.Ints(ports) + for _, port := range ports { + if port < 1 || port > 65535 { + return nil, E.New("invalid ip_rule port: ", port) + } + } + policy.rules = append(policy.rules, compiledIPRule{ + prefix: prefix, + ports: ports, + allow: rawRule.Allow, + }) + } + return policy, nil +} + +func (p *ipRulePolicy) Allow(ctx context.Context, destination M.Socksaddr) (bool, error) { + if p == nil { + return false, nil + } + ipAddr, err := resolvePolicyDestination(ctx, destination) + if err != nil { + return false, err + } + port := int(destination.Port) + for _, rule := range p.rules { + if !rule.prefix.Contains(ipAddr) { + continue + } + if len(rule.ports) == 0 { + return rule.allow, nil + } + portIndex := sort.SearchInts(rule.ports, port) + if portIndex < len(rule.ports) && rule.ports[portIndex] == port { + return rule.allow, nil + } + } + return false, nil +} + +func resolvePolicyDestination(ctx context.Context, destination M.Socksaddr) (netip.Addr, error) { + if destination.IsIP() { + return destination.Unwrap().Addr, nil + } + if !destination.IsFqdn() { + return netip.Addr{}, E.New("destination is neither IP nor FQDN") + } + ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, destination.Fqdn) + if err != nil { + return netip.Addr{}, E.Cause(err, "resolve destination") + } + if len(ipAddrs) == 0 { + return netip.Addr{}, E.New("resolved destination is empty") + } + resolvedAddr, ok := netip.AddrFromSlice(ipAddrs[0].IP) + if !ok { + return netip.Addr{}, E.New("resolved destination is invalid") + } + return resolvedAddr.Unmap(), nil +} diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index c35c5505c9..52583fa3ea 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -50,6 +50,7 @@ type ResolvedService struct { BaseURL *url.URL UnixPath string StatusCode int + SocksPolicy *ipRulePolicy OriginRequest OriginRequestConfig } @@ -432,9 +433,14 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) OriginRequest: originRequest, }, nil case rawService == "socks-proxy": + policy, err := newIPRulePolicy(originRequest.IPRules) + if err != nil { + return ResolvedService{}, E.Cause(err, "compile socks-proxy ip rules") + } return ResolvedService{ Kind: ResolvedServiceSocksProxy, Service: rawService, + SocksPolicy: policy, OriginRequest: originRequest, }, nil case strings.HasPrefix(rawService, "unix:"): diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index c7214ec9a6..bf4833b9df 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -28,6 +28,13 @@ import ( var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") +const ( + socksReplySuccess = 0 + socksReplyRuleFailure = 2 + socksReplyHostUnreachable = 4 + socksReplyCommandNotSupported = 7 +) + func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { destination, err := resolveBastionDestination(request) if err != nil { @@ -60,7 +67,7 @@ func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWr _ = bufio.CopyConn(ctx, wsConn, targetConn) } -func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { +func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request))) if err != nil { i.logger.ErrorContext(ctx, "write socks-proxy websocket response: ", err) @@ -69,7 +76,7 @@ func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWrit wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) defer wsConn.Close() - if err := i.serveSocksProxy(ctx, wsConn); err != nil && !E.IsClosedOrCanceled(err) { + if err := i.serveSocksProxy(ctx, wsConn, service.SocksPolicy); err != nil && !E.IsClosedOrCanceled(err) { i.logger.DebugContext(ctx, "socks-proxy stream closed: ", err) } } @@ -132,7 +139,7 @@ func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (n }, nil } -func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn) error { +func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error { version := make([]byte, 1) if _, err := io.ReadFull(conn, version); err != nil { return err @@ -161,7 +168,7 @@ func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn) error { return E.New("unsupported SOCKS request version: ", requestHeader[0]) } if requestHeader[1] != 1 { - _, _ = conn.Write([]byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}) + _ = writeSocksReply(conn, socksReplyCommandNotSupported) return E.New("unsupported SOCKS command: ", requestHeader[1]) } @@ -169,19 +176,33 @@ func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn) error { if err != nil { return err } + allowed, err := policy.Allow(ctx, destination) + if err != nil { + _ = writeSocksReply(conn, socksReplyRuleFailure) + return err + } + if !allowed { + _ = writeSocksReply(conn, socksReplyRuleFailure) + return E.New("connect to ", destination, " denied by ip_rules") + } targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) if err != nil { - _, _ = conn.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) + _ = writeSocksReply(conn, socksReplyHostUnreachable) return err } defer cleanup() - if _, err := conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + if err := writeSocksReply(conn, socksReplySuccess); err != nil { return err } return bufio.CopyConn(ctx, conn, targetConn) } +func writeSocksReply(conn net.Conn, reply byte) error { + _, err := conn.Write([]byte{5, reply, 0, 1, 0, 0, 0, 0, 0, 0}) + return err +} + func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) { switch addressType { case 1: diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 0e4b2083fb..a2f4455538 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strconv" + "sync/atomic" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/ws" "github.com/sagernet/ws/wsutil" ) @@ -48,6 +50,10 @@ func (w *fakeConnectResponseWriter) WriteResponse(responseError error, metadata } func newSpecialServiceInbound(t *testing.T) *Inbound { + return newSpecialServiceInboundWithRouter(t, &testRouter{}) +} + +func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *Inbound { t.Helper() logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) if err != nil { @@ -59,19 +65,28 @@ func newSpecialServiceInbound(t *testing.T) *Inbound { } return &Inbound{ Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), - router: &testRouter{}, + router: router, logger: logFactory.NewLogger("test"), configManager: configManager, } } -func TestHandleBastionStream(t *testing.T) { +type countingRouter struct { + testRouter + count atomic.Int32 +} + +func (r *countingRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + r.count.Add(1) + r.testRouter.RouteConnectionEx(ctx, conn, metadata, onClose) +} + +func startEchoListener(t *testing.T) net.Listener { + t.Helper() listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } - defer listener.Close() - go func() { for { conn, err := listener.Accept() @@ -84,30 +99,42 @@ func TestHandleBastionStream(t *testing.T) { }(conn) } }() + return listener +} - serverSide, clientSide := net.Pipe() - defer clientSide.Close() +func newSocksProxyService(t *testing.T, rules []option.CloudflareTunnelIPRule) ResolvedService { + t.Helper() + service, err := parseResolvedService("socks-proxy", originRequestFromOption(option.CloudflareTunnelOriginRequestOptions{ + IPRules: rules, + })) + if err != nil { + t.Fatal(err) + } + return service +} - inboundInstance := newSpecialServiceInbound(t) - request := &ConnectRequest{ +func newSocksProxyConnectRequest() *ConnectRequest { + return &ConnectRequest{ Type: ConnectionTypeWebsocket, Metadata: []Metadata{ {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, - {Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()}, }, } - respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} +} +func startSocksProxyStream(t *testing.T, inboundInstance *Inbound, service ResolvedService) (net.Conn, <-chan struct{}) { + t.Helper() + serverSide, clientSide := net.Pipe() + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, newSocksProxyConnectRequest(), adapter.InboundContext{}, service) }() - select { case <-respWriter.done: case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for bastion connect response") + t.Fatal("timed out waiting for socks-proxy connect response") } if respWriter.err != nil { t.Fatal(respWriter.err) @@ -115,50 +142,49 @@ func TestHandleBastionStream(t *testing.T) { if respWriter.status != http.StatusSwitchingProtocols { t.Fatalf("expected 101 response, got %d", respWriter.status) } - if respWriter.headers.Get("Sec-WebSocket-Accept") == "" { - t.Fatal("expected websocket accept header") - } + return clientSide, done +} - if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { +func writeSocksAuth(t *testing.T, conn net.Conn) { + t.Helper() + if err := wsutil.WriteClientMessage(conn, ws.OpBinary, []byte{5, 1, 0}); err != nil { t.Fatal(err) } - data, opCode, err := wsutil.ReadServerData(clientSide) + data, _, err := wsutil.ReadServerData(conn) if err != nil { t.Fatal(err) } - if opCode != ws.OpBinary { - t.Fatalf("expected binary frame, got %v", opCode) - } - if string(data) != "hello" { - t.Fatalf("expected echoed payload, got %q", string(data)) - } - _ = clientSide.Close() - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("bastion stream did not exit") + if string(data) != string([]byte{5, 0}) { + t.Fatalf("unexpected auth response: %v", data) } } -func TestHandleSocksProxyStream(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") +func writeSocksConnectIPv4(t *testing.T, conn net.Conn, address string) []byte { + t.Helper() + host, portText, err := net.SplitHostPort(address) if err != nil { t.Fatal(err) } - defer listener.Close() + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatal(err) + } + requestBytes := []byte{5, 1, 0, 1} + requestBytes = append(requestBytes, net.ParseIP(host).To4()...) + requestBytes = append(requestBytes, byte(port>>8), byte(port)) + if err := wsutil.WriteClientMessage(conn, ws.OpBinary, requestBytes); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(conn) + if err != nil { + t.Fatal(err) + } + return data +} - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - _, _ = io.Copy(conn, conn) - }(conn) - } - }() +func TestHandleBastionStream(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() serverSide, clientSide := net.Pipe() defer clientSide.Close() @@ -168,6 +194,7 @@ func TestHandleSocksProxyStream(t *testing.T) { Type: ConnectionTypeWebsocket, Metadata: []Metadata{ {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + {Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()}, }, } respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} @@ -175,13 +202,13 @@ func TestHandleSocksProxyStream(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) }() select { case <-respWriter.done: case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for socks-proxy connect response") + t.Fatal("timed out waiting for bastion connect response") } if respWriter.err != nil { t.Fatal(respWriter.err) @@ -189,30 +216,48 @@ func TestHandleSocksProxyStream(t *testing.T) { if respWriter.status != http.StatusSwitchingProtocols { t.Fatalf("expected 101 response, got %d", respWriter.status) } + if respWriter.headers.Get("Sec-WebSocket-Accept") == "" { + t.Fatal("expected websocket accept header") + } - if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte{5, 1, 0}); err != nil { + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { t.Fatal(err) } - data, _, err := wsutil.ReadServerData(clientSide) + data, opCode, err := wsutil.ReadServerData(clientSide) if err != nil { t.Fatal(err) } - if string(data) != string([]byte{5, 0}) { - t.Fatalf("unexpected auth response: %v", data) + if opCode != ws.OpBinary { + t.Fatalf("expected binary frame, got %v", opCode) } - - host, portText, _ := net.SplitHostPort(listener.Addr().String()) - port, _ := strconv.Atoi(portText) - requestBytes := []byte{5, 1, 0, 1} - requestBytes = append(requestBytes, net.ParseIP(host).To4()...) - requestBytes = append(requestBytes, byte(port>>8), byte(port)) - if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, requestBytes); err != nil { - t.Fatal(err) + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) } - data, _, err = wsutil.ReadServerData(clientSide) - if err != nil { - t.Fatal(err) + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bastion stream did not exit") } +} + +func TestHandleSocksProxyStream(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + _, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + Prefix: "127.0.0.0/8", + Ports: []int{port}, + Allow: true, + }}) + + clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), service) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) if len(data) != 10 || data[1] != 0 { t.Fatalf("unexpected connect response: %v", data) } @@ -220,7 +265,7 @@ func TestHandleSocksProxyStream(t *testing.T) { if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { t.Fatal(err) } - data, _, err = wsutil.ReadServerData(clientSide) + data, _, err := wsutil.ReadServerData(clientSide) if err != nil { t.Fatal(err) } @@ -235,25 +280,149 @@ func TestHandleSocksProxyStream(t *testing.T) { } } -func TestHandleStreamService(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) +func TestHandleSocksProxyStreamDenyRule(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + _, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + Prefix: "127.0.0.0/8", + Ports: []int{port}, + Allow: false, + }}) + router := &countingRouter{} + clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplyRuleFailure { + t.Fatalf("unexpected deny response: %v", data) + } + if router.count.Load() != 0 { + t.Fatalf("expected no router dial, got %d", router.count.Load()) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") } +} + +func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) { + listener := startEchoListener(t) defer listener.Close() - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - _, _ = io.Copy(conn, conn) - }(conn) + _, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + Prefix: "127.0.0.0/8", + Ports: []int{port + 1}, + Allow: true, + }}) + router := &countingRouter{} + clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplyRuleFailure { + t.Fatalf("unexpected port mismatch response: %v", data) + } + if router.count.Load() != 0 { + t.Fatalf("expected no router dial, got %d", router.count.Load()) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } +} + +func TestHandleSocksProxyStreamEmptyRulesDefaultDeny(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + router := &countingRouter{} + clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), newSocksProxyService(t, nil)) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplyRuleFailure { + t.Fatalf("unexpected empty-rule response: %v", data) + } + if router.count.Load() != 0 { + t.Fatalf("expected no router dial, got %d", router.count.Load()) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } +} + +func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + _, portText, _ := net.SplitHostPort(listener.Addr().String()) + port, _ := strconv.Atoi(portText) + allowFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, + {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, + }) + denyFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, + {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, + }) + + t.Run("allow-first", func(t *testing.T) { + clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), allowFirst) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplySuccess { + t.Fatalf("unexpected allow-first response: %v", data) } - }() + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } + }) + + t.Run("deny-first", func(t *testing.T) { + router := &countingRouter{} + clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), denyFirst) + defer clientSide.Close() + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplyRuleFailure { + t.Fatalf("unexpected deny-first response: %v", data) + } + if router.count.Load() != 0 { + t.Fatalf("expected no router dial, got %d", router.count.Load()) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks-proxy stream did not exit") + } + }) +} + +func TestHandleStreamService(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() serverSide, clientSide := net.Pipe() defer clientSide.Close() From af2afc529b70514bec0a74a8e821e8cd23d303fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 16:26:50 +0800 Subject: [PATCH 28/41] cloudflare: require remote-managed tunnels --- option/cloudflare_tunnel.go | 68 ++------- protocol/cloudflare/config_decode_test.go | 28 ++++ protocol/cloudflare/connection_http2.go | 17 +++ protocol/cloudflare/connection_quic.go | 10 ++ protocol/cloudflare/control.go | 7 + protocol/cloudflare/credentials_test.go | 51 ------- protocol/cloudflare/flow_limiter_test.go | 2 +- protocol/cloudflare/helpers_test.go | 2 +- protocol/cloudflare/inbound.go | 99 +++++++------ protocol/cloudflare/ingress_test.go | 15 +- protocol/cloudflare/runtime_config.go | 152 ++------------------ protocol/cloudflare/special_service_test.go | 18 ++- 12 files changed, 154 insertions(+), 315 deletions(-) create mode 100644 protocol/cloudflare/config_decode_test.go diff --git a/option/cloudflare_tunnel.go b/option/cloudflare_tunnel.go index bf388a044e..74b511eefe 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflare_tunnel.go @@ -3,64 +3,12 @@ package option import "github.com/sagernet/sing/common/json/badoption" type CloudflareTunnelInboundOptions struct { - Token string `json:"token,omitempty"` - CredentialPath string `json:"credential_path,omitempty"` - HAConnections int `json:"ha_connections,omitempty"` - Protocol string `json:"protocol,omitempty"` - ControlDialer DialerOptions `json:"control_dialer,omitempty"` - EdgeIPVersion int `json:"edge_ip_version,omitempty"` - DatagramVersion string `json:"datagram_version,omitempty"` - GracePeriod badoption.Duration `json:"grace_period,omitempty"` - Region string `json:"region,omitempty"` - Ingress []CloudflareTunnelIngressRule `json:"ingress,omitempty"` - OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` - WarpRouting CloudflareTunnelWarpRoutingOptions `json:"warp_routing,omitempty"` -} - -type CloudflareTunnelIngressRule struct { - Hostname string `json:"hostname,omitempty"` - Path string `json:"path,omitempty"` - Service string `json:"service,omitempty"` - OriginRequest CloudflareTunnelOriginRequestOptions `json:"origin_request,omitempty"` -} - -type CloudflareTunnelOriginRequestOptions struct { - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - TLSTimeout badoption.Duration `json:"tls_timeout,omitempty"` - TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` - NoHappyEyeballs bool `json:"no_happy_eyeballs,omitempty"` - KeepAliveTimeout badoption.Duration `json:"keep_alive_timeout,omitempty"` - KeepAliveConnections int `json:"keep_alive_connections,omitempty"` - HTTPHostHeader string `json:"http_host_header,omitempty"` - OriginServerName string `json:"origin_server_name,omitempty"` - MatchSNIToHost bool `json:"match_sni_to_host,omitempty"` - CAPool string `json:"ca_pool,omitempty"` - NoTLSVerify bool `json:"no_tls_verify,omitempty"` - DisableChunkedEncoding bool `json:"disable_chunked_encoding,omitempty"` - BastionMode bool `json:"bastion_mode,omitempty"` - ProxyAddress string `json:"proxy_address,omitempty"` - ProxyPort uint `json:"proxy_port,omitempty"` - ProxyType string `json:"proxy_type,omitempty"` - IPRules []CloudflareTunnelIPRule `json:"ip_rules,omitempty"` - HTTP2Origin bool `json:"http2_origin,omitempty"` - Access CloudflareTunnelAccessRule `json:"access,omitempty"` -} - -type CloudflareTunnelAccessRule struct { - Required bool `json:"required,omitempty"` - TeamName string `json:"team_name,omitempty"` - AudTag []string `json:"aud_tag,omitempty"` - Environment string `json:"environment,omitempty"` -} - -type CloudflareTunnelIPRule struct { - Prefix string `json:"prefix,omitempty"` - Ports []int `json:"ports,omitempty"` - Allow bool `json:"allow,omitempty"` -} - -type CloudflareTunnelWarpRoutingOptions struct { - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - MaxActiveFlows uint64 `json:"max_active_flows,omitempty"` - TCPKeepAlive badoption.Duration `json:"tcp_keep_alive,omitempty"` + Token string `json:"token,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + ControlDialer DialerOptions `json:"control_dialer,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod badoption.Duration `json:"grace_period,omitempty"` + Region string `json:"region,omitempty"` } diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go new file mode 100644 index 0000000000..588e0355ef --- /dev/null +++ b/protocol/cloudflare/config_decode_test.go @@ -0,0 +1,28 @@ +//go:build with_cloudflare_tunnel + +package cloudflare + +import ( + "context" + "testing" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +func TestNewInboundRequiresToken(t *testing.T) { + _, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflareTunnelInboundOptions{}) + if err == nil { + t.Fatal("expected missing token error") + } +} + +func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) { + err := validateRegistrationResult(&RegistrationResult{TunnelIsRemotelyManaged: false}) + if err == nil { + t.Fatal("expected unsupported tunnel error") + } + if err != ErrNonRemoteManagedTunnelUnsupported { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 9fed72cf80..24ddadd6c1 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -44,6 +44,7 @@ type HTTP2Connection struct { numPreviousAttempts uint8 registrationClient *RegistrationClient registrationResult *RegistrationResult + controlStreamErr error activeRequests sync.WaitGroup closeOnce sync.Once @@ -113,6 +114,9 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { Handler: c, }) + if c.controlStreamErr != nil { + return c.controlStreamErr + } if ctx.Err() != nil { return ctx.Err() } @@ -161,10 +165,23 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options, ) if err != nil { + c.controlStreamErr = err + c.logger.Error("register connection: ", err) + if c.registrationClient != nil { + c.registrationClient.Close() + } + go c.close() + return + } + if err := validateRegistrationResult(result); err != nil { + c.controlStreamErr = err c.logger.Error("register connection: ", err) + c.registrationClient.Close() + go c.close() return } c.registrationResult = result + c.inbound.notifyConnected(c.connIndex) c.logger.Info("connected to ", result.Location, " (connection ", result.ConnectionID, ")") diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index fd2e56a3bb..2a02a06d0e 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -50,6 +50,7 @@ type QUICConnection struct { gracePeriod time.Duration registrationClient *RegistrationClient registrationResult *RegistrationResult + onConnected func() closeOnce sync.Once } @@ -90,6 +91,7 @@ func NewQUICConnection( numPreviousAttempts uint8, gracePeriod time.Duration, controlDialer N.Dialer, + onConnected func(), logger log.ContextLogger, ) (*QUICConnection, error) { rootCAs, err := cloudflareRootCertPool() @@ -134,6 +136,7 @@ func NewQUICConnection( features: features, numPreviousAttempts: numPreviousAttempts, gracePeriod: gracePeriod, + onConnected: onConnected, }, nil } @@ -170,6 +173,7 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error err = q.register(ctx, controlStream) if err != nil { controlStream.Close() + q.Close() return err } @@ -208,7 +212,13 @@ func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) erro if err != nil { return E.Cause(err, "register connection") } + if err := validateRegistrationResult(result); err != nil { + return err + } q.registrationResult = result + if q.onConnected != nil { + q.onConnected() + } return nil } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index 6e98811142..f72d627f02 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -150,6 +150,13 @@ func (c *RegistrationClient) Close() error { ) } +func validateRegistrationResult(result *RegistrationResult) error { + if result == nil || result.TunnelIsRemotelyManaged { + return nil + } + return ErrNonRemoteManagedTunnelUnsupported +} + // BuildConnectionOptions creates the ConnectionOptions to send during registration. func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions { return &RegistrationConnectionOptions{ diff --git a/protocol/cloudflare/credentials_test.go b/protocol/cloudflare/credentials_test.go index 31759aa34c..506d8601a5 100644 --- a/protocol/cloudflare/credentials_test.go +++ b/protocol/cloudflare/credentials_test.go @@ -4,8 +4,6 @@ package cloudflare import ( "encoding/base64" - "os" - "path/filepath" "testing" "github.com/google/uuid" @@ -43,52 +41,3 @@ func TestParseTokenInvalidJSON(t *testing.T) { t.Fatal("expected error for invalid JSON") } } - -func TestParseCredentialFile(t *testing.T) { - tunnelID := uuid.New() - content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"` + tunnelID.String() + `"}` - path := filepath.Join(t.TempDir(), "creds.json") - err := os.WriteFile(path, []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - credentials, err := parseCredentialFile(path) - if err != nil { - t.Fatal("parseCredentialFile: ", err) - } - if credentials.AccountTag != "acct" { - t.Error("expected AccountTag acct, got ", credentials.AccountTag) - } - if credentials.TunnelID != tunnelID { - t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID) - } -} - -func TestParseCredentialFileMissingTunnelID(t *testing.T) { - content := `{"AccountTag":"acct","TunnelSecret":"c2VjcmV0","TunnelID":"00000000-0000-0000-0000-000000000000"}` - path := filepath.Join(t.TempDir(), "creds.json") - err := os.WriteFile(path, []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - _, err = parseCredentialFile(path) - if err == nil { - t.Fatal("expected error for missing tunnel ID") - } -} - -func TestParseCredentialsBothSpecified(t *testing.T) { - _, err := parseCredentials("sometoken", "/some/path") - if err == nil { - t.Fatal("expected error when both specified") - } -} - -func TestParseCredentialsNoneSpecified(t *testing.T) { - _, err := parseCredentials("", "") - if err == nil { - t.Fatal("expected error when none specified") - } -} diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index ad27c534b8..b8e69aeeb7 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -23,7 +23,7 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound { if err != nil { t.Fatal(err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index d873a40b11..8dbec9c7ad 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -170,7 +170,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i t.Fatal("create logger: ", err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal("create config manager: ", err) } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index fd6a0f0b58..038a6a9092 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -6,12 +6,12 @@ import ( "context" stdTLS "crypto/tls" "encoding/base64" + "errors" "io" "math/rand" "net" "net/http" "net/url" - "os" "sync" "time" @@ -33,6 +33,8 @@ func RegisterInbound(registry *inbound.Registry) { inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound) } +var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflare tunnel only supports remote-managed tunnels") + type Inbound struct { inbound.Adapter ctx context.Context @@ -63,12 +65,19 @@ type Inbound struct { helloWorldAccess sync.Mutex helloWorldServer *http.Server helloWorldURL *url.URL + + connectedAccess sync.Mutex + connectedIndices map[uint8]struct{} + connectedNotify chan uint8 } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { - credentials, err := parseCredentials(options.Token, options.CredentialPath) + if options.Token == "" { + return nil, E.New("missing token") + } + credentials, err := parseToken(options.Token) if err != nil { - return nil, E.Cause(err, "parse credentials") + return nil, E.Cause(err, "parse token") } haConnections := options.HAConnections @@ -96,7 +105,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo gracePeriod = 30 * time.Second } - configManager, err := NewConfigManager(options) + configManager, err := NewConfigManager() if err != nil { return nil, E.Cause(err, "build cloudflare tunnel runtime config") } @@ -139,6 +148,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo controlDialer: controlDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), }, nil } @@ -164,24 +175,36 @@ func (i *Inbound) Start(stage adapter.StartStage) error { for connIndex := 0; connIndex < i.haConnections; connIndex++ { i.done.Add(1) go i.superviseConnection(uint8(connIndex), edgeAddrs, features) - if connIndex == 0 { - // Wait a bit for the first connection before starting others - select { - case <-time.After(time.Second): - case <-i.ctx.Done(): - return i.ctx.Err() + select { + case readyConnIndex := <-i.connectedNotify: + if readyConnIndex != uint8(connIndex) { + i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex) } - } else { - select { - case <-time.After(time.Second): - case <-i.ctx.Done(): - return nil + case <-time.After(firstConnectionReadyTimeout): + case <-i.ctx.Done(): + if connIndex == 0 { + return i.ctx.Err() } + return nil } } return nil } +func (i *Inbound) notifyConnected(connIndex uint8) { + if i.connectedNotify == nil { + return + } + i.connectedAccess.Lock() + if _, loaded := i.connectedIndices[connIndex]; loaded { + i.connectedAccess.Unlock() + return + } + i.connectedIndices[connIndex] = struct{}{} + i.connectedAccess.Unlock() + i.connectedNotify <- connIndex +} + func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { result := i.configManager.Apply(version, config) if result.Err != nil { @@ -249,8 +272,9 @@ func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { } const ( - backoffBaseTime = time.Second - backoffMaxTime = 2 * time.Minute + backoffBaseTime = time.Second + backoffMaxTime = 2 * time.Minute + firstConnectionReadyTimeout = 15 * time.Second ) func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { @@ -269,6 +293,11 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe if err == nil || i.ctx.Err() != nil { return } + if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { + i.logger.Error("connection ", connIndex, " failed permanently: ", err) + i.cancel() + return + } retries++ backoff := backoffDuration(retries) @@ -294,6 +323,9 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features if err == nil || i.ctx.Err() != nil { return err } + if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { + return err + } i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err) return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) case "http2": @@ -309,7 +341,9 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, - features, numPreviousAttempts, i.gracePeriod, i.controlDialer, i.logger, + features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() { + i.notifyConnected(connIndex) + }, i.logger, ) if err != nil { return E.Cause(err, "create QUIC connection") @@ -377,19 +411,6 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { return result } -func parseCredentials(token string, credentialPath string) (Credentials, error) { - if token == "" && credentialPath == "" { - return Credentials{}, E.New("either token or credential_path must be specified") - } - if token != "" && credentialPath != "" { - return Credentials{}, E.New("token and credential_path are mutually exclusive") - } - if token != "" { - return parseToken(token) - } - return parseCredentialFile(credentialPath) -} - func parseToken(token string) (Credentials, error) { data, err := base64.StdEncoding.DecodeString(token) if err != nil { @@ -402,19 +423,3 @@ func parseToken(token string) (Credentials, error) { } return tunnelToken.ToCredentials(), nil } - -func parseCredentialFile(path string) (Credentials, error) { - data, err := os.ReadFile(path) - if err != nil { - return Credentials{}, E.Cause(err, "read credential file") - } - var credentials Credentials - err = json.Unmarshal(data, &credentials) - if err != nil { - return Credentials{}, E.Cause(err, "unmarshal credential file") - } - if credentials.TunnelID == (uuid.UUID{}) { - return Credentials{}, E.New("credential file missing tunnel ID") - } - return credentials, nil -} diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index e4432cf321..f73db11a19 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -6,12 +6,11 @@ import ( "testing" "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" ) func newTestIngressInbound(t *testing.T) *Inbound { t.Helper() - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } @@ -85,6 +84,18 @@ func TestApplyConfigInvalidJSON(t *testing.T) { } } +func TestDefaultConfigIsCatchAll503(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + + service, loaded := inboundInstance.configManager.Resolve("any.example.com", "/") + if !loaded { + t.Fatal("expected default config to resolve catch-all rule") + } + if service.StatusCode != 503 { + t.Fatalf("expected catch-all 503, got %#v", service) + } +} + func TestResolveExactAndWildcard(t *testing.T) { inboundInstance := newTestIngressInbound(t) inboundInstance.configManager.activeConfig = RuntimeConfig{ diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 52583fa3ea..d46e1062fc 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -155,8 +154,8 @@ type ConfigManager struct { activeConfig RuntimeConfig } -func NewConfigManager(options option.CloudflareTunnelInboundOptions) (*ConfigManager, error) { - config, err := buildLocalRuntimeConfig(options) +func NewConfigManager() (*ConfigManager, error) { + config, err := defaultRuntimeConfig() if err != nil { return nil, err } @@ -237,26 +236,19 @@ func matchIngressHost(pattern, hostname string) bool { return false } -func buildLocalRuntimeConfig(options option.CloudflareTunnelInboundOptions) (RuntimeConfig, error) { - defaultOriginRequest := originRequestFromOption(options.OriginRequest) - warpRouting := warpRoutingFromOption(options.WarpRouting) - var ingressRules []localIngressRule - for _, rule := range options.Ingress { - ingressRules = append(ingressRules, localIngressRule{ - Hostname: rule.Hostname, - Path: rule.Path, - Service: rule.Service, - OriginRequest: mergeOptionOriginRequest(defaultOriginRequest, rule.OriginRequest), - }) - } - compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules) +func defaultRuntimeConfig() (RuntimeConfig, error) { + defaultOriginRequest := defaultOriginRequestConfig() + compiledRules, err := compileIngressRules(defaultOriginRequest, nil) if err != nil { return RuntimeConfig{}, err } return RuntimeConfig{ Ingress: compiledRules, OriginRequest: defaultOriginRequest, - WarpRouting: warpRouting, + WarpRouting: WarpRoutingConfig{ + ConnectTimeout: defaultWarpRoutingConnectTime, + TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, + }, }, nil } @@ -554,117 +546,6 @@ func defaultOriginRequestConfig() OriginRequestConfig { } } -func originRequestFromOption(input option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { - config := defaultOriginRequestConfig() - if input.ConnectTimeout != 0 { - config.ConnectTimeout = time.Duration(input.ConnectTimeout) - } - if input.TLSTimeout != 0 { - config.TLSTimeout = time.Duration(input.TLSTimeout) - } - if input.TCPKeepAlive != 0 { - config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) - } - if input.KeepAliveTimeout != 0 { - config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) - } - if input.KeepAliveConnections != 0 { - config.KeepAliveConnections = input.KeepAliveConnections - } - config.NoHappyEyeballs = input.NoHappyEyeballs - config.HTTPHostHeader = input.HTTPHostHeader - config.OriginServerName = input.OriginServerName - config.MatchSNIToHost = input.MatchSNIToHost - config.CAPool = input.CAPool - config.NoTLSVerify = input.NoTLSVerify - config.DisableChunkedEncoding = input.DisableChunkedEncoding - config.BastionMode = input.BastionMode - if input.ProxyAddress != "" { - config.ProxyAddress = input.ProxyAddress - } - if input.ProxyPort != 0 { - config.ProxyPort = input.ProxyPort - } - config.ProxyType = input.ProxyType - config.HTTP2Origin = input.HTTP2Origin - config.Access = AccessConfig{ - Required: input.Access.Required, - TeamName: input.Access.TeamName, - AudTag: append([]string(nil), input.Access.AudTag...), - Environment: input.Access.Environment, - } - for _, rule := range input.IPRules { - config.IPRules = append(config.IPRules, IPRule{ - Prefix: rule.Prefix, - Ports: append([]int(nil), rule.Ports...), - Allow: rule.Allow, - }) - } - return config -} - -func mergeOptionOriginRequest(base OriginRequestConfig, override option.CloudflareTunnelOriginRequestOptions) OriginRequestConfig { - result := base - if override.ConnectTimeout != 0 { - result.ConnectTimeout = time.Duration(override.ConnectTimeout) - } - if override.TLSTimeout != 0 { - result.TLSTimeout = time.Duration(override.TLSTimeout) - } - if override.TCPKeepAlive != 0 { - result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) - } - if override.KeepAliveTimeout != 0 { - result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) - } - if override.KeepAliveConnections != 0 { - result.KeepAliveConnections = override.KeepAliveConnections - } - result.NoHappyEyeballs = override.NoHappyEyeballs - if override.HTTPHostHeader != "" { - result.HTTPHostHeader = override.HTTPHostHeader - } - if override.OriginServerName != "" { - result.OriginServerName = override.OriginServerName - } - result.MatchSNIToHost = override.MatchSNIToHost - if override.CAPool != "" { - result.CAPool = override.CAPool - } - result.NoTLSVerify = override.NoTLSVerify - result.DisableChunkedEncoding = override.DisableChunkedEncoding - result.BastionMode = override.BastionMode - if override.ProxyAddress != "" { - result.ProxyAddress = override.ProxyAddress - } - if override.ProxyPort != 0 { - result.ProxyPort = override.ProxyPort - } - if override.ProxyType != "" { - result.ProxyType = override.ProxyType - } - if len(override.IPRules) > 0 { - result.IPRules = nil - for _, rule := range override.IPRules { - result.IPRules = append(result.IPRules, IPRule{ - Prefix: rule.Prefix, - Ports: append([]int(nil), rule.Ports...), - Allow: rule.Allow, - }) - } - } - result.HTTP2Origin = override.HTTP2Origin - if override.Access.Required || override.Access.TeamName != "" || len(override.Access.AudTag) > 0 || override.Access.Environment != "" { - result.Access = AccessConfig{ - Required: override.Access.Required, - TeamName: override.Access.TeamName, - AudTag: append([]string(nil), override.Access.AudTag...), - Environment: override.Access.Environment, - } - } - return result -} - func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig { config := defaultOriginRequestConfig() if input.ConnectTimeout != 0 { @@ -802,21 +683,6 @@ func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginReq return result } -func warpRoutingFromOption(input option.CloudflareTunnelWarpRoutingOptions) WarpRoutingConfig { - config := WarpRoutingConfig{ - ConnectTimeout: defaultWarpRoutingConnectTime, - TCPKeepAlive: defaultWarpRoutingTCPKeepAlive, - MaxActiveFlows: input.MaxActiveFlows, - } - if input.ConnectTimeout != 0 { - config.ConnectTimeout = time.Duration(input.ConnectTimeout) - } - if input.TCPKeepAlive != 0 { - config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) - } - return config -} - func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig { config := WarpRoutingConfig{ ConnectTimeout: defaultWarpRoutingConnectTime, diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index a2f4455538..9b29c0a0e6 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -59,7 +59,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In if err != nil { t.Fatal(err) } - configManager, err := NewConfigManager(option.CloudflareTunnelInboundOptions{}) + configManager, err := NewConfigManager() if err != nil { t.Fatal(err) } @@ -102,11 +102,9 @@ func startEchoListener(t *testing.T) net.Listener { return listener } -func newSocksProxyService(t *testing.T, rules []option.CloudflareTunnelIPRule) ResolvedService { +func newSocksProxyService(t *testing.T, rules []IPRule) ResolvedService { t.Helper() - service, err := parseResolvedService("socks-proxy", originRequestFromOption(option.CloudflareTunnelOriginRequestOptions{ - IPRules: rules, - })) + service, err := parseResolvedService("socks-proxy", OriginRequestConfig{IPRules: rules}) if err != nil { t.Fatal(err) } @@ -247,7 +245,7 @@ func TestHandleSocksProxyStream(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true, @@ -286,7 +284,7 @@ func TestHandleSocksProxyStreamDenyRule(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: false, @@ -317,7 +315,7 @@ func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - service := newSocksProxyService(t, []option.CloudflareTunnelIPRule{{ + service := newSocksProxyService(t, []IPRule{{ Prefix: "127.0.0.0/8", Ports: []int{port + 1}, Allow: true, @@ -372,11 +370,11 @@ func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) { _, portText, _ := net.SplitHostPort(listener.Addr().String()) port, _ := strconv.Atoi(portText) - allowFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + allowFirst := newSocksProxyService(t, []IPRule{ {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, }) - denyFirst := newSocksProxyService(t, []option.CloudflareTunnelIPRule{ + denyFirst := newSocksProxyService(t, []IPRule{ {Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false}, {Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true}, }) From 2cf2ff3f330ffd119c6a4bc9c0a89cc6ba05f894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 16:40:21 +0800 Subject: [PATCH 29/41] Rename cloudflare-tunnel type to cloudflared --- constant/proxy.go | 6 +++--- include/cloudflare_tunnel_stub.go | 20 ------------------- .../{cloudflare_tunnel.go => cloudflared.go} | 4 ++-- include/cloudflared_stub.go | 20 +++++++++++++++++++ include/registry.go | 2 +- .../{cloudflare_tunnel.go => cloudflared.go} | 2 +- protocol/cloudflare/access.go | 5 +++-- protocol/cloudflare/access_test.go | 4 ++-- protocol/cloudflare/config_decode_test.go | 4 ++-- protocol/cloudflare/connection_http2.go | 2 +- protocol/cloudflare/connection_quic.go | 2 +- protocol/cloudflare/connection_quic_test.go | 2 +- protocol/cloudflare/control.go | 2 +- protocol/cloudflare/credentials.go | 2 +- protocol/cloudflare/credentials_test.go | 2 +- protocol/cloudflare/datagram_v2.go | 2 +- protocol/cloudflare/datagram_v3.go | 2 +- protocol/cloudflare/datagram_v3_test.go | 6 +++--- protocol/cloudflare/direct_origin_test.go | 2 +- protocol/cloudflare/dispatch.go | 2 +- protocol/cloudflare/dispatch_test.go | 2 +- protocol/cloudflare/edge_discovery.go | 2 +- protocol/cloudflare/edge_discovery_test.go | 2 +- protocol/cloudflare/flow_limiter.go | 2 +- protocol/cloudflare/flow_limiter_test.go | 4 ++-- protocol/cloudflare/header.go | 2 +- protocol/cloudflare/helpers_test.go | 4 ++-- protocol/cloudflare/icmp.go | 2 +- protocol/cloudflare/icmp_test.go | 10 +++++----- protocol/cloudflare/inbound.go | 14 ++++++------- protocol/cloudflare/ingress_test.go | 2 +- protocol/cloudflare/integration_test.go | 2 +- protocol/cloudflare/ip_rule_policy.go | 2 +- protocol/cloudflare/origin_request_test.go | 2 +- protocol/cloudflare/root_ca.go | 2 +- protocol/cloudflare/runtime_config.go | 2 +- protocol/cloudflare/special_service.go | 2 +- protocol/cloudflare/special_service_test.go | 4 ++-- protocol/cloudflare/stream.go | 2 +- protocol/cloudflare/stream_test.go | 2 +- release/DEFAULT_BUILD_TAGS | 2 +- release/DEFAULT_BUILD_TAGS_OTHERS | 2 +- release/DEFAULT_BUILD_TAGS_WINDOWS | 2 +- 43 files changed, 83 insertions(+), 82 deletions(-) delete mode 100644 include/cloudflare_tunnel_stub.go rename include/{cloudflare_tunnel.go => cloudflared.go} (62%) create mode 100644 include/cloudflared_stub.go rename option/{cloudflare_tunnel.go => cloudflared.go} (93%) diff --git a/constant/proxy.go b/constant/proxy.go index 91b3bc98e9..ffec80250b 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -25,7 +25,7 @@ const ( TypeTUIC = "tuic" TypeHysteria2 = "hysteria2" TypeTailscale = "tailscale" - TypeCloudflareTunnel = "cloudflare-tunnel" + TypeCloudflared = "cloudflared" TypeDERP = "derp" TypeResolved = "resolved" TypeSSMAPI = "ssm-api" @@ -91,8 +91,8 @@ func ProxyDisplayName(proxyType string) string { return "AnyTLS" case TypeTailscale: return "Tailscale" - case TypeCloudflareTunnel: - return "Cloudflare Tunnel" + case TypeCloudflared: + return "Cloudflared" case TypeSelector: return "Selector" case TypeURLTest: diff --git a/include/cloudflare_tunnel_stub.go b/include/cloudflare_tunnel_stub.go deleted file mode 100644 index 65c676ab0c..0000000000 --- a/include/cloudflare_tunnel_stub.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !with_cloudflare_tunnel - -package include - -import ( - "context" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/adapter/inbound" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" -) - -func registerCloudflareTunnelInbound(registry *inbound.Registry) { - inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { - return nil, E.New(`Cloudflare Tunnel is not included in this build, rebuild with -tags with_cloudflare_tunnel`) - }) -} diff --git a/include/cloudflare_tunnel.go b/include/cloudflared.go similarity index 62% rename from include/cloudflare_tunnel.go rename to include/cloudflared.go index 80273a313a..6320010825 100644 --- a/include/cloudflare_tunnel.go +++ b/include/cloudflared.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package include @@ -7,6 +7,6 @@ import ( "github.com/sagernet/sing-box/protocol/cloudflare" ) -func registerCloudflareTunnelInbound(registry *inbound.Registry) { +func registerCloudflaredInbound(registry *inbound.Registry) { cloudflare.RegisterInbound(registry) } diff --git a/include/cloudflared_stub.go b/include/cloudflared_stub.go new file mode 100644 index 0000000000..8f49aecc69 --- /dev/null +++ b/include/cloudflared_stub.go @@ -0,0 +1,20 @@ +//go:build !with_cloudflared + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerCloudflaredInbound(registry *inbound.Registry) { + inbound.Register[option.CloudflaredInboundOptions](registry, C.TypeCloudflared, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) { + return nil, E.New(`Cloudflared is not included in this build, rebuild with -tags with_cloudflared`) + }) +} diff --git a/include/registry.go b/include/registry.go index 4cecfcda0b..5a1a2f973a 100644 --- a/include/registry.go +++ b/include/registry.go @@ -66,7 +66,7 @@ func InboundRegistry() *inbound.Registry { anytls.RegisterInbound(registry) registerQUICInbounds(registry) - registerCloudflareTunnelInbound(registry) + registerCloudflaredInbound(registry) registerStubForRemovedInbounds(registry) return registry diff --git a/option/cloudflare_tunnel.go b/option/cloudflared.go similarity index 93% rename from option/cloudflare_tunnel.go rename to option/cloudflared.go index 74b511eefe..e597ebb77e 100644 --- a/option/cloudflare_tunnel.go +++ b/option/cloudflared.go @@ -2,7 +2,7 @@ package option import "github.com/sagernet/sing/common/json/badoption" -type CloudflareTunnelInboundOptions struct { +type CloudflaredInboundOptions struct { Token string `json:"token,omitempty"` HAConnections int `json:"ha_connections,omitempty"` Protocol string `json:"protocol,omitempty"` diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go index 75c1e8ada4..fc40e72331 100644 --- a/protocol/cloudflare/access.go +++ b/protocol/cloudflare/access.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -10,10 +10,11 @@ import ( "strings" "sync" - "github.com/coreos/go-oidc/v3/oidc" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + + "github.com/coreos/go-oidc/v3/oidc" ) const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion" diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 8c7d2b9e10..3cceb155e7 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -32,7 +32,7 @@ func newAccessTestInbound(t *testing.T) *Inbound { t.Fatal(err) } return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), logger: logFactory.NewLogger("test"), accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, router: &testRouter{}, diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go index 588e0355ef..4addd1f99b 100644 --- a/protocol/cloudflare/config_decode_test.go +++ b/protocol/cloudflare/config_decode_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -11,7 +11,7 @@ import ( ) func TestNewInboundRequiresToken(t *testing.T) { - _, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflareTunnelInboundOptions{}) + _, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflaredInboundOptions{}) if err == nil { t.Fatal("expected missing token error") } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 24ddadd6c1..ef7f460359 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 2a02a06d0e..c674828e30 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/connection_quic_test.go b/protocol/cloudflare/connection_quic_test.go index 7ea4a86906..ac7f58aba6 100644 --- a/protocol/cloudflare/connection_quic_test.go +++ b/protocol/cloudflare/connection_quic_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index f72d627f02..80d2cbc636 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/credentials.go b/protocol/cloudflare/credentials.go index 443da061c0..0b11d5ee73 100644 --- a/protocol/cloudflare/credentials.go +++ b/protocol/cloudflare/credentials.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/credentials_test.go b/protocol/cloudflare/credentials_test.go index 506d8601a5..1c3d7fd6be 100644 --- a/protocol/cloudflare/credentials_test.go +++ b/protocol/cloudflare/credentials_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index a17e7c5cc3..318ce5b0bc 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index ea23ed21c8..42e758584f 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index 5703310c0a..08d8704ecc 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -14,7 +14,7 @@ import ( func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { sender := &captureDatagramSender{} inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), flowLimiter: &FlowLimiter{}, } muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) @@ -40,7 +40,7 @@ func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) { sender := &captureDatagramSender{} inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), flowLimiter: &FlowLimiter{}, } muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) diff --git a/protocol/cloudflare/direct_origin_test.go b/protocol/cloudflare/direct_origin_test.go index f38c96e226..85e1b9d8c7 100644 --- a/protocol/cloudflare/direct_origin_test.go +++ b/protocol/cloudflare/direct_origin_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index e8089858c6..8d447f6ac0 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/dispatch_test.go b/protocol/cloudflare/dispatch_test.go index e4645cbd38..3afeb3ad97 100644 --- a/protocol/cloudflare/dispatch_test.go +++ b/protocol/cloudflare/dispatch_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/edge_discovery.go b/protocol/cloudflare/edge_discovery.go index 922063ce43..b8555fa164 100644 --- a/protocol/cloudflare/edge_discovery.go +++ b/protocol/cloudflare/edge_discovery.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index 930fd46be2..28dda352ec 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/flow_limiter.go b/protocol/cloudflare/flow_limiter.go index cfe753f6b7..b26b619b24 100644 --- a/protocol/cloudflare/flow_limiter.go +++ b/protocol/cloudflare/flow_limiter.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index b8e69aeeb7..12eaa6dba6 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -31,7 +31,7 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound { config.WarpRouting.MaxActiveFlows = limit configManager.activeConfig = config return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: &testRouter{}, logger: logFactory.NewLogger("test"), configManager: configManager, diff --git a/protocol/cloudflare/header.go b/protocol/cloudflare/header.go index 05aa3765df..3a40d95812 100644 --- a/protocol/cloudflare/header.go +++ b/protocol/cloudflare/header.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 8dbec9c7ad..4fd13ef77d 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -177,7 +177,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i ctx, cancel := context.WithCancel(context.Background()) inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), ctx: ctx, cancel: cancel, router: &testRouter{}, diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go index 8e000c0db8..1070a2d835 100644 --- a/protocol/cloudflare/icmp.go +++ b/protocol/cloudflare/icmp.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go index 6f985050f2..9557fa16f6 100644 --- a/protocol/cloudflare/icmp_test.go +++ b/protocol/cloudflare/icmp_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -71,7 +71,7 @@ func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) { }, } inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: router, } sender := &captureDatagramSender{} @@ -117,7 +117,7 @@ func TestICMPBridgeHandleV2TracedReply(t *testing.T) { }, } inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: router, } bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) @@ -151,7 +151,7 @@ func TestICMPBridgeHandleV3Reply(t *testing.T) { }, } inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: router, } bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3) @@ -178,7 +178,7 @@ func TestICMPBridgeDropsNonEcho(t *testing.T) { }, } inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: router, } sender := &captureDatagramSender{} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 038a6a9092..c794df3b51 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -30,10 +30,10 @@ import ( ) func RegisterInbound(registry *inbound.Registry) { - inbound.Register[option.CloudflareTunnelInboundOptions](registry, C.TypeCloudflareTunnel, NewInbound) + inbound.Register[option.CloudflaredInboundOptions](registry, C.TypeCloudflared, NewInbound) } -var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflare tunnel only supports remote-managed tunnels") +var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflared only supports remote-managed tunnels") type Inbound struct { inbound.Adapter @@ -71,7 +71,7 @@ type Inbound struct { connectedNotify chan uint8 } -func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflareTunnelInboundOptions) (adapter.Inbound, error) { +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) { if options.Token == "" { return nil, E.New("missing token") } @@ -107,7 +107,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo configManager, err := NewConfigManager() if err != nil { - return nil, E.Cause(err, "build cloudflare tunnel runtime config") + return nil, E.Cause(err, "build cloudflared runtime config") } controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{ Context: ctx, @@ -115,7 +115,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo RemoteIsDomain: true, }) if err != nil { - return nil, E.Cause(err, "build cloudflare tunnel control dialer") + return nil, E.Cause(err, "build cloudflared control dialer") } region := options.Region @@ -129,7 +129,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, tag), + Adapter: inbound.NewAdapter(C.TypeCloudflared, tag), ctx: inboundCtx, cancel: cancel, router: router, diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index f73db11a19..a742e794de 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/integration_test.go b/protocol/cloudflare/integration_test.go index d1ca5799a6..8d19000489 100644 --- a/protocol/cloudflare/integration_test.go +++ b/protocol/cloudflare/integration_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/ip_rule_policy.go b/protocol/cloudflare/ip_rule_policy.go index d2526306dd..191d7a1475 100644 --- a/protocol/cloudflare/ip_rule_policy.go +++ b/protocol/cloudflare/ip_rule_policy.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index d8a6716ab4..72aa3aeabf 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/root_ca.go b/protocol/cloudflare/root_ca.go index bfca9a4c54..6436514c8d 100644 --- a/protocol/cloudflare/root_ca.go +++ b/protocol/cloudflare/root_ca.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index d46e1062fc..5c5ae88243 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index bf4833b9df..e60edfa0d1 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 9b29c0a0e6..35d8245fc1 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare @@ -64,7 +64,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In t.Fatal(err) } return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflareTunnel, "test"), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), router: router, logger: logFactory.NewLogger("test"), configManager: configManager, diff --git a/protocol/cloudflare/stream.go b/protocol/cloudflare/stream.go index 0cd92d30e4..e9dcfa55f9 100644 --- a/protocol/cloudflare/stream.go +++ b/protocol/cloudflare/stream.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/protocol/cloudflare/stream_test.go b/protocol/cloudflare/stream_test.go index 56e60e85b1..78378ce86f 100644 --- a/protocol/cloudflare/stream_test.go +++ b/protocol/cloudflare/stream_test.go @@ -1,4 +1,4 @@ -//go:build with_cloudflare_tunnel +//go:build with_cloudflared package cloudflare diff --git a/release/DEFAULT_BUILD_TAGS b/release/DEFAULT_BUILD_TAGS index cc2c039d8e..e06bc120e0 100644 --- a/release/DEFAULT_BUILD_TAGS +++ b/release/DEFAULT_BUILD_TAGS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflared,with_naive_outbound,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/DEFAULT_BUILD_TAGS_OTHERS b/release/DEFAULT_BUILD_TAGS_OTHERS index 7100c5ad58..a28e900e9d 100644 --- a/release/DEFAULT_BUILD_TAGS_OTHERS +++ b/release/DEFAULT_BUILD_TAGS_OTHERS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflared,badlinkname,tfogo_checklinkname0 \ No newline at end of file diff --git a/release/DEFAULT_BUILD_TAGS_WINDOWS b/release/DEFAULT_BUILD_TAGS_WINDOWS index 7d5dd55ad8..af4fe41620 100644 --- a/release/DEFAULT_BUILD_TAGS_WINDOWS +++ b/release/DEFAULT_BUILD_TAGS_WINDOWS @@ -1 +1 @@ -with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflare_tunnel,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file +with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_ccm,with_ocm,with_cloudflared,with_naive_outbound,with_purego,badlinkname,tfogo_checklinkname0 \ No newline at end of file From 012335e2f5857782cdf61d2576075f8fc051740e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 18:12:18 +0800 Subject: [PATCH 30/41] fix cloudflared warp datagram behavior --- protocol/cloudflare/connection_quic.go | 13 + protocol/cloudflare/control.go | 2 + .../cloudflare/datagram_lifecycle_test.go | 107 ++++++ protocol/cloudflare/datagram_v2.go | 217 ++++++++++-- protocol/cloudflare/datagram_v3.go | 312 ++++++++++++------ protocol/cloudflare/datagram_v3_test.go | 10 +- protocol/cloudflare/dispatch.go | 20 +- protocol/cloudflare/features.go | 59 ++++ protocol/cloudflare/features_test.go | 30 ++ protocol/cloudflare/flow_limiter_test.go | 11 +- protocol/cloudflare/helpers_test.go | 48 ++- protocol/cloudflare/inbound.go | 49 +-- protocol/cloudflare/origin_dial.go | 86 +++++ protocol/cloudflare/stream.go | 5 + route/dial.go | 149 +++++++++ 15 files changed, 934 insertions(+), 184 deletions(-) create mode 100644 protocol/cloudflare/datagram_lifecycle_test.go create mode 100644 protocol/cloudflare/features.go create mode 100644 protocol/cloudflare/features_test.go create mode 100644 protocol/cloudflare/origin_dial.go create mode 100644 route/dial.go diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index c674828e30..f654ee4cb8 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -271,6 +271,19 @@ func (q *QUICConnection) SendDatagram(data []byte) error { return q.conn.SendDatagram(data) } +func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error) { + stream, err := q.conn.OpenStream() + if err != nil { + return nil, E.Cause(err, "open rpc stream") + } + rwc := newStreamReadWriteCloser(stream) + if err := WriteRPCStreamSignature(rwc); err != nil { + rwc.Close() + return nil, E.Cause(err, "write rpc stream signature") + } + return rwc, nil +} + func (q *QUICConnection) gracefulShutdown() { q.closeOnce.Do(func() { if q.registrationClient != nil { diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index 80d2cbc636..dd8b99da6e 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -166,6 +166,8 @@ func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviou Version: clientVersion, Arch: runtime.GOOS + "_" + runtime.GOARCH, }, + ReplaceExisting: false, + CompressionQuality: 0, OriginLocalIP: originLocalIP, NumPreviousAttempts: numPreviousAttempts, } diff --git a/protocol/cloudflare/datagram_lifecycle_test.go b/protocol/cloudflare/datagram_lifecycle_test.go new file mode 100644 index 0000000000..11a98b8bce --- /dev/null +++ b/protocol/cloudflare/datagram_lifecycle_test.go @@ -0,0 +1,107 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "encoding/binary" + "net" + "testing" + "time" + + "github.com/google/uuid" +) + +type v2UnregisterCall struct { + sessionID uuid.UUID + message string +} + +type captureRPCDatagramSender struct { + captureDatagramSender +} + +type captureV2SessionRPCClient struct { + unregisterCh chan v2UnregisterCall +} + +func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error { + c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message} + return nil +} + +func (c *captureV2SessionRPCClient) Close() error { return nil } + +func TestDatagramV2LocalCloseUnregistersRemote(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + sender := &captureRPCDatagramSender{} + muxer := NewDatagramV2Muxer(inboundInstance, sender, inboundInstance.logger) + unregisterCh := make(chan v2UnregisterCall, 1) + originalClientFactory := newV2SessionRPCClient + newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) { + return &captureV2SessionRPCClient{unregisterCh: unregisterCh}, nil + } + defer func() { + newV2SessionRPCClient = originalClientFactory + }() + + sessionID := uuidTest(7) + if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil { + t.Fatal(err) + } + + muxer.sessionAccess.RLock() + session := muxer.sessions[sessionID] + muxer.sessionAccess.RUnlock() + if session == nil { + t.Fatal("expected registered session") + } + + session.closeWithReason("local close") + + select { + case call := <-unregisterCh: + if call.sessionID != sessionID { + t.Fatalf("unexpected session id: %s", call.sessionID) + } + if call.message != "local close" { + t.Fatalf("unexpected message: %q", call.message) + } + case <-time.After(2 * time.Second): + t.Fatal("expected unregister rpc") + } +} + +func TestDatagramV3RegistrationMigratesSender(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + sender1 := &captureDatagramSender{} + sender2 := &captureDatagramSender{} + muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger) + muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger) + + requestID := RequestID{} + requestID[15] = 9 + payload := make([]byte, 1+2+2+16+4) + payload[0] = 0 + binary.BigEndian.PutUint16(payload[1:3], 53) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + copy(payload[21:25], []byte{127, 0, 0, 1}) + + muxer1.handleRegistration(context.Background(), payload) + session, exists := inboundInstance.datagramV3Manager.Get(requestID) + if !exists { + t.Fatal("expected v3 session after first registration") + } + + muxer2.handleRegistration(context.Background(), payload) + + session.senderAccess.RLock() + currentSender := session.sender + session.senderAccess.RUnlock() + if currentSender != sender2 { + t.Fatal("expected v3 session sender migration to second sender") + } + + session.close() +} diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 318ce5b0bc..9ba52f9731 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" "github.com/sagernet/sing/common/buf" @@ -59,6 +58,54 @@ func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.Cont } } +type rpcStreamOpener interface { + OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error) +} + +type v2SessionRPCClient interface { + UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error + Close() error +} + +var newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) { + opener, ok := sender.(rpcStreamOpener) + if !ok { + return nil, E.New("sender does not support rpc streams") + } + stream, err := opener.OpenRPCStream(ctx) + if err != nil { + return nil, err + } + transport := rpc.StreamTransport(stream) + conn := rpc.NewConn(transport) + return &capnpV2SessionRPCClient{ + client: tunnelrpc.SessionManager{Client: conn.Bootstrap(ctx)}, + rpcConn: conn, + transport: transport, + }, nil +} + +type capnpV2SessionRPCClient struct { + client tunnelrpc.SessionManager + rpcConn *rpc.Conn + transport rpc.Transport +} + +func (c *capnpV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error { + promise := c.client.UnregisterUdpSession(ctx, func(p tunnelrpc.SessionManager_unregisterUdpSession_Params) error { + if err := p.SetSessionId(sessionID[:]); err != nil { + return err + } + return p.SetMessage(message) + }) + _, err := promise.Struct() + return err +} + +func (c *capnpV2SessionRPCClient) Close() error { + return E.Errors(c.rpcConn.Close(), c.transport.Close()) +} + // HandleDatagram demuxes an incoming V2 datagram. func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) { if len(data) < typeIDLength { @@ -139,7 +186,14 @@ func (m *DatagramV2Muxer) RegisterSession( return E.New("too many active flows") } - session := newUDPSession(sessionID, destination, closeAfterIdle, m) + origin, err := m.inbound.dialWarpPacketConnection(ctx, destination) + if err != nil { + m.inbound.flowLimiter.Release(limit) + m.sessionAccess.Unlock() + return err + } + + session := newUDPSession(sessionID, destination, closeAfterIdle, origin, m) m.sessions[sessionID] = session m.sessionAccess.Unlock() @@ -159,32 +213,30 @@ func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID) { m.sessionAccess.Unlock() if exists { + session.markRemoteClosed() session.close() m.logger.Info("unregistered V2 UDP session ", sessionID) } } func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession, limit uint64) { - defer m.UnregisterSession(session.id) defer m.inbound.flowLimiter.Release(limit) - metadata := adapter.InboundContext{ - Inbound: m.inbound.Tag(), - InboundType: m.inbound.Type(), - Network: N.NetworkUDP, - } - metadata.Destination = M.SocksaddrFromNetIP(session.destination) - - done := make(chan struct{}) - m.inbound.router.RoutePacketConnectionEx( - ctx, - session, - metadata, - N.OnceClose(func(it error) { - close(done) - }), - ) - <-done + session.serve(ctx) + + m.sessionAccess.Lock() + if current, exists := m.sessions[session.id]; exists && current == session { + delete(m.sessions, session.id) + } + m.sessionAccess.Unlock() + + if !session.remoteClosed() { + unregisterCtx, cancel := context.WithTimeout(context.Background(), registrationTimeout) + defer cancel() + if err := m.unregisterRemoteSession(unregisterCtx, session.id, session.closeReason()); err != nil { + m.logger.Debug("failed to unregister V2 UDP session ", session.id, ": ", err) + } + } } // sendToEdge sends a V2 UDP datagram back to the edge. @@ -213,21 +265,31 @@ type udpSession struct { id uuid.UUID destination netip.AddrPort closeAfterIdle time.Duration + origin N.PacketConn muxer *DatagramV2Muxer writeChan chan []byte closeOnce sync.Once closeChan chan struct{} + + activeAccess sync.RWMutex + activeAt time.Time + + stateAccess sync.RWMutex + closedByRemote bool + closeReasonString string } -func newUDPSession(id uuid.UUID, destination netip.AddrPort, closeAfterIdle time.Duration, muxer *DatagramV2Muxer) *udpSession { +func newUDPSession(id uuid.UUID, destination netip.AddrPort, closeAfterIdle time.Duration, origin N.PacketConn, muxer *DatagramV2Muxer) *udpSession { return &udpSession{ id: id, destination: destination, closeAfterIdle: closeAfterIdle, + origin: origin, muxer: muxer, writeChan: make(chan []byte, 256), closeChan: make(chan struct{}), + activeAt: time.Now(), } } @@ -242,10 +304,114 @@ func (s *udpSession) writeToOrigin(payload []byte) { func (s *udpSession) close() { s.closeOnce.Do(func() { + if s.origin != nil { + _ = s.origin.Close() + } close(s.closeChan) }) } +func (s *udpSession) serve(ctx context.Context) { + go s.readLoop() + go s.writeLoop() + + tickInterval := s.closeAfterIdle / 2 + if tickInterval <= 0 || tickInterval > 10*time.Second { + tickInterval = time.Second + } + ticker := time.NewTicker(tickInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.closeWithReason("connection closed") + case <-ticker.C: + if time.Since(s.lastActive()) >= s.closeAfterIdle { + s.closeWithReason("idle timeout") + } + case <-s.closeChan: + return + } + } +} + +func (s *udpSession) readLoop() { + for { + buffer := buf.NewPacket() + _, err := s.origin.ReadPacket(buffer) + if err != nil { + buffer.Release() + s.closeWithReason(err.Error()) + return + } + s.markActive() + s.muxer.sendToEdge(s.id, append([]byte(nil), buffer.Bytes()...)) + buffer.Release() + } +} + +func (s *udpSession) writeLoop() { + for { + select { + case payload := <-s.writeChan: + err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination)) + if err != nil { + s.closeWithReason(err.Error()) + return + } + s.markActive() + case <-s.closeChan: + return + } + } +} + +func (s *udpSession) markActive() { + s.activeAccess.Lock() + s.activeAt = time.Now() + s.activeAccess.Unlock() +} + +func (s *udpSession) lastActive() time.Time { + s.activeAccess.RLock() + defer s.activeAccess.RUnlock() + return s.activeAt +} + +func (s *udpSession) closeWithReason(reason string) { + s.stateAccess.Lock() + if s.closeReasonString == "" { + s.closeReasonString = reason + } + s.stateAccess.Unlock() + s.close() +} + +func (s *udpSession) markRemoteClosed() { + s.stateAccess.Lock() + s.closedByRemote = true + if s.closeReasonString == "" { + s.closeReasonString = "unregistered by edge" + } + s.stateAccess.Unlock() +} + +func (s *udpSession) remoteClosed() bool { + s.stateAccess.RLock() + defer s.stateAccess.RUnlock() + return s.closedByRemote +} + +func (s *udpSession) closeReason() string { + s.stateAccess.RLock() + defer s.stateAccess.RUnlock() + if s.closeReasonString == "" { + return "session closed" + } + return s.closeReasonString +} + // ReadPacket implements N.PacketConn - reads packets from the edge to forward to origin. func (s *udpSession) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { select { @@ -273,6 +439,15 @@ func (s *udpSession) SetDeadline(_ time.Time) error { return nil } func (s *udpSession) SetReadDeadline(_ time.Time) error { return nil } func (s *udpSession) SetWriteDeadline(_ time.Time) error { return nil } +func (m *DatagramV2Muxer) unregisterRemoteSession(ctx context.Context, sessionID uuid.UUID, message string) error { + client, err := newV2SessionRPCClient(ctx, m.sender) + if err != nil { + return err + } + defer client.Close() + return client.UnregisterSession(ctx, sessionID, message) +} + // V2 RPC server implementation for HandleRPCStream. type cloudflaredServer struct { diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index 42e758584f..436fc5a33a 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -5,13 +5,11 @@ package cloudflare import ( "context" "encoding/binary" - "io" - "net" + "errors" "net/netip" "sync" "time" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" @@ -56,25 +54,40 @@ const ( // RequestID is a 128-bit session identifier for V3. type RequestID [v3RequestIDLength]byte +type v3RegistrationState uint8 + +const ( + v3RegistrationNew v3RegistrationState = iota + v3RegistrationExisting + v3RegistrationMigrated +) + +type DatagramV3SessionManager struct { + sessionAccess sync.RWMutex + sessions map[RequestID]*v3Session +} + +func NewDatagramV3SessionManager() *DatagramV3SessionManager { + return &DatagramV3SessionManager{ + sessions: make(map[RequestID]*v3Session), + } +} + // DatagramV3Muxer handles V3 datagram demuxing and session management. type DatagramV3Muxer struct { inbound *Inbound logger log.ContextLogger sender DatagramSender icmp *ICMPBridge - - sessionAccess sync.RWMutex - sessions map[RequestID]*v3Session } // NewDatagramV3Muxer creates a new V3 datagram muxer. func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV3Muxer { return &DatagramV3Muxer{ - inbound: inbound, - logger: logger, - sender: sender, - icmp: NewICMPBridge(inbound, sender, icmpWireV3), - sessions: make(map[RequestID]*v3Session), + inbound: inbound, + logger: logger, + sender: sender, + icmp: NewICMPBridge(inbound, sender, icmpWireV3), } } @@ -147,37 +160,25 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { return } - m.sessionAccess.Lock() - if existing, exists := m.sessions[requestID]; exists { - m.sessionAccess.Unlock() - // Session already exists - re-ack - m.sendRegistrationResponse(requestID, v3ResponseOK, "") - // Handle bundled payload - if flags&v3FlagBundle != 0 && len(data) > offset { - existing.writeToOrigin(data[offset:]) - } + session, state, err := m.inbound.datagramV3Manager.Register(m.inbound, ctx, requestID, destination, closeAfterIdle, m.sender) + if err == errTooManyActiveFlows { + m.sendRegistrationResponse(requestID, v3ResponseTooManyActiveFlows, "") return } - limit := m.inbound.maxActiveFlows() - if !m.inbound.flowLimiter.Acquire(limit) { - m.sessionAccess.Unlock() - m.sendRegistrationResponse(requestID, v3ResponseTooManyActiveFlows, "") + if err != nil { + m.sendRegistrationResponse(requestID, v3ResponseUnableToBindSocket, "") return } - session := newV3Session(requestID, destination, closeAfterIdle, m) - m.sessions[requestID] = session - m.sessionAccess.Unlock() - - m.logger.Info("registered V3 UDP session to ", destination) + if state == v3RegistrationNew { + m.logger.Info("registered V3 UDP session to ", destination) + } m.sendRegistrationResponse(requestID, v3ResponseOK, "") // Handle bundled first payload if flags&v3FlagBundle != 0 && len(data) > offset { session.writeToOrigin(data[offset:]) } - - go m.serveV3Session(ctx, session, limit) } func (m *DatagramV3Muxer) handlePayload(data []byte) { @@ -189,10 +190,7 @@ func (m *DatagramV3Muxer) handlePayload(data []byte) { copy(requestID[:], data[:v3RequestIDLength]) payload := data[v3RequestIDLength:] - m.sessionAccess.RLock() - session, exists := m.sessions[requestID] - m.sessionAccess.RUnlock() - + session, exists := m.inbound.datagramV3Manager.Get(requestID) if !exists { return } @@ -219,74 +217,162 @@ func (m *DatagramV3Muxer) sendPayload(requestID RequestID, payload []byte) { m.sender.SendDatagram(data) } -func (m *DatagramV3Muxer) unregisterSession(requestID RequestID) { - m.sessionAccess.Lock() - session, exists := m.sessions[requestID] - if exists { - delete(m.sessions, requestID) - } - m.sessionAccess.Unlock() - - if exists { - session.close() - } -} - -func (m *DatagramV3Muxer) serveV3Session(ctx context.Context, session *v3Session, limit uint64) { - defer m.unregisterSession(session.id) - defer m.inbound.flowLimiter.Release(limit) - - metadata := adapter.InboundContext{ - Inbound: m.inbound.Tag(), - InboundType: m.inbound.Type(), - Network: N.NetworkUDP, - } - metadata.Destination = M.SocksaddrFromNetIP(session.destination) - - done := make(chan struct{}) - m.inbound.router.RoutePacketConnectionEx( - ctx, - session, - metadata, - N.OnceClose(func(it error) { - close(done) - }), - ) - <-done -} - // Close closes all V3 sessions. -func (m *DatagramV3Muxer) Close() { - m.sessionAccess.Lock() - sessions := m.sessions - m.sessions = make(map[RequestID]*v3Session) - m.sessionAccess.Unlock() - - for _, session := range sessions { - session.close() - } -} +func (m *DatagramV3Muxer) Close() {} // v3Session represents a V3 UDP session. type v3Session struct { id RequestID destination netip.AddrPort closeAfterIdle time.Duration - muxer *DatagramV3Muxer + origin N.PacketConn + manager *DatagramV3SessionManager + inbound *Inbound writeChan chan []byte closeOnce sync.Once closeChan chan struct{} + + activeAccess sync.RWMutex + activeAt time.Time + + senderAccess sync.RWMutex + sender DatagramSender } -func newV3Session(id RequestID, destination netip.AddrPort, closeAfterIdle time.Duration, muxer *DatagramV3Muxer) *v3Session { - return &v3Session{ - id: id, +var errTooManyActiveFlows = errors.New("too many active flows") + +func (m *DatagramV3SessionManager) Register( + inbound *Inbound, + ctx context.Context, + requestID RequestID, + destination netip.AddrPort, + closeAfterIdle time.Duration, + sender DatagramSender, +) (*v3Session, v3RegistrationState, error) { + m.sessionAccess.Lock() + if existing, exists := m.sessions[requestID]; exists { + if existing.sender == sender { + existing.markActive() + m.sessionAccess.Unlock() + return existing, v3RegistrationExisting, nil + } + existing.setSender(sender) + existing.markActive() + m.sessionAccess.Unlock() + return existing, v3RegistrationMigrated, nil + } + + limit := inbound.maxActiveFlows() + if !inbound.flowLimiter.Acquire(limit) { + m.sessionAccess.Unlock() + return nil, 0, errTooManyActiveFlows + } + origin, err := inbound.dialWarpPacketConnection(ctx, destination) + if err != nil { + inbound.flowLimiter.Release(limit) + m.sessionAccess.Unlock() + return nil, 0, err + } + + session := &v3Session{ + id: requestID, destination: destination, closeAfterIdle: closeAfterIdle, - muxer: muxer, + origin: origin, + manager: m, + inbound: inbound, writeChan: make(chan []byte, 512), closeChan: make(chan struct{}), + activeAt: time.Now(), + sender: sender, + } + m.sessions[requestID] = session + m.sessionAccess.Unlock() + + sessionCtx := inbound.ctx + if sessionCtx == nil { + sessionCtx = context.Background() + } + go session.serve(sessionCtx, limit) + return session, v3RegistrationNew, nil +} + +func (m *DatagramV3SessionManager) Get(requestID RequestID) (*v3Session, bool) { + m.sessionAccess.RLock() + defer m.sessionAccess.RUnlock() + session, exists := m.sessions[requestID] + return session, exists +} + +func (m *DatagramV3SessionManager) remove(session *v3Session) { + m.sessionAccess.Lock() + if current, exists := m.sessions[session.id]; exists && current == session { + delete(m.sessions, session.id) + } + m.sessionAccess.Unlock() +} + +func (s *v3Session) serve(ctx context.Context, limit uint64) { + defer s.inbound.flowLimiter.Release(limit) + defer s.manager.remove(s) + + go s.readLoop() + go s.writeLoop() + + tickInterval := s.closeAfterIdle / 2 + if tickInterval <= 0 || tickInterval > 10*time.Second { + tickInterval = time.Second + } + ticker := time.NewTicker(tickInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.close() + case <-ticker.C: + if time.Since(s.lastActive()) >= s.closeAfterIdle { + s.close() + } + case <-s.closeChan: + return + } + } +} + +func (s *v3Session) readLoop() { + for { + buffer := buf.NewPacket() + _, err := s.origin.ReadPacket(buffer) + if err != nil { + buffer.Release() + s.close() + return + } + s.markActive() + if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil { + buffer.Release() + s.close() + return + } + buffer.Release() + } +} + +func (s *v3Session) writeLoop() { + for { + select { + case payload := <-s.writeChan: + err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination)) + if err != nil { + s.close() + return + } + s.markActive() + case <-s.closeChan: + return + } } } @@ -299,35 +385,41 @@ func (s *v3Session) writeToOrigin(payload []byte) { } } -func (s *v3Session) close() { - s.closeOnce.Do(func() { - close(s.closeChan) - }) +func (s *v3Session) senderDatagram(payload []byte) error { + data := make([]byte, v3PayloadHeaderLen+len(payload)) + data[0] = byte(DatagramV3TypePayload) + copy(data[1:1+v3RequestIDLength], s.id[:]) + copy(data[v3PayloadHeaderLen:], payload) + + s.senderAccess.RLock() + sender := s.sender + s.senderAccess.RUnlock() + return sender.SendDatagram(data) } -// ReadPacket implements N.PacketConn. -func (s *v3Session) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { - select { - case data := <-s.writeChan: - _, err := buffer.Write(data) - return M.SocksaddrFromNetIP(s.destination), err - case <-s.closeChan: - return M.Socksaddr{}, io.EOF - } +func (s *v3Session) setSender(sender DatagramSender) { + s.senderAccess.Lock() + s.sender = sender + s.senderAccess.Unlock() } -// WritePacket implements N.PacketConn. -func (s *v3Session) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - s.muxer.sendPayload(s.id, buffer.Bytes()) - return nil +func (s *v3Session) markActive() { + s.activeAccess.Lock() + s.activeAt = time.Now() + s.activeAccess.Unlock() } -func (s *v3Session) Close() error { - s.close() - return nil +func (s *v3Session) lastActive() time.Time { + s.activeAccess.RLock() + defer s.activeAccess.RUnlock() + return s.activeAt } -func (s *v3Session) LocalAddr() net.Addr { return nil } -func (s *v3Session) SetDeadline(_ time.Time) error { return nil } -func (s *v3Session) SetReadDeadline(_ time.Time) error { return nil } -func (s *v3Session) SetWriteDeadline(_ time.Time) error { return nil } +func (s *v3Session) close() { + s.closeOnce.Do(func() { + if s.origin != nil { + _ = s.origin.Close() + } + close(s.closeChan) + }) +} diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index 08d8704ecc..87f9148dd5 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -14,8 +14,9 @@ import ( func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { sender := &captureDatagramSender{} inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - flowLimiter: &FlowLimiter{}, + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + flowLimiter: &FlowLimiter{}, + datagramV3Manager: NewDatagramV3SessionManager(), } muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) @@ -40,8 +41,9 @@ func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) { sender := &captureDatagramSender{} inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - flowLimiter: &FlowLimiter{}, + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + flowLimiter: &FlowLimiter{}, + datagramV3Manager: NewDatagramV3SessionManager(), } muxer := NewDatagramV3Muxer(inboundInstance, sender, nil) diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 8d447f6ac0..746e83bb3d 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -195,17 +196,24 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser } defer i.flowLimiter.Release(limit) - err := respWriter.WriteResponse(nil, nil) + targetConn, err := i.dialWarpTCP(ctx, metadata.Destination) + if err != nil { + i.logger.ErrorContext(ctx, "dial tcp origin: ", err) + respWriter.WriteResponse(err, nil) + return + } + defer targetConn.Close() + + err = respWriter.WriteResponse(nil, nil) if err != nil { i.logger.ErrorContext(ctx, "write connect response: ", err) return } - done := make(chan struct{}) - i.router.RouteConnectionEx(ctx, newStreamConn(stream), metadata, N.OnceClose(func(it error) { - close(done) - })) - <-done + err = bufio.CopyConn(ctx, newStreamConn(stream), targetConn) + if err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "copy TCP stream: ", err) + } } func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { diff --git a/protocol/cloudflare/features.go b/protocol/cloudflare/features.go new file mode 100644 index 0000000000..5b26336ab5 --- /dev/null +++ b/protocol/cloudflare/features.go @@ -0,0 +1,59 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "encoding/json" + "hash/fnv" + "net" + "time" +) + +const ( + featureSelectorHostname = "cfd-features.argotunnel.com" + featureLookupTimeout = 10 * time.Second +) + +type cloudflaredFeaturesRecord struct { + DatagramV3Percentage uint32 `json:"dv3_2"` +} + +var lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) { + lookupCtx, cancel := context.WithTimeout(ctx, featureLookupTimeout) + defer cancel() + + records, err := net.DefaultResolver.LookupTXT(lookupCtx, featureSelectorHostname) + if err != nil || len(records) == 0 { + return nil, err + } + return []byte(records[0]), nil +} + +func resolveDatagramVersion(ctx context.Context, accountTag string, configured string) string { + if configured != "" { + return configured + } + record, err := lookupCloudflaredFeatures(ctx) + if err != nil { + return "v2" + } + + var features cloudflaredFeaturesRecord + if err := json.Unmarshal(record, &features); err != nil { + return "v2" + } + if accountEnabled(accountTag, features.DatagramV3Percentage) { + return "v3" + } + return "v2" +} + +func accountEnabled(accountTag string, percentage uint32) bool { + if percentage == 0 { + return false + } + hasher := fnv.New32a() + _, _ = hasher.Write([]byte(accountTag)) + return percentage > hasher.Sum32()%100 +} diff --git a/protocol/cloudflare/features_test.go b/protocol/cloudflare/features_test.go new file mode 100644 index 0000000000..82534eb47c --- /dev/null +++ b/protocol/cloudflare/features_test.go @@ -0,0 +1,30 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "testing" +) + +func TestResolveDatagramVersionConfiguredWins(t *testing.T) { + version := resolveDatagramVersion(context.Background(), "account", "v3") + if version != "v3" { + t.Fatalf("expected configured version to win, got %s", version) + } +} + +func TestResolveDatagramVersionRemoteSelection(t *testing.T) { + originalLookup := lookupCloudflaredFeatures + lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) { + return []byte(`{"dv3_2":100}`), nil + } + defer func() { + lookupCloudflaredFeatures = originalLookup + }() + + version := resolveDatagramVersion(context.Background(), "account", "") + if version != "v3" { + t.Fatalf("expected auto-selected v3, got %s", version) + } +} diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index 12eaa6dba6..f4d24ee0a3 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -31,11 +31,12 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound { config.WarpRouting.MaxActiveFlows = limit configManager.activeConfig = config return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - router: &testRouter{}, - logger: logFactory.NewLogger("test"), - configManager: configManager, - flowLimiter: &FlowLimiter{}, + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + configManager: configManager, + flowLimiter: &FlowLimiter{}, + datagramV3Manager: NewDatagramV3SessionManager(), } } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 4fd13ef77d..81f829dafb 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -22,6 +22,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/bufio" N "github.com/sagernet/sing/common/network" "github.com/google/uuid" @@ -137,6 +138,18 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC onClose(nil) } +func (r *testRouter) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { + return net.Dial("tcp", metadata.Destination.String()) +} + +func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { + conn, err := net.Dial("udp", metadata.Destination.String()) + if err != nil { + return nil, err + } + return bufio.NewUnbindPacketConn(conn), nil +} + func (r *testRouter) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { if r.preMatch != nil { return r.preMatch(metadata, routeContext, timeout, supportBypass) @@ -177,23 +190,24 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i ctx, cancel := context.WithCancel(context.Background()) inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - ctx: ctx, - cancel: cancel, - router: &testRouter{}, - logger: logFactory.NewLogger("test"), - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - edgeIPVersion: 0, - datagramVersion: "", - gracePeriod: 5 * time.Second, - configManager: configManager, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), - controlDialer: N.SystemDialer, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + ctx: ctx, + cancel: cancel, + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + edgeIPVersion: 0, + datagramVersion: "", + gracePeriod: 5 * time.Second, + configManager: configManager, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + datagramV3Manager: NewDatagramV3SessionManager(), + controlDialer: N.SystemDialer, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, } t.Cleanup(func() { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index c794df3b51..ad46458187 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -61,6 +61,7 @@ type Inbound struct { datagramMuxerAccess sync.Mutex datagramV2Muxers map[DatagramSender]*DatagramV2Muxer datagramV3Muxers map[DatagramSender]*DatagramV3Muxer + datagramV3Manager *DatagramV3SessionManager helloWorldAccess sync.Mutex helloWorldServer *http.Server @@ -129,27 +130,28 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, tag), - ctx: inboundCtx, - cancel: cancel, - router: router, - logger: logger, - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - region: region, - edgeIPVersion: edgeIPVersion, - datagramVersion: datagramVersion, - gracePeriod: gracePeriod, - configManager: configManager, - flowLimiter: &FlowLimiter{}, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, - controlDialer: controlDialer, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), - connectedIndices: make(map[uint8]struct{}), - connectedNotify: make(chan uint8, haConnections), + Adapter: inbound.NewAdapter(C.TypeCloudflared, tag), + ctx: inboundCtx, + cancel: cancel, + router: router, + logger: logger, + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + region: region, + edgeIPVersion: edgeIPVersion, + datagramVersion: datagramVersion, + gracePeriod: gracePeriod, + configManager: configManager, + flowLimiter: &FlowLimiter{}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, + controlDialer: controlDialer, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + datagramV3Manager: NewDatagramV3SessionManager(), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), }, nil } @@ -170,6 +172,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { return E.New("no edge addresses available") } + i.datagramVersion = resolveDatagramVersion(i.ctx, i.credentials.AccountTag, i.datagramVersion) features := DefaultFeatures(i.datagramVersion) for connIndex := 0; connIndex < i.haConnections; connIndex++ { @@ -301,6 +304,10 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe retries++ backoff := backoffDuration(retries) + var retryableErr *RetryableError + if errors.As(err, &retryableErr) && retryableErr.Delay > 0 { + backoff = retryableErr.Delay + } i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff) select { diff --git a/protocol/cloudflare/origin_dial.go b/protocol/cloudflare/origin_dial.go new file mode 100644 index 0000000000..5c6f80b7a9 --- /dev/null +++ b/protocol/cloudflare/origin_dial.go @@ -0,0 +1,86 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type routedOriginDialer interface { + DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) + DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) +} + +func (i *Inbound) dialWarpTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + originDialer, ok := i.router.(routedOriginDialer) + if !ok { + return nil, E.New("router does not support cloudflare routed dialing") + } + + warpRouting := i.configManager.Snapshot().WarpRouting + if warpRouting.ConnectTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout) + defer cancel() + } + + conn, err := originDialer.DialRouteConnection(ctx, adapter.InboundContext{ + Inbound: i.Tag(), + InboundType: i.Type(), + Network: N.NetworkTCP, + Destination: destination, + }) + if err != nil { + return nil, err + } + _ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive) + return conn, nil +} + +func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) { + originDialer, ok := i.router.(routedOriginDialer) + if !ok { + return nil, E.New("router does not support cloudflare routed packet dialing") + } + + warpRouting := i.configManager.Snapshot().WarpRouting + if warpRouting.ConnectTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout) + defer cancel() + } + + return originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{ + Inbound: i.Tag(), + InboundType: i.Type(), + Network: N.NetworkUDP, + Destination: M.SocksaddrFromNetIP(destination), + UDPConnect: true, + }) +} + +func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error { + if keepAlive <= 0 { + return nil + } + type keepAliveConn interface { + SetKeepAlive(bool) error + SetKeepAlivePeriod(time.Duration) error + } + tcpConn, ok := conn.(keepAliveConn) + if !ok { + return nil + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return tcpConn.SetKeepAlivePeriod(keepAlive) +} diff --git a/protocol/cloudflare/stream.go b/protocol/cloudflare/stream.go index e9dcfa55f9..62003d2334 100644 --- a/protocol/cloudflare/stream.go +++ b/protocol/cloudflare/stream.go @@ -168,6 +168,11 @@ func WriteConnectResponse(w io.Writer, responseError error, metadata ...Metadata return capnp.NewEncoder(w).Encode(msg) } +func WriteRPCStreamSignature(w io.Writer) error { + _, err := w.Write(rpcStreamSignature[:]) + return err +} + // Registration data structures for the control stream. type RegistrationTunnelAuth struct { diff --git a/route/dial.go b/route/dial.go new file mode 100644 index 0000000000..013a6350aa --- /dev/null +++ b/route/dial.go @@ -0,0 +1,149 @@ +package route + +import ( + "context" + "net" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + tf "github.com/sagernet/sing-box/common/tlsfragment" + C "github.com/sagernet/sing-box/constant" + R "github.com/sagernet/sing-box/route/rule" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +// DialRouteConnection dials a routed TCP connection for metadata without requiring an upstream accepted socket. +func (r *Router) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { + metadata.Network = N.NetworkTCP + ctx = adapter.WithContext(ctx, &metadata) + + selectedRule, selectedOutbound, err := r.selectRoutedOutbound(ctx, &metadata, N.NetworkTCP) + if err != nil { + return nil, err + } + + var conn net.Conn + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + conn, err = dialer.DialSerialNetwork( + ctx, + selectedOutbound, + N.NetworkTCP, + metadata.Destination, + metadata.DestinationAddresses, + metadata.NetworkStrategy, + metadata.NetworkType, + metadata.FallbackNetworkType, + metadata.FallbackDelay, + ) + } else { + conn, err = selectedOutbound.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + if err != nil { + return nil, err + } + + if metadata.TLSFragment || metadata.TLSRecordFragment { + conn = tf.NewConn(conn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay) + } + for _, tracker := range r.trackers { + conn = tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) + } + return conn, nil +} + +// DialRoutePacketConnection dials a routed connected UDP packet connection for metadata. +func (r *Router) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { + metadata.Network = N.NetworkUDP + metadata.UDPConnect = true + ctx = adapter.WithContext(ctx, &metadata) + + selectedRule, selectedOutbound, err := r.selectRoutedOutbound(ctx, &metadata, N.NetworkUDP) + if err != nil { + return nil, err + } + + var remoteConn net.Conn + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { + remoteConn, err = dialer.DialSerialNetwork( + ctx, + selectedOutbound, + N.NetworkUDP, + metadata.Destination, + metadata.DestinationAddresses, + metadata.NetworkStrategy, + metadata.NetworkType, + metadata.FallbackNetworkType, + metadata.FallbackDelay, + ) + } else { + remoteConn, err = selectedOutbound.DialContext(ctx, N.NetworkUDP, metadata.Destination) + } + if err != nil { + return nil, err + } + + var packetConn N.PacketConn = bufio.NewUnbindPacketConn(remoteConn) + for _, tracker := range r.trackers { + packetConn = tracker.RoutedPacketConnection(ctx, packetConn, metadata, selectedRule, selectedOutbound) + } + if metadata.FakeIP { + packetConn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(packetConn), metadata.OriginDestination, metadata.Destination) + } + return packetConn, nil +} + +func (r *Router) selectRoutedOutbound( + ctx context.Context, + metadata *adapter.InboundContext, + network string, +) (adapter.Rule, adapter.Outbound, error) { + selectedRule, _, buffers, packetBuffers, err := r.matchRule(ctx, metadata, false, false, nil, nil) + if len(buffers) > 0 { + buf.ReleaseMulti(buffers) + } + if len(packetBuffers) > 0 { + N.ReleaseMultiPacketBuffer(packetBuffers) + } + if err != nil { + return nil, nil, err + } + + var selectedOutbound adapter.Outbound + if selectedRule != nil { + switch action := selectedRule.Action().(type) { + case *R.RuleActionRoute: + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, nil, E.New("outbound not found: ", action.Outbound) + } + case *R.RuleActionBypass: + if action.Outbound != "" { + var loaded bool + selectedOutbound, loaded = r.outbound.Outbound(action.Outbound) + if !loaded { + return nil, nil, E.New("outbound not found: ", action.Outbound) + } + } + case *R.RuleActionReject: + if action.Method == C.RuleActionRejectMethodReply { + return nil, nil, E.New("reject method `reply` is not supported for dialed connections") + } + return nil, nil, action.Error(ctx) + case *R.RuleActionHijackDNS: + return nil, nil, E.New("DNS hijack is not supported for dialed connections") + } + } + + if selectedOutbound == nil { + selectedOutbound = r.outbound.Default() + } + if !common.Contains(selectedOutbound.Network(), network) { + return nil, nil, E.New(network, " is not supported by outbound: ", selectedOutbound.Tag()) + } + return selectedRule, selectedOutbound, nil +} From 6e35f4da8966137e38c0f6c853fc2c4329ebd71e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 20:03:33 +0800 Subject: [PATCH 31/41] Route cloudflare TCP through pipe --- protocol/cloudflare/dispatch.go | 28 ++-- protocol/cloudflare/helpers_test.go | 4 - protocol/cloudflare/origin_dial.go | 51 +----- protocol/cloudflare/router_pipe.go | 90 +++++++++++ protocol/cloudflare/router_pipe_test.go | 165 ++++++++++++++++++++ protocol/cloudflare/special_service.go | 17 +- protocol/cloudflare/special_service_test.go | 1 + route/dial.go | 40 ----- 8 files changed, 271 insertions(+), 125 deletions(-) create mode 100644 protocol/cloudflare/router_pipe.go create mode 100644 protocol/cloudflare/router_pipe_test.go diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 746e83bb3d..5cedfe2a22 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -23,7 +23,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" ) const ( @@ -196,14 +195,22 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser } defer i.flowLimiter.Release(limit) - targetConn, err := i.dialWarpTCP(ctx, metadata.Destination) + warpRouting := i.configManager.Snapshot().WarpRouting + targetConn, cleanup, err := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{ + timeout: warpRouting.ConnectTimeout, + onHandshake: func(conn net.Conn) { + _ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive) + }, + }) if err != nil { i.logger.ErrorContext(ctx, "dial tcp origin: ", err) respWriter.WriteResponse(err, nil) return } - defer targetConn.Close() + defer cleanup() + // Cloudflare expects an optimistic ACK here so the routed TCP path can sniff + // the real input stream before the outbound connection is fully established. err = respWriter.WriteResponse(nil, nil) if err != nil { i.logger.ErrorContext(ctx, "write connect response: ", err) @@ -391,12 +398,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, } func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) { - input, output := pipe.Pipe() - done := make(chan struct{}) - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - common.Close(input, output) - close(done) - })) + input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) transport := &http.Transport{ DisableCompression: true, @@ -411,13 +413,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter }, } applyHTTPTransportProxy(transport, originRequest) - return transport, func() { - common.Close(input, output) - select { - case <-done: - case <-time.After(time.Second): - } - } + return transport, cleanup } func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 81f829dafb..253eed5cf8 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -138,10 +138,6 @@ func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketC onClose(nil) } -func (r *testRouter) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { - return net.Dial("tcp", metadata.Destination.String()) -} - func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { conn, err := net.Dial("udp", metadata.Destination.String()) if err != nil { diff --git a/protocol/cloudflare/origin_dial.go b/protocol/cloudflare/origin_dial.go index 5c6f80b7a9..c937aa35b3 100644 --- a/protocol/cloudflare/origin_dial.go +++ b/protocol/cloudflare/origin_dial.go @@ -4,9 +4,7 @@ package cloudflare import ( "context" - "net" "net/netip" - "time" "github.com/sagernet/sing-box/adapter" E "github.com/sagernet/sing/common/exceptions" @@ -14,39 +12,12 @@ import ( N "github.com/sagernet/sing/common/network" ) -type routedOriginDialer interface { - DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) +type routedOriginPacketDialer interface { DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) } -func (i *Inbound) dialWarpTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { - originDialer, ok := i.router.(routedOriginDialer) - if !ok { - return nil, E.New("router does not support cloudflare routed dialing") - } - - warpRouting := i.configManager.Snapshot().WarpRouting - if warpRouting.ConnectTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout) - defer cancel() - } - - conn, err := originDialer.DialRouteConnection(ctx, adapter.InboundContext{ - Inbound: i.Tag(), - InboundType: i.Type(), - Network: N.NetworkTCP, - Destination: destination, - }) - if err != nil { - return nil, err - } - _ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive) - return conn, nil -} - func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) { - originDialer, ok := i.router.(routedOriginDialer) + originDialer, ok := i.router.(routedOriginPacketDialer) if !ok { return nil, E.New("router does not support cloudflare routed packet dialing") } @@ -66,21 +37,3 @@ func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination neti UDPConnect: true, }) } - -func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error { - if keepAlive <= 0 { - return nil - } - type keepAliveConn interface { - SetKeepAlive(bool) error - SetKeepAlivePeriod(time.Duration) error - } - tcpConn, ok := conn.(keepAliveConn) - if !ok { - return nil - } - if err := tcpConn.SetKeepAlive(true); err != nil { - return err - } - return tcpConn.SetKeepAlivePeriod(keepAlive) -} diff --git a/protocol/cloudflare/router_pipe.go b/protocol/cloudflare/router_pipe.go new file mode 100644 index 0000000000..9431fa228b --- /dev/null +++ b/protocol/cloudflare/router_pipe.go @@ -0,0 +1,90 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +type routedPipeTCPOptions struct { + timeout time.Duration + onHandshake func(net.Conn) +} + +type routedPipeTCPConn struct { + net.Conn + handshakeOnce sync.Once + onHandshake func(net.Conn) +} + +func (c *routedPipeTCPConn) ConnHandshakeSuccess(conn net.Conn) error { + if c.onHandshake != nil { + c.handshakeOnce.Do(func() { + c.onHandshake(conn) + }) + } + return nil +} + +func (i *Inbound) dialRouterTCPWithMetadata(ctx context.Context, metadata adapter.InboundContext, options routedPipeTCPOptions) (net.Conn, func(), error) { + input, output := pipe.Pipe() + routerConn := &routedPipeTCPConn{ + Conn: output, + onHandshake: options.onHandshake, + } + done := make(chan struct{}) + + routeCtx := ctx + var cancel context.CancelFunc + if options.timeout > 0 { + routeCtx, cancel = context.WithTimeout(ctx, options.timeout) + } + + var closeOnce sync.Once + closePipe := func() { + closeOnce.Do(func() { + if cancel != nil { + cancel() + } + common.Close(input, routerConn) + }) + } + go i.router.RouteConnectionEx(routeCtx, routerConn, metadata, N.OnceClose(func(it error) { + closePipe() + close(done) + })) + + return input, func() { + closePipe() + select { + case <-done: + case <-time.After(time.Second): + } + }, nil +} + +func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error { + if keepAlive <= 0 { + return nil + } + type keepAliveConn interface { + SetKeepAlive(bool) error + SetKeepAlivePeriod(time.Duration) error + } + tcpConn, ok := conn.(keepAliveConn) + if !ok { + return nil + } + if err := tcpConn.SetKeepAlive(true); err != nil { + return err + } + return tcpConn.SetKeepAlivePeriod(keepAlive) +} diff --git a/protocol/cloudflare/router_pipe_test.go b/protocol/cloudflare/router_pipe_test.go new file mode 100644 index 0000000000..779ade53d7 --- /dev/null +++ b/protocol/cloudflare/router_pipe_test.go @@ -0,0 +1,165 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func TestHandleTCPStreamUsesRouteConnectionEx(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + router := &countingRouter{} + inboundInstance := newSpecialServiceInboundWithRouter(t, router) + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + responseDone := respWriter.done + finished := make(chan struct{}) + go func() { + inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{ + Destination: M.ParseSocksaddr(listener.Addr().String()), + }) + close(finished) + }() + + select { + case <-responseDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for connect response") + } + if respWriter.err != nil { + t.Fatal("unexpected response error: ", respWriter.err) + } + + if err := clientSide.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } + payload := []byte("ping") + if _, err := clientSide.Write(payload); err != nil { + t.Fatal(err) + } + response := make([]byte, len(payload)) + if _, err := io.ReadFull(clientSide, response); err != nil { + t.Fatal(err) + } + if string(response) != string(payload) { + t.Fatalf("unexpected echo payload: %q", string(response)) + } + if router.count.Load() != 1 { + t.Fatalf("expected RouteConnectionEx to be used once, got %d", router.count.Load()) + } + + _ = clientSide.Close() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for TCP stream handler to exit") + } +} + +func TestHandleTCPStreamWritesOptimisticAck(t *testing.T) { + router := &blockingRouteRouter{ + started: make(chan struct{}), + release: make(chan struct{}), + } + inboundInstance := newSpecialServiceInboundWithRouter(t, router) + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + responseDone := respWriter.done + finished := make(chan struct{}) + go func() { + inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{ + Destination: M.ParseSocksaddr("127.0.0.1:443"), + }) + close(finished) + }() + + select { + case <-router.started: + case <-time.After(time.Second): + t.Fatal("timed out waiting for router goroutine to start") + } + select { + case <-responseDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for optimistic connect response") + } + if respWriter.err != nil { + t.Fatal("unexpected response error: ", respWriter.err) + } + + close(router.release) + _ = clientSide.Close() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for TCP stream handler to exit") + } +} + +func TestRoutedPipeTCPConnHandshakeAppliesKeepAlive(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + remoteConn := &keepAliveTestConn{Conn: right} + routerConn := &routedPipeTCPConn{ + Conn: left, + onHandshake: func(conn net.Conn) { + _ = applyTCPKeepAlive(conn, 15*time.Second) + }, + } + if err := routerConn.ConnHandshakeSuccess(remoteConn); err != nil { + t.Fatal(err) + } + if !remoteConn.enabled { + t.Fatal("expected keepalive to be enabled") + } + if remoteConn.period != 15*time.Second { + t.Fatalf("unexpected keepalive period: %s", remoteConn.period) + } +} + +type blockingRouteRouter struct { + testRouter + started chan struct{} + release chan struct{} +} + +func (r *blockingRouteRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + close(r.started) + <-r.release + _ = conn.Close() + onClose(nil) +} + +type keepAliveTestConn struct { + net.Conn + enabled bool + period time.Duration +} + +func (c *keepAliveTestConn) SetKeepAlive(enabled bool) error { + c.enabled = enabled + return nil +} + +func (c *keepAliveTestConn) SetKeepAlivePeriod(period time.Duration) error { + c.period = period + return nil +} diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index e60edfa0d1..c5b5e0f4dd 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -13,16 +13,13 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/transport/v2raywebsocket" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/ws" ) @@ -118,25 +115,13 @@ func requestHeaderValue(request *ConnectRequest, headerName string) string { } func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) { - input, output := pipe.Pipe() - done := make(chan struct{}) metadata := adapter.InboundContext{ Inbound: i.Tag(), InboundType: i.Type(), Network: N.NetworkTCP, Destination: destination, } - go i.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) { - common.Close(input, output) - close(done) - })) - return input, func() { - common.Close(input, output) - select { - case <-done: - case <-time.After(time.Second): - } - }, nil + return i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) } func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error { diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 35d8245fc1..2df966afa3 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -68,6 +68,7 @@ func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *In router: router, logger: logFactory.NewLogger("test"), configManager: configManager, + flowLimiter: &FlowLimiter{}, } } diff --git a/route/dial.go b/route/dial.go index 013a6350aa..48187debec 100644 --- a/route/dial.go +++ b/route/dial.go @@ -6,7 +6,6 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" - tf "github.com/sagernet/sing-box/common/tlsfragment" C "github.com/sagernet/sing-box/constant" R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing/common" @@ -16,45 +15,6 @@ import ( N "github.com/sagernet/sing/common/network" ) -// DialRouteConnection dials a routed TCP connection for metadata without requiring an upstream accepted socket. -func (r *Router) DialRouteConnection(ctx context.Context, metadata adapter.InboundContext) (net.Conn, error) { - metadata.Network = N.NetworkTCP - ctx = adapter.WithContext(ctx, &metadata) - - selectedRule, selectedOutbound, err := r.selectRoutedOutbound(ctx, &metadata, N.NetworkTCP) - if err != nil { - return nil, err - } - - var conn net.Conn - if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { - conn, err = dialer.DialSerialNetwork( - ctx, - selectedOutbound, - N.NetworkTCP, - metadata.Destination, - metadata.DestinationAddresses, - metadata.NetworkStrategy, - metadata.NetworkType, - metadata.FallbackNetworkType, - metadata.FallbackDelay, - ) - } else { - conn, err = selectedOutbound.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return nil, err - } - - if metadata.TLSFragment || metadata.TLSRecordFragment { - conn = tf.NewConn(conn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay) - } - for _, tracker := range r.trackers { - conn = tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound) - } - return conn, nil -} - // DialRoutePacketConnection dials a routed connected UDP packet connection for metadata. func (r *Router) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { metadata.Network = N.NetworkUDP From 1320b737b9e1db9bbd8a18169173d72930df6ea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 21:03:47 +0800 Subject: [PATCH 32/41] Align cloudflare runtime behavior with cloudflared --- protocol/cloudflare/config_decode_test.go | 10 +++ protocol/cloudflare/dispatch.go | 4 +- protocol/cloudflare/edge_discovery_test.go | 27 ++++++++ protocol/cloudflare/inbound.go | 34 +++++++-- protocol/cloudflare/special_service.go | 77 +++++++++++++++++++-- protocol/cloudflare/special_service_test.go | 71 ++++++++++++++++++- 6 files changed, 210 insertions(+), 13 deletions(-) diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go index 4addd1f99b..0c6f834736 100644 --- a/protocol/cloudflare/config_decode_test.go +++ b/protocol/cloudflare/config_decode_test.go @@ -26,3 +26,13 @@ func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestNormalizeProtocolAcceptsAuto(t *testing.T) { + protocol, err := normalizeProtocol("auto") + if err != nil { + t.Fatal(err) + } + if protocol != "" { + t.Fatalf("expected auto protocol to normalize to empty string, got %q", protocol) + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 5cedfe2a22..10b4300768 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -265,7 +265,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos respWriter.WriteResponse(err, nil) return } - i.handleStreamService(ctx, stream, respWriter, request, metadata, service.Destination) + i.handleStreamService(ctx, stream, respWriter, request, metadata, service) case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: if request.Type == ConnectionTypeHTTP { i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) @@ -279,7 +279,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos respWriter.WriteResponse(err, nil) return } - i.handleBastionStream(ctx, stream, respWriter, request, metadata) + i.handleBastionStream(ctx, stream, respWriter, request, metadata, service) case ResolvedServiceSocksProxy: if request.Type != ConnectionTypeWebsocket { err := E.New("socks-proxy service requires websocket request type") diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index 28dda352ec..f3e6e8df54 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -97,3 +97,30 @@ func TestGetRegionalServiceName(t *testing.T) { t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got) } } + +func TestInitialEdgeAddrIndex(t *testing.T) { + if got := initialEdgeAddrIndex(0, 4); got != 0 { + t.Fatalf("expected conn 0 to get index 0, got %d", got) + } + if got := initialEdgeAddrIndex(3, 4); got != 3 { + t.Fatalf("expected conn 3 to get index 3, got %d", got) + } + if got := initialEdgeAddrIndex(5, 4); got != 1 { + t.Fatalf("expected conn 5 to wrap to index 1, got %d", got) + } + if got := initialEdgeAddrIndex(2, 1); got != 0 { + t.Fatalf("expected single-address pool to always return 0, got %d", got) + } +} + +func TestRotateEdgeAddrIndex(t *testing.T) { + if got := rotateEdgeAddrIndex(0, 4); got != 1 { + t.Fatalf("expected index 0 to rotate to 1, got %d", got) + } + if got := rotateEdgeAddrIndex(3, 4); got != 0 { + t.Fatalf("expected last index to wrap to 0, got %d", got) + } + if got := rotateEdgeAddrIndex(0, 1); got != 0 { + t.Fatalf("expected single-address pool to stay at 0, got %d", got) + } +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index ad46458187..42e0b46a3c 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -86,9 +86,9 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo haConnections = 4 } - protocol := options.Protocol - if protocol != "" && protocol != "quic" && protocol != "http2" { - return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2") + protocol, err := normalizeProtocol(options.Protocol) + if err != nil { + return nil, err } edgeIPVersion := options.EdgeIPVersion @@ -283,6 +283,7 @@ const ( func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { defer i.done.Done() + edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs)) retries := 0 for { select { @@ -291,7 +292,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe default: } - edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))] + edgeAddr := edgeAddrs[edgeIndex] err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries)) if err == nil || i.ctx.Err() != nil { return @@ -303,6 +304,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } retries++ + edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs)) backoff := backoffDuration(retries) var retryableErr *RetryableError if errors.As(err, &retryableErr) && retryableErr.Delay > 0 { @@ -410,6 +412,20 @@ func backoffDuration(retries int) time.Duration { return backoff/2 + jitter } +func initialEdgeAddrIndex(connIndex uint8, size int) int { + if size <= 1 { + return 0 + } + return int(connIndex) % size +} + +func rotateEdgeAddrIndex(current int, size int) int { + if size <= 1 { + return 0 + } + return (current + 1) % size +} + func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { var result []*EdgeAddr for _, region := range regions { @@ -430,3 +446,13 @@ func parseToken(token string) (Credentials, error) { } return tunnelToken.ToCredentials(), nil } + +func normalizeProtocol(protocol string) (string, error) { + if protocol == "auto" { + return "", nil + } + if protocol != "" && protocol != "quic" && protocol != "http2" { + return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2") + } + return protocol, nil +} diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index c5b5e0f4dd..6c6d142ae1 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -32,20 +32,20 @@ const ( socksReplyCommandNotSupported = 7 ) -func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { +func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { destination, err := resolveBastionDestination(request) if err != nil { respWriter.WriteResponse(err, nil) return } - i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination)) + i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType) } -func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, destination M.Socksaddr) { - i.handleRouterBackedStream(ctx, stream, respWriter, request, destination) +func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType) } -func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr) { +func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) { targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) if err != nil { respWriter.WriteResponse(err, nil) @@ -61,6 +61,12 @@ func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWr wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) defer wsConn.Close() + if isSocksProxyType(proxyType) { + if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err) + } + return + } _ = bufio.CopyConn(ctx, wsConn, targetConn) } @@ -101,6 +107,67 @@ func websocketResponseHeaders(request *ConnectRequest) http.Header { return header } +func isSocksProxyType(proxyType string) bool { + lower := strings.ToLower(strings.TrimSpace(proxyType)) + return lower == "socks" || lower == "socks5" +} + +func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error { + version := make([]byte, 1) + if _, err := io.ReadFull(conn, version); err != nil { + return err + } + if version[0] != 5 { + return E.New("unsupported SOCKS version: ", version[0]) + } + + methodCount := make([]byte, 1) + if _, err := io.ReadFull(conn, methodCount); err != nil { + return err + } + methods := make([]byte, int(methodCount[0])) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + var supportsNoAuth bool + for _, method := range methods { + if method == 0 { + supportsNoAuth = true + break + } + } + if !supportsNoAuth { + _, err := conn.Write([]byte{5, 255}) + if err != nil { + return err + } + return E.New("unknown authentication type") + } + if _, err := conn.Write([]byte{5, 0}); err != nil { + return err + } + + requestHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, requestHeader); err != nil { + return err + } + if requestHeader[0] != 5 { + return E.New("unsupported SOCKS request version: ", requestHeader[0]) + } + if requestHeader[1] != 1 { + _ = writeSocksReply(conn, socksReplyCommandNotSupported) + return E.New("unsupported SOCKS command: ", requestHeader[1]) + } + if _, err := readSocksDestination(conn, requestHeader[3]); err != nil { + return err + } + if err := writeSocksReply(conn, socksReplySuccess); err != nil { + return err + } + return bufio.CopyConn(ctx, conn, targetConn) +} + func requestHeaderValue(request *ConnectRequest, headerName string) string { for _, entry := range request.Metadata { if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") { diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 2df966afa3..8c39543c29 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -201,7 +201,7 @@ func TestHandleBastionStream(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{}) }() select { @@ -438,7 +438,10 @@ func TestHandleStreamService(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, M.ParseSocksaddr(listener.Addr().String())) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + }) }() select { @@ -473,3 +476,67 @@ func TestHandleStreamService(t *testing.T) { t.Fatal("stream service did not exit") } } + +func TestHandleStreamServiceProxyTypeSocks(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + OriginRequest: OriginRequestConfig{ + ProxyType: "socks", + }, + }) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for stream service connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplySuccess { + t.Fatalf("unexpected socks connect response: %v", data) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks stream service did not exit") + } +} From 7ca692d8c2b7867703c99179bf40b9205c8b1ba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 24 Mar 2026 21:45:14 +0800 Subject: [PATCH 33/41] Remove hello_world cloudflare service --- protocol/cloudflare/dispatch.go | 15 +------ protocol/cloudflare/inbound.go | 49 ---------------------- protocol/cloudflare/ingress_test.go | 26 ------------ protocol/cloudflare/origin_request_test.go | 7 +--- protocol/cloudflare/runtime_config.go | 21 +--------- 5 files changed, 4 insertions(+), 114 deletions(-) diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 10b4300768..07b5b020d6 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -150,14 +150,6 @@ func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string if !loaded { return ResolvedService{}, "", E.New("no ingress rule matched request host/path") } - if service.Kind == ResolvedServiceHelloWorld { - helloURL, err := i.ensureHelloWorldURL() - if err != nil { - return ResolvedService{}, "", err - } - service.BaseURL = helloURL - service.OriginRequest.NoTLSVerify = true - } originURL, err := service.BuildRequestURL(requestURL) if err != nil { return ResolvedService{}, "", E.Cause(err, "build origin request URL") @@ -266,7 +258,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos return } i.handleStreamService(ctx, stream, respWriter, request, metadata, service) - case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: + case ResolvedServiceUnix, ResolvedServiceUnixTLS: if request.Type == ConnectionTypeHTTP { i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) } else { @@ -439,11 +431,6 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { return dialer.DialContext(ctx, "unix", service.UnixPath) } - case ResolvedServiceHelloWorld: - target := service.BaseURL.Host - transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "tcp", target) - } default: return nil, nil, E.New("unsupported direct origin service") } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 42e0b46a3c..a49d478bd3 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -4,21 +4,16 @@ package cloudflare import ( "context" - stdTLS "crypto/tls" "encoding/base64" "errors" "io" "math/rand" - "net" - "net/http" - "net/url" "sync" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" boxDialer "github.com/sagernet/sing-box/common/dialer" - boxTLS "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -63,10 +58,6 @@ type Inbound struct { datagramV3Muxers map[DatagramSender]*DatagramV3Muxer datagramV3Manager *DatagramV3SessionManager - helloWorldAccess sync.Mutex - helloWorldServer *http.Server - helloWorldURL *url.URL - connectedAccess sync.Mutex connectedIndices map[uint8]struct{} connectedNotify chan uint8 @@ -231,49 +222,9 @@ func (i *Inbound) Close() error { } i.connections = nil i.connectionAccess.Unlock() - if i.helloWorldServer != nil { - i.helloWorldServer.Close() - } return nil } -func (i *Inbound) ensureHelloWorldURL() (*url.URL, error) { - i.helloWorldAccess.Lock() - defer i.helloWorldAccess.Unlock() - if i.helloWorldURL != nil { - return i.helloWorldURL, nil - } - - mux := http.NewServeMux() - mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write([]byte("Hello World")) - }) - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, E.Cause(err, "listen hello world server") - } - certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost") - if err != nil { - _ = listener.Close() - return nil, E.Cause(err, "generate hello world certificate") - } - tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{ - Certificates: []stdTLS.Certificate{*certificate}, - }) - server := &http.Server{Handler: mux} - go server.Serve(tlsListener) - - i.helloWorldServer = server - i.helloWorldURL = &url.URL{ - Scheme: "https", - Host: listener.Addr().String(), - } - return i.helloWorldURL, nil -} - const ( backoffBaseTime = time.Second backoffMaxTime = 2 * time.Minute diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index a742e794de..5ff2db3db2 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -162,29 +162,3 @@ func TestResolveHTTPServiceStatus(t *testing.T) { t.Fatalf("status service should keep request URL, got %s", requestURL) } } - -func TestResolveHTTPServiceHelloWorld(t *testing.T) { - inboundInstance := newTestIngressInbound(t) - inboundInstance.configManager.activeConfig = RuntimeConfig{ - Ingress: []compiledIngressRule{ - {Service: mustResolvedService(t, "hello_world")}, - }, - } - - service, requestURL, err := inboundInstance.resolveHTTPService("https://hello.example.com/path") - if err != nil { - t.Fatal(err) - } - if service.Kind != ResolvedServiceHelloWorld { - t.Fatalf("expected hello world service, got %#v", service) - } - if service.BaseURL == nil || service.BaseURL.Scheme != "https" { - t.Fatalf("expected hello world base URL to be https, got %#v", service.BaseURL) - } - if !service.OriginRequest.NoTLSVerify { - t.Fatal("expected hello world to force no_tls_verify") - } - if requestURL == "" || requestURL[:8] != "https://" { - t.Fatalf("expected https request URL, got %s", requestURL) - } -} diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index 72aa3aeabf..17401cd7dd 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -58,11 +58,8 @@ func TestApplyHTTPTransportProxy(t *testing.T) { func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) { inbound := &Inbound{} transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{ - Kind: ResolvedServiceHelloWorld, - BaseURL: &url.URL{ - Scheme: "http", - Host: "127.0.0.1:8080", - }, + Kind: ResolvedServiceUnix, + UnixPath: "/tmp/test.sock", OriginRequest: OriginRequestConfig{ NoHappyEyeballs: true, }, diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 5c5ae88243..ef8c50495e 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -35,7 +35,6 @@ const ( ResolvedServiceHTTP ResolvedServiceKind = iota ResolvedServiceStream ResolvedServiceStatus - ResolvedServiceHelloWorld ResolvedServiceUnix ResolvedServiceUnixTLS ResolvedServiceBastion @@ -70,20 +69,6 @@ func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) { originURL.RawQuery = requestParsed.RawQuery originURL.Fragment = requestParsed.Fragment return originURL.String(), nil - case ResolvedServiceHelloWorld: - if s.BaseURL == nil { - return "", E.New("hello world service is unavailable") - } - requestParsed, err := url.Parse(requestURL) - if err != nil { - return "", err - } - originURL := *s.BaseURL - originURL.Path = requestParsed.Path - originURL.RawPath = requestParsed.RawPath - originURL.RawQuery = requestParsed.RawQuery - originURL.Fragment = requestParsed.Fragment - return originURL.String(), nil default: return requestURL, nil } @@ -413,11 +398,7 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) OriginRequest: originRequest, }, nil case rawService == "hello_world" || rawService == "hello-world": - return ResolvedService{ - Kind: ResolvedServiceHelloWorld, - Service: rawService, - OriginRequest: originRequest, - }, nil + return ResolvedService{}, E.New("unsupported ingress service: hello_world") case rawService == "bastion": return ResolvedService{ Kind: ResolvedServiceBastion, From 4497f6132341369c14aa8b7d50a35476caab9505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 15:52:43 +0800 Subject: [PATCH 34/41] Fix cloudflared test and protocol parity --- protocol/cloudflare/connection_http2.go | 10 +++- protocol/cloudflare/dispatch.go | 33 +++++++++--- protocol/cloudflare/edge_discovery_test.go | 22 ++++++++ protocol/cloudflare/flow_limiter_test.go | 52 +++++++++++++++++++ protocol/cloudflare/helpers_test.go | 2 + protocol/cloudflare/inbound.go | 14 +++++ protocol/cloudflare/origin_request_test.go | 60 ++++++++++++++++++++++ protocol/cloudflare/stream.go | 18 +++++++ 8 files changed, 203 insertions(+), 8 deletions(-) diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index ef7f460359..daa5cfd909 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -24,7 +24,9 @@ import ( ) const ( - h2EdgeSNI = "h2.cftunnel.com" + h2EdgeSNI = "h2.cftunnel.com" + h2ResponseMetaCloudflared = `{"src":"cloudflared"}` + h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}` ) // HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge. @@ -357,7 +359,11 @@ func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Meta w.headersSent = true if responseError != nil { - w.writer.Header().Set(h2HeaderResponseMeta, `{"src":"cloudflared"}`) + if hasFlowConnectRateLimited(metadata) { + w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited) + } else { + w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared) + } w.writer.WriteHeader(http.StatusBadGateway) w.flusher.Flush() return nil diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 07b5b020d6..cfadf00c77 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -182,7 +182,7 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser if !i.flowLimiter.Acquire(limit) { err := E.New("too many active flows") i.logger.ErrorContext(ctx, err) - respWriter.WriteResponse(err, nil) + respWriter.WriteResponse(err, flowConnectRateLimitedMetadata()) return } defer i.flowLimiter.Release(limit) @@ -341,7 +341,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, return } - httpRequest = applyOriginRequest(httpRequest, service.OriginRequest) + httpRequest = normalizeOriginRequest(request.Type, httpRequest, service.OriginRequest) requestCtx := httpRequest.Context() if service.OriginRequest.ConnectTimeout > 0 { var cancel context.CancelFunc @@ -489,12 +489,33 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig request.Header.Set("X-Forwarded-Host", request.Host) request.Host = originRequest.HTTPHostHeader } - if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" { - if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil { - request.ContentLength = contentLength - request.TransferEncoding = nil + return request +} + +func normalizeOriginRequest(connectType ConnectionType, request *http.Request, originRequest OriginRequestConfig) *http.Request { + request = applyOriginRequest(request, originRequest) + + switch connectType { + case ConnectionTypeWebsocket: + request.Header.Set("Connection", "Upgrade") + request.Header.Set("Upgrade", "websocket") + request.Header.Set("Sec-Websocket-Version", "13") + request.ContentLength = 0 + request.Body = nil + default: + if originRequest.DisableChunkedEncoding { + request.TransferEncoding = []string{"gzip", "deflate"} + if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil { + request.ContentLength = contentLength + } } + request.Header.Set("Connection", "keep-alive") + } + + if _, exists := request.Header["User-Agent"]; !exists { + request.Header.Set("User-Agent", "") } + return request } diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index f3e6e8df54..a970a8df8b 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -124,3 +124,25 @@ func TestRotateEdgeAddrIndex(t *testing.T) { t.Fatalf("expected single-address pool to stay at 0, got %d", got) } } + +func TestEffectiveHAConnections(t *testing.T) { + tests := []struct { + name string + requested int + available int + expected int + }{ + {name: "requested below available", requested: 2, available: 4, expected: 2}, + {name: "requested equals available", requested: 4, available: 4, expected: 4}, + {name: "requested above available", requested: 5, available: 3, expected: 3}, + {name: "no available edges", requested: 4, available: 0, expected: 0}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if actual := effectiveHAConnections(testCase.requested, testCase.available); actual != testCase.expected { + t.Fatalf("effectiveHAConnections(%d, %d) = %d, want %d", testCase.requested, testCase.available, actual, testCase.expected) + } + }) + } +} diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index f4d24ee0a3..199fdfe18c 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -6,6 +6,8 @@ import ( "context" "encoding/binary" "net" + "net/http" + "net/http/httptest" "testing" "github.com/sagernet/sing-box/adapter" @@ -17,6 +19,17 @@ import ( "github.com/google/uuid" ) +type captureConnectMetadataWriter struct { + err error + metadata []Metadata +} + +func (w *captureConnectMetadataWriter) WriteResponse(responseError error, metadata []Metadata) error { + w.err = responseError + w.metadata = append([]Metadata(nil), metadata...) + return nil +} + func newLimitedInbound(t *testing.T, limit uint64) *Inbound { t.Helper() logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}}) @@ -56,6 +69,45 @@ func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) { } } +func TestHandleTCPStreamRateLimitMetadata(t *testing.T) { + inboundInstance := newLimitedInbound(t, 1) + if !inboundInstance.flowLimiter.Acquire(1) { + t.Fatal("failed to pre-acquire limiter") + } + + stream, peer := net.Pipe() + defer stream.Close() + defer peer.Close() + + respWriter := &captureConnectMetadataWriter{} + inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{}) + if respWriter.err == nil { + t.Fatal("expected too many active flows error") + } + if !hasFlowConnectRateLimited(respWriter.metadata) { + t.Fatal("expected flow rate limit metadata") + } +} + +func TestHTTP2ResponseWriterFlowRateLimitedMeta(t *testing.T) { + recorder := httptest.NewRecorder() + writer := &http2ResponseWriter{ + writer: recorder, + flusher: recorder, + } + + err := writer.WriteResponse(context.DeadlineExceeded, flowConnectRateLimitedMetadata()) + if err != nil { + t.Fatal(err) + } + if recorder.Code != http.StatusBadGateway { + t.Fatalf("expected %d, got %d", http.StatusBadGateway, recorder.Code) + } + if meta := recorder.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflaredLimited { + t.Fatalf("unexpected response meta: %q", meta) + } +} + func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) { inboundInstance := newLimitedInbound(t, 1) if !inboundInstance.flowLimiter.Acquire(1) { diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 253eed5cf8..fa05ca7978 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -202,6 +202,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), datagramV3Manager: NewDatagramV3SessionManager(), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), controlDialer: N.SystemDialer, accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index a49d478bd3..442f700834 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -162,6 +162,10 @@ func (i *Inbound) Start(stage adapter.StartStage) error { if len(edgeAddrs) == 0 { return E.New("no edge addresses available") } + if cappedHAConnections := effectiveHAConnections(i.haConnections, len(edgeAddrs)); cappedHAConnections != i.haConnections { + i.logger.Info("requested ", i.haConnections, " HA connections but only ", cappedHAConnections, " edge addresses are available") + i.haConnections = cappedHAConnections + } i.datagramVersion = resolveDatagramVersion(i.ctx, i.credentials.AccountTag, i.datagramVersion) features := DefaultFeatures(i.datagramVersion) @@ -385,6 +389,16 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { return result } +func effectiveHAConnections(requested, available int) int { + if available <= 0 { + return 0 + } + if requested > available { + return available + } + return requested +} + func parseToken(token string) (Credentials, error) { data, err := base64.StdEncoding.DecodeString(token) if err != nil { diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index 17401cd7dd..c00807ff83 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -3,8 +3,10 @@ package cloudflare import ( + "io" "net/http" "net/url" + "strings" "testing" ) @@ -75,3 +77,61 @@ func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) { t.Fatal("expected custom direct dial context") } } + +func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody) + if err != nil { + t.Fatal(err) + } + + request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{}) + if connection := request.Header.Get("Connection"); connection != "keep-alive" { + t.Fatalf("expected keep-alive connection header, got %q", connection) + } + if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" { + t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"]) + } +} + +func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) { + request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload")) + if err != nil { + t.Fatal(err) + } + request.TransferEncoding = []string{"chunked"} + request.Header.Set("Content-Length", "7") + + request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{ + DisableChunkedEncoding: true, + }) + if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" { + t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding) + } + if request.ContentLength != 7 { + t.Fatalf("expected content length 7, got %d", request.ContentLength) + } +} + +func TestNormalizeOriginRequestWebsocket(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload"))) + if err != nil { + t.Fatal(err) + } + + request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{}) + if connection := request.Header.Get("Connection"); connection != "Upgrade" { + t.Fatalf("expected websocket connection header, got %q", connection) + } + if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" { + t.Fatalf("expected websocket upgrade header, got %q", upgrade) + } + if version := request.Header.Get("Sec-Websocket-Version"); version != "13" { + t.Fatalf("expected websocket version 13, got %q", version) + } + if request.ContentLength != 0 { + t.Fatalf("expected websocket content length 0, got %d", request.ContentLength) + } + if request.Body != nil { + t.Fatal("expected websocket body to be nil") + } +} diff --git a/protocol/cloudflare/stream.go b/protocol/cloudflare/stream.go index 62003d2334..4da783fc13 100644 --- a/protocol/cloudflare/stream.go +++ b/protocol/cloudflare/stream.go @@ -31,6 +31,8 @@ const ( StreamTypeRPC ) +const metadataFlowConnectRateLimited = "FlowConnectRateLimited" + // ConnectionType indicates the proxied connection type within a data stream. type ConnectionType uint16 @@ -59,6 +61,22 @@ type Metadata struct { Val string `capnp:"val"` } +func flowConnectRateLimitedMetadata() []Metadata { + return []Metadata{{ + Key: metadataFlowConnectRateLimited, + Val: "true", + }} +} + +func hasFlowConnectRateLimited(metadata []Metadata) bool { + for _, entry := range metadata { + if entry.Key == metadataFlowConnectRateLimited && entry.Val == "true" { + return true + } + } + return false +} + // ConnectRequest is sent by the edge at the start of a data stream. type ConnectRequest struct { Dest string `capnp:"dest"` From 316c2559b1a0aec9df70d241062d4ef19e798484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 18:57:38 +0800 Subject: [PATCH 35/41] Fix cloudflared compatibility gaps --- protocol/cloudflare/connection_drain_test.go | 232 ++++++++++++++++ protocol/cloudflare/connection_http2.go | 111 ++++++-- protocol/cloudflare/connection_quic.go | 72 +++-- protocol/cloudflare/control.go | 12 + .../cloudflare/datagram_lifecycle_test.go | 96 +++++++ protocol/cloudflare/datagram_v3.go | 59 +++- protocol/cloudflare/datagram_v3_test.go | 85 ++++++ protocol/cloudflare/icmp.go | 255 +++++++++++++++--- protocol/cloudflare/icmp_test.go | 152 +++++++++++ protocol/cloudflare/ingress_test.go | 46 ++++ protocol/cloudflare/runtime_config.go | 16 +- 11 files changed, 1051 insertions(+), 85 deletions(-) create mode 100644 protocol/cloudflare/connection_drain_test.go diff --git a/protocol/cloudflare/connection_drain_test.go b/protocol/cloudflare/connection_drain_test.go new file mode 100644 index 0000000000..0d975a1547 --- /dev/null +++ b/protocol/cloudflare/connection_drain_test.go @@ -0,0 +1,232 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sagernet/quic-go" +) + +type stubNetConn struct { + closed chan struct{} +} + +func newStubNetConn() *stubNetConn { + return &stubNetConn{closed: make(chan struct{})} +} + +func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF } +func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil } +func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil } +func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (c *stubNetConn) SetDeadline(time.Time) error { return nil } +func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil } +func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil } + +type stubQUICConn struct { + closed chan string +} + +func newStubQUICConn() *stubQUICConn { + return &stubQUICConn{closed: make(chan string, 1)} +} + +func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") } +func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) { + return nil, errors.New("unused") +} +func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) { + return nil, errors.New("unused") +} +func (c *stubQUICConn) SendDatagram([]byte) error { return nil } +func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error { + select { + case c.closed <- reason: + default: + } + return nil +} + +type mockRegistrationClient struct { + unregisterCalled chan struct{} + closed chan struct{} +} + +func newMockRegistrationClient() *mockRegistrationClient { + return &mockRegistrationClient{ + unregisterCalled: make(chan struct{}, 1), + closed: make(chan struct{}, 1), + } +} + +func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) { + return &RegistrationResult{}, nil +} + +func (c *mockRegistrationClient) Unregister(context.Context) error { + select { + case c.unregisterCalled <- struct{}{}: + default: + } + return nil +} + +func (c *mockRegistrationClient) Close() error { + select { + case c.closed <- struct{}{}: + default: + } + return nil +} + +func closeOnce(ch chan struct{}) { + select { + case <-ch: + default: + close(ch) + } +} + +func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) { + conn := newStubNetConn() + registrationClient := newMockRegistrationClient() + connection := &HTTP2Connection{ + conn: conn, + gracePeriod: 200 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() {}, + } + connection.activeRequests.Add(1) + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-registrationClient.unregisterCalled: + case <-time.After(time.Second): + t.Fatal("expected unregister call") + } + + select { + case <-conn.closed: + t.Fatal("connection closed before active requests completed") + case <-time.After(50 * time.Millisecond): + } + + connection.activeRequests.Done() + + select { + case <-conn.closed: + case <-time.After(time.Second): + t.Fatal("expected connection close after active requests finished") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish") + } +} + +func TestHTTP2GracefulShutdownTimesOut(t *testing.T) { + conn := newStubNetConn() + registrationClient := newMockRegistrationClient() + connection := &HTTP2Connection{ + conn: conn, + gracePeriod: 50 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() {}, + } + connection.activeRequests.Add(1) + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-conn.closed: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected connection close after grace timeout") + } + + connection.activeRequests.Done() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish after request completion") + } +} + +func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) { + conn := newStubQUICConn() + registrationClient := newMockRegistrationClient() + serveCancelCalled := make(chan struct{}, 1) + connection := &QUICConnection{ + conn: conn, + gracePeriod: 80 * time.Millisecond, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCancel: func() { + select { + case serveCancelCalled <- struct{}{}: + default: + } + }, + } + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-registrationClient.unregisterCalled: + case <-time.After(time.Second): + t.Fatal("expected unregister call") + } + + select { + case <-conn.closed: + t.Fatal("connection closed before grace window elapsed") + case <-time.After(20 * time.Millisecond): + } + + select { + case reason := <-conn.closed: + if reason != "graceful shutdown" { + t.Fatalf("unexpected close reason: %q", reason) + } + case <-time.After(time.Second): + t.Fatal("expected graceful close") + } + + select { + case <-serveCancelCalled: + case <-time.After(time.Second): + t.Fatal("expected serve cancel to be called") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected graceful shutdown to finish") + } +} diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index daa5cfd909..33192806f7 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -44,12 +44,15 @@ type HTTP2Connection struct { inbound *Inbound numPreviousAttempts uint8 - registrationClient *RegistrationClient + registrationClient registrationRPCClient registrationResult *RegistrationResult controlStreamErr error - activeRequests sync.WaitGroup - closeOnce sync.Once + activeRequests sync.WaitGroup + serveCancel context.CancelFunc + registrationClose sync.Once + shutdownOnce sync.Once + closeOnce sync.Once } // NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal. @@ -106,22 +109,28 @@ func NewHTTP2Connection( // Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends. func (c *HTTP2Connection) Serve(ctx context.Context) error { + serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx)) + c.serveCancel = serveCancel + + shutdownDone := make(chan struct{}) go func() { <-ctx.Done() - c.close() + c.gracefulShutdown() + close(shutdownDone) }() c.server.ServeConn(c.conn, &http2.ServeConnOpts{ - Context: ctx, + Context: serveCtx, Handler: c, }) - if c.controlStreamErr != nil { - return c.controlStreamErr - } if ctx.Err() != nil { + <-shutdownDone return ctx.Err() } + if c.controlStreamErr != nil { + return c.controlStreamErr + } if c.registrationResult == nil { return E.New("edge connection closed before registration") } @@ -129,12 +138,15 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error { } func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream { + c.handleControlStream(r.Context(), r, w) + return + } + c.activeRequests.Add(1) defer c.activeRequests.Done() switch { - case r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream: - c.handleControlStream(r.Context(), r, w) case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket: c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket) case r.Header.Get(h2HeaderTCPSrc) != "": @@ -169,17 +181,13 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque if err != nil { c.controlStreamErr = err c.logger.Error("register connection: ", err) - if c.registrationClient != nil { - c.registrationClient.Close() - } - go c.close() + go c.forceClose() return } if err := validateRegistrationResult(result); err != nil { c.controlStreamErr = err c.logger.Error("register connection: ", err) - c.registrationClient.Close() - go c.close() + go c.forceClose() return } c.registrationResult = result @@ -189,13 +197,6 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque " (connection ", result.ConnectionID, ")") <-ctx.Done() - unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) - defer cancel() - err = c.registrationClient.Unregister(unregisterCtx) - if err != nil { - c.logger.Debug("failed to unregister: ", err) - } - c.registrationClient.Close() } func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) { @@ -280,16 +281,74 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`)) } -func (c *HTTP2Connection) close() { +func (c *HTTP2Connection) gracefulShutdown() { + c.shutdownOnce.Do(func() { + if c.registrationClient == nil || c.registrationResult == nil { + c.closeNow() + return + } + + unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod) + err := c.registrationClient.Unregister(unregisterCtx) + cancel() + if err != nil { + c.logger.Debug("failed to unregister: ", err) + } + c.closeRegistrationClient() + c.waitForActiveRequests(c.gracePeriod) + c.closeNow() + }) +} + +func (c *HTTP2Connection) forceClose() { + c.shutdownOnce.Do(func() { + c.closeNow() + }) +} + +func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) { + if timeout <= 0 { + c.activeRequests.Wait() + return + } + + done := make(chan struct{}) + go func() { + c.activeRequests.Wait() + close(done) + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-done: + case <-timer.C: + } +} + +func (c *HTTP2Connection) closeRegistrationClient() { + c.registrationClose.Do(func() { + if c.registrationClient != nil { + _ = c.registrationClient.Close() + } + }) +} + +func (c *HTTP2Connection) closeNow() { c.closeOnce.Do(func() { - c.conn.Close() + _ = c.conn.Close() + if c.serveCancel != nil { + c.serveCancel() + } + c.closeRegistrationClient() c.activeRequests.Wait() }) } // Close closes the HTTP/2 connection. func (c *HTTP2Connection) Close() error { - c.close() + c.forceClose() return nil } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index f654ee4cb8..bb935b4e32 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -48,11 +48,14 @@ type QUICConnection struct { features []string numPreviousAttempts uint8 gracePeriod time.Duration - registrationClient *RegistrationClient + registrationClient registrationRPCClient registrationResult *RegistrationResult onConnected func() - closeOnce sync.Once + serveCancel context.CancelFunc + registrationClose sync.Once + shutdownOnce sync.Once + closeOnce sync.Once } type quicConnection interface { @@ -180,22 +183,29 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error q.logger.Info("connected to ", q.registrationResult.Location, " (connection ", q.registrationResult.ConnectionID, ")") + serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx)) + q.serveCancel = serveCancel + errChan := make(chan error, 2) go func() { - errChan <- q.acceptStreams(ctx, handler) + errChan <- q.acceptStreams(serveCtx, handler) }() go func() { - errChan <- q.handleDatagrams(ctx, handler) + errChan <- q.handleDatagrams(serveCtx, handler) }() select { case <-ctx.Done(): q.gracefulShutdown() + <-errChan return ctx.Err() case err = <-errChan: - q.gracefulShutdown() + q.forceClose() + if ctx.Err() != nil { + return ctx.Err() + } return err } } @@ -285,23 +295,55 @@ func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, } func (q *QUICConnection) gracefulShutdown() { - q.closeOnce.Do(func() { + q.shutdownOnce.Do(func() { + if q.registrationClient == nil || q.registrationResult == nil { + q.closeNow("connection closed") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod) + err := q.registrationClient.Unregister(ctx) + cancel() + if err != nil { + q.logger.Debug("failed to unregister: ", err) + } + q.closeRegistrationClient() + if q.gracePeriod > 0 { + timer := time.NewTimer(q.gracePeriod) + <-timer.C + timer.Stop() + } + q.closeNow("graceful shutdown") + }) +} + +func (q *QUICConnection) forceClose() { + q.shutdownOnce.Do(func() { + q.closeNow("connection closed") + }) +} + +func (q *QUICConnection) closeRegistrationClient() { + q.registrationClose.Do(func() { if q.registrationClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod) - defer cancel() - err := q.registrationClient.Unregister(ctx) - if err != nil { - q.logger.Debug("failed to unregister: ", err) - } - q.registrationClient.Close() + _ = q.registrationClient.Close() + } + }) +} + +func (q *QUICConnection) closeNow(reason string) { + q.closeOnce.Do(func() { + if q.serveCancel != nil { + q.serveCancel() } - q.conn.CloseWithError(0, "graceful shutdown") + q.closeRegistrationClient() + _ = q.conn.CloseWithError(0, reason) }) } // Close closes the QUIC connection immediately. func (q *QUICConnection) Close() error { - q.gracefulShutdown() + q.forceClose() return nil } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index dd8b99da6e..e6a0b070f7 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -31,6 +31,18 @@ type RegistrationClient struct { transport rpc.Transport } +type registrationRPCClient interface { + RegisterConnection( + ctx context.Context, + auth TunnelAuth, + tunnelID uuid.UUID, + connIndex uint8, + options *RegistrationConnectionOptions, + ) (*RegistrationResult, error) + Unregister(ctx context.Context) error + Close() error +} + // NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. // The stream should be the first QUIC stream (control stream). func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { diff --git a/protocol/cloudflare/datagram_lifecycle_test.go b/protocol/cloudflare/datagram_lifecycle_test.go index 11a98b8bce..b08e3a7e58 100644 --- a/protocol/cloudflare/datagram_lifecycle_test.go +++ b/protocol/cloudflare/datagram_lifecycle_test.go @@ -5,11 +5,16 @@ package cloudflare import ( "context" "encoding/binary" + "io" "net" "testing" "time" "github.com/google/uuid" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type v2UnregisterCall struct { @@ -25,6 +30,43 @@ type captureV2SessionRPCClient struct { unregisterCh chan v2UnregisterCall } +type blockingPacketConn struct { + closed chan struct{} +} + +func newBlockingPacketConn() *blockingPacketConn { + return &blockingPacketConn{closed: make(chan struct{})} +} + +func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) { + <-c.closed + return M.Socksaddr{}, io.EOF +} + +func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return nil +} + +func (c *blockingPacketConn) Close() error { + closeOnce(c.closed) + return nil +} + +func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil } +func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil } + +type packetDialingRouter struct { + testRouter + packetConn N.PacketConn +} + +func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { + return r.packetConn, nil +} + func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error { c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message} return nil @@ -105,3 +147,57 @@ func TestDatagramV3RegistrationMigratesSender(t *testing.T) { session.close() } + +func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) { + packetConn := newBlockingPacketConn() + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.router = &packetDialingRouter{packetConn: packetConn} + sender1 := &captureDatagramSender{} + sender2 := &captureDatagramSender{} + muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger) + muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger) + + requestID := RequestID{} + requestID[15] = 10 + payload := make([]byte, 1+2+2+16+4) + payload[0] = 0 + binary.BigEndian.PutUint16(payload[1:3], 53) + binary.BigEndian.PutUint16(payload[3:5], 30) + copy(payload[5:21], requestID[:]) + copy(payload[21:25], []byte{127, 0, 0, 1}) + + ctx1, cancel1 := context.WithCancel(context.Background()) + muxer1.handleRegistration(ctx1, payload) + + ctx2, cancel2 := context.WithCancel(context.Background()) + muxer2.handleRegistration(ctx2, payload) + + cancel1() + time.Sleep(50 * time.Millisecond) + + session, exists := inboundInstance.datagramV3Manager.Get(requestID) + if !exists { + t.Fatal("expected session to survive old connection context cancellation") + } + + session.senderAccess.RLock() + currentSender := session.sender + session.senderAccess.RUnlock() + if currentSender != sender2 { + t.Fatal("expected migrated sender to stay active") + } + + cancel2() + + deadline := time.After(time.Second) + for { + if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists { + return + } + select { + case <-deadline: + t.Fatal("expected session to be removed after new context cancellation") + case <-time.After(10 * time.Millisecond): + } + } +} diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index 436fc5a33a..42719c5a1e 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -37,6 +37,7 @@ const ( v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22 v3PayloadHeaderLen = 1 + v3RequestIDLength // 17 v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20 + maxV3UDPPayloadLen = 1280 // V3 registration flags v3FlagIPv6 byte = 0x01 @@ -238,6 +239,10 @@ type v3Session struct { senderAccess sync.RWMutex sender DatagramSender + + contextAccess sync.RWMutex + connCtx context.Context + contextChan chan context.Context } var errTooManyActiveFlows = errors.New("too many active flows") @@ -253,11 +258,12 @@ func (m *DatagramV3SessionManager) Register( m.sessionAccess.Lock() if existing, exists := m.sessions[requestID]; exists { if existing.sender == sender { + existing.updateContext(ctx) existing.markActive() m.sessionAccess.Unlock() return existing, v3RegistrationExisting, nil } - existing.setSender(sender) + existing.migrate(sender, ctx) existing.markActive() m.sessionAccess.Unlock() return existing, v3RegistrationMigrated, nil @@ -286,14 +292,17 @@ func (m *DatagramV3SessionManager) Register( closeChan: make(chan struct{}), activeAt: time.Now(), sender: sender, + connCtx: ctx, + contextChan: make(chan context.Context, 1), } m.sessions[requestID] = session m.sessionAccess.Unlock() - sessionCtx := inbound.ctx + sessionCtx := ctx if sessionCtx == nil { sessionCtx = context.Background() } + session.connCtx = sessionCtx go session.serve(sessionCtx, limit) return session, v3RegistrationNew, nil } @@ -320,6 +329,8 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) { go s.readLoop() go s.writeLoop() + connCtx := ctx + tickInterval := s.closeAfterIdle / 2 if tickInterval <= 0 || tickInterval > 10*time.Second { tickInterval = time.Second @@ -329,8 +340,16 @@ func (s *v3Session) serve(ctx context.Context, limit uint64) { for { select { - case <-ctx.Done(): + case <-connCtx.Done(): + if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx { + connCtx = latestCtx + continue + } s.close() + case newCtx := <-s.contextChan: + if newCtx != nil { + connCtx = newCtx + } case <-ticker.C: if time.Since(s.lastActive()) >= s.closeAfterIdle { s.close() @@ -350,6 +369,11 @@ func (s *v3Session) readLoop() { s.close() return } + if buffer.Len() > maxV3UDPPayloadLen { + s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len()) + buffer.Release() + continue + } s.markActive() if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil { buffer.Release() @@ -403,6 +427,35 @@ func (s *v3Session) setSender(sender DatagramSender) { s.senderAccess.Unlock() } +func (s *v3Session) updateContext(ctx context.Context) { + if ctx == nil { + return + } + s.contextAccess.Lock() + s.connCtx = ctx + s.contextAccess.Unlock() + select { + case s.contextChan <- ctx: + default: + select { + case <-s.contextChan: + default: + } + s.contextChan <- ctx + } +} + +func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) { + s.setSender(sender) + s.updateContext(ctx) +} + +func (s *v3Session) currentContext() context.Context { + s.contextAccess.RLock() + defer s.contextAccess.RUnlock() + return s.connCtx +} + func (s *v3Session) markActive() { s.activeAccess.Lock() s.activeAt = time.Now() diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index 87f9148dd5..22f7cc5859 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -5,10 +5,18 @@ package cloudflare import ( "context" "encoding/binary" + "errors" + "io" + "net" + "net/netip" "testing" + "time" "github.com/sagernet/sing-box/adapter/inbound" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" ) func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) { @@ -64,3 +72,80 @@ func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) { t.Fatalf("unexpected datagram response: %v", sender.sent[0]) } } + +type scriptedPacketConn struct { + reads [][]byte + index int +} + +func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + if c.index >= len(c.reads) { + return M.Socksaddr{}, io.EOF + } + _, err := buffer.Write(c.reads[c.index]) + c.index++ + return M.Socksaddr{}, err +} + +func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return nil +} + +func (c *scriptedPacketConn) Close() error { return nil } +func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil } +func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil } + +type sizeLimitedSender struct { + sent [][]byte + max int +} + +func (s *sizeLimitedSender) SendDatagram(data []byte) error { + if len(data) > s.max { + return errors.New("datagram too large") + } + s.sent = append(s.sent, append([]byte(nil), data...)) + return nil +} + +func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) { + logger := log.NewNOPFactory().NewLogger("test") + sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen} + session := &v3Session{ + id: RequestID{}, + destination: netip.MustParseAddrPort("127.0.0.1:53"), + origin: &scriptedPacketConn{reads: [][]byte{ + make([]byte, maxV3UDPPayloadLen+1), + []byte("ok"), + }}, + inbound: &Inbound{ + logger: logger, + }, + writeChan: make(chan []byte, 1), + closeChan: make(chan struct{}), + contextChan: make(chan context.Context, 1), + sender: sender, + } + + done := make(chan struct{}) + go func() { + session.readLoop() + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected read loop to finish") + } + + if len(sender.sent) != 1 { + t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent)) + } + if len(sender.sent[0]) != v3PayloadHeaderLen+2 { + t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0])) + } +} diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go index 1070a2d835..088fd4159f 100644 --- a/protocol/cloudflare/icmp.go +++ b/protocol/cloudflare/icmp.go @@ -20,6 +20,15 @@ import ( const ( icmpFlowTimeout = 30 * time.Second icmpTraceIdentityLength = 16 + 8 + 1 + defaultICMPPacketTTL = 64 + icmpErrorHeaderLen = 8 + + icmpv4TypeEchoRequest = 8 + icmpv4TypeEchoReply = 0 + icmpv4TypeTimeExceeded = 11 + icmpv6TypeEchoRequest = 128 + icmpv6TypeEchoReply = 129 + icmpv6TypeTimeExceeded = 3 ) type ICMPTraceContext struct { @@ -40,15 +49,18 @@ type ICMPRequestKey struct { } type ICMPPacketInfo struct { - IPVersion uint8 - Protocol uint8 - SourceIP netip.Addr - Destination netip.Addr - ICMPType uint8 - ICMPCode uint8 - Identifier uint16 - Sequence uint16 - RawPacket []byte + IPVersion uint8 + Protocol uint8 + SourceIP netip.Addr + Destination netip.Addr + ICMPType uint8 + ICMPCode uint8 + Identifier uint16 + Sequence uint16 + IPv4HeaderLen int + IPv4TTL uint8 + IPv6HopLimit uint8 + RawPacket []byte } func (i ICMPPacketInfo) FlowKey() ICMPFlowKey { @@ -82,9 +94,9 @@ func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey { func (i ICMPPacketInfo) IsEchoRequest() bool { switch i.IPVersion { case 4: - return i.ICMPType == 8 && i.ICMPCode == 0 + return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0 case 6: - return i.ICMPType == 128 && i.ICMPCode == 0 + return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0 default: return false } @@ -93,14 +105,47 @@ func (i ICMPPacketInfo) IsEchoRequest() bool { func (i ICMPPacketInfo) IsEchoReply() bool { switch i.IPVersion { case 4: - return i.ICMPType == 0 && i.ICMPCode == 0 + return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0 case 6: - return i.ICMPType == 129 && i.ICMPCode == 0 + return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0 default: return false } } +func (i ICMPPacketInfo) TTL() uint8 { + if i.IPVersion == 4 { + return i.IPv4TTL + } + return i.IPv6HopLimit +} + +func (i ICMPPacketInfo) TTLExpired() bool { + return i.TTL() <= 1 +} + +func (i *ICMPPacketInfo) DecrementTTL() error { + switch i.IPVersion { + case 4: + if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen { + return E.New("invalid IPv4 packet TTL state") + } + i.IPv4TTL-- + i.RawPacket[8] = i.IPv4TTL + binary.BigEndian.PutUint16(i.RawPacket[10:12], 0) + binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0)) + case 6: + if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 { + return E.New("invalid IPv6 packet hop limit state") + } + i.IPv6HopLimit-- + i.RawPacket[7] = i.IPv6HopLimit + default: + return E.New("unsupported IP version: ", i.IPVersion) + } + return nil +} + type icmpWireVersion uint8 const ( @@ -154,15 +199,7 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error { } w.access.Unlock() - var datagram []byte - switch w.wireVersion { - case icmpWireV2: - datagram, err = encodeV2ICMPDatagram(packetInfo.RawPacket, traceContext) - case icmpWireV3: - datagram = encodeV3ICMPDatagram(packetInfo.RawPacket) - default: - err = E.New("unsupported icmp wire version: ", w.wireVersion) - } + datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext) if err != nil { return err } @@ -218,6 +255,21 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont if !packetInfo.IsEchoRequest() { return nil } + if packetInfo.TTLExpired() { + ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext)) + if err != nil { + return err + } + datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext) + if err != nil { + return err + } + return b.sender.SendDatagram(datagram) + } + + if err := packetInfo.DecrementTTL(); err != nil { + return err + } state := b.getFlowState(packetInfo.FlowKey()) if traceContext.Traced { @@ -294,15 +346,17 @@ func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) { return ICMPPacketInfo{}, E.New("invalid IPv4 destination address") } return ICMPPacketInfo{ - IPVersion: 4, - Protocol: 1, - SourceIP: sourceIP, - Destination: destinationIP, - ICMPType: packet[headerLen], - ICMPCode: packet[headerLen+1], - Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), - Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), - RawPacket: append([]byte(nil), packet...), + IPVersion: 4, + Protocol: 1, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[headerLen], + ICMPCode: packet[headerLen+1], + Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]), + Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]), + IPv4HeaderLen: headerLen, + IPv4TTL: packet[8], + RawPacket: append([]byte(nil), packet...), }, nil } @@ -322,18 +376,139 @@ func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) { return ICMPPacketInfo{}, E.New("invalid IPv6 destination address") } return ICMPPacketInfo{ - IPVersion: 6, - Protocol: 58, - SourceIP: sourceIP, - Destination: destinationIP, - ICMPType: packet[40], - ICMPCode: packet[41], - Identifier: binary.BigEndian.Uint16(packet[44:46]), - Sequence: binary.BigEndian.Uint16(packet[46:48]), - RawPacket: append([]byte(nil), packet...), + IPVersion: 6, + Protocol: 58, + SourceIP: sourceIP, + Destination: destinationIP, + ICMPType: packet[40], + ICMPCode: packet[41], + Identifier: binary.BigEndian.Uint16(packet[44:46]), + Sequence: binary.BigEndian.Uint16(packet[46:48]), + IPv6HopLimit: packet[7], + RawPacket: append([]byte(nil), packet...), }, nil } +func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int { + limit := maxV3UDPPayloadLen + switch wireVersion { + case icmpWireV2: + limit -= typeIDLength + if traceContext.Traced { + limit -= len(traceContext.Identity) + } + case icmpWireV3: + limit -= 1 + default: + return 0 + } + if limit < 0 { + return 0 + } + return limit +} + +func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + switch packetInfo.IPVersion { + case 4: + return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen) + case 6: + return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen) + default: + return nil, E.New("unsupported IP version: ", packetInfo.IPVersion) + } +} + +func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + const headerLen = 20 + if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() { + return nil, E.New("TTL exceeded packet requires IPv4 addresses") + } + if maxPacketLen <= headerLen+icmpErrorHeaderLen { + return nil, E.New("TTL exceeded packet size limit is too small") + } + + quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) + packet[0] = 0x45 + binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet))) + packet[8] = defaultICMPPacketTTL + packet[9] = 1 + copy(packet[12:16], packetInfo.Destination.AsSlice()) + copy(packet[16:20], packetInfo.SourceIP.AsSlice()) + packet[20] = icmpv4TypeTimeExceeded + packet[21] = 0 + copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) + binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0)) + binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0)) + return packet, nil +} + +func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { + const headerLen = 40 + if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() { + return nil, E.New("TTL exceeded packet requires IPv6 addresses") + } + if maxPacketLen <= headerLen+icmpErrorHeaderLen { + return nil, E.New("TTL exceeded packet size limit is too small") + } + + quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) + packet[0] = 0x60 + binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength)) + packet[6] = 58 + packet[7] = defaultICMPPacketTTL + copy(packet[8:24], packetInfo.Destination.AsSlice()) + copy(packet[24:40], packetInfo.SourceIP.AsSlice()) + packet[40] = icmpv6TypeTimeExceeded + packet[41] = 0 + copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength]) + binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58))) + return packet, nil +} + +func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) { + switch wireVersion { + case icmpWireV2: + return encodeV2ICMPDatagram(packet, traceContext) + case icmpWireV3: + return encodeV3ICMPDatagram(packet), nil + default: + return nil, E.New("unsupported icmp wire version: ", wireVersion) + } +} + +func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 { + var sum uint32 + sum = checksumSum(source.AsSlice(), sum) + sum = checksumSum(destination.AsSlice(), sum) + var lengthBytes [4]byte + binary.BigEndian.PutUint32(lengthBytes[:], payloadLength) + sum = checksumSum(lengthBytes[:], sum) + sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum) + return sum +} + +func checksumSum(data []byte, sum uint32) uint32 { + for len(data) >= 2 { + sum += uint32(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + } + if len(data) == 1 { + sum += uint32(data[0]) << 8 + } + return sum +} + +func checksum(data []byte, initial uint32) uint16 { + sum := checksumSum(data, initial) + for sum > 0xffff { + sum = (sum >> 16) + (sum & 0xffff) + } + return ^uint16(sum) +} + func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) { if traceContext.Traced { data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1) diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go index 9557fa16f6..aeecf5751a 100644 --- a/protocol/cloudflare/icmp_test.go +++ b/protocol/cloudflare/icmp_test.go @@ -169,6 +169,158 @@ func TestICMPBridgeHandleV3Reply(t *testing.T) { } } +func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) { + var destination *fakeDirectRouteDestination + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + destination = &fakeDirectRouteDestination{routeContext: routeContext} + return destination, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2) + + packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1) + packet[8] = 5 + + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil { + t.Fatal(err) + } + if len(destination.packets) != 1 { + t.Fatalf("expected one routed packet, got %d", len(destination.packets)) + } + if got := destination.packets[0][8]; got != 4 { + t.Fatalf("expected decremented IPv4 TTL, got %d", got) + } +} + +func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) { + var destination *fakeDirectRouteDestination + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + destination = &fakeDirectRouteDestination{routeContext: routeContext} + return destination, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3) + + packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1) + packet[7] = 3 + + if err := bridge.HandleV3(context.Background(), packet); err != nil { + t.Fatal(err) + } + if len(destination.packets) != 1 { + t.Fatalf("expected one routed packet, got %d", len(destination.packets)) + } + if got := destination.packets[0][7]; got != 2 { + t.Fatalf("expected decremented IPv6 hop limit, got %d", got) + } +} + +func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) { + var preMatchCalls int + traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength) + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + return nil, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2) + + source := netip.MustParseAddr("198.18.0.2") + target := netip.MustParseAddr("1.1.1.1") + packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1) + packet[8] = 1 + packet = append(packet, traceIdentity...) + + if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil { + t.Fatal(err) + } + if preMatchCalls != 0 { + t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent)) + } + reply := sender.sent[0] + if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) { + t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1]) + } + gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1] + if !bytes.Equal(gotIdentity, traceIdentity) { + t.Fatalf("unexpected trace identity: %x", gotIdentity) + } + rawReply := reply[:len(reply)-1-icmpTraceIdentityLength] + packetInfo, err := ParseICMPPacket(rawReply) + if err != nil { + t.Fatal(err) + } + if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 { + t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode) + } + if packetInfo.SourceIP != target || packetInfo.Destination != source { + t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) + } +} + +func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) { + var preMatchCalls int + sender := &captureDatagramSender{} + router := &testRouter{ + preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) { + preMatchCalls++ + return nil, nil + }, + } + inboundInstance := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + router: router, + } + bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3) + + source := netip.MustParseAddr("2001:db8::2") + target := netip.MustParseAddr("2606:4700:4700::1111") + packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1) + packet[7] = 1 + + if err := bridge.HandleV3(context.Background(), packet); err != nil { + t.Fatal(err) + } + if preMatchCalls != 0 { + t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent)) + } + if sender.sent[0][0] != byte(DatagramV3TypeICMP) { + t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0]) + } + packetInfo, err := ParseICMPPacket(sender.sent[0][1:]) + if err != nil { + t.Fatal(err) + } + if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 { + t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode) + } + if packetInfo.SourceIP != target || packetInfo.Destination != source { + t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) + } +} + func TestICMPBridgeDropsNonEcho(t *testing.T) { var preMatchCalls int router := &testRouter{ diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 5ff2db3db2..61111ab77d 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -162,3 +162,49 @@ func TestResolveHTTPServiceStatus(t *testing.T) { t.Fatalf("status service should keep request URL, got %s", requestURL) } } + +func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) { + testCases := []struct { + rawService string + wantScheme string + }{ + {rawService: "ws://127.0.0.1:8080", wantScheme: "http"}, + {rawService: "wss://127.0.0.1:8443", wantScheme: "https"}, + } + + for _, testCase := range testCases { + t.Run(testCase.rawService, func(t *testing.T) { + service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.BaseURL == nil { + t.Fatal("expected base URL") + } + if service.BaseURL.Scheme != testCase.wantScheme { + t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme) + } + if service.Service != testCase.rawService { + t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service) + } + }) + } +} + +func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) { + inboundInstance := newTestIngressInbound(t) + inboundInstance.configManager.activeConfig = RuntimeConfig{ + Ingress: []compiledIngressRule{ + {Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")}, + {Service: mustResolvedService(t, "http_status:404")}, + }, + } + + _, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1") + if err != nil { + t.Fatal(err) + } + if requestURL != "http://127.0.0.1:8083/path?q=1" { + t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL) + } +} diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index ef8c50495e..5b0c8d9edf 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -74,6 +74,20 @@ func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) { } } +func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL { + if parsedURL == nil { + return nil + } + canonicalURL := *parsedURL + switch canonicalURL.Scheme { + case "ws": + canonicalURL.Scheme = "http" + case "wss": + canonicalURL.Scheme = "https" + } + return &canonicalURL +} + type compiledIngressRule struct { Hostname string PunycodeHostname string @@ -451,7 +465,7 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) Kind: ResolvedServiceHTTP, Service: rawService, Destination: parseServiceDestination(parsedURL), - BaseURL: parsedURL, + BaseURL: canonicalizeHTTPOriginURL(parsedURL), OriginRequest: originRequest, }, nil case "tcp", "ssh", "rdp", "smb": From 3eb626581f38123a68462a9d656683f8247c9118 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 21:42:15 +0800 Subject: [PATCH 36/41] Fix cloudflared parity gaps --- protocol/cloudflare/connection_quic.go | 7 + protocol/cloudflare/dispatch.go | 103 +++++++---- protocol/cloudflare/features.go | 82 ++++++++- protocol/cloudflare/features_test.go | 109 ++++++++++- protocol/cloudflare/helpers_test.go | 1 + protocol/cloudflare/inbound.go | 31 +++- protocol/cloudflare/integration_test.go | 17 +- protocol/cloudflare/origin_request_test.go | 191 ++++++++++++++++++-- protocol/cloudflare/request_builder_test.go | 78 ++++++++ 9 files changed, 542 insertions(+), 77 deletions(-) create mode 100644 protocol/cloudflare/request_builder_test.go diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index bb935b4e32..4c5256d09d 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -45,6 +45,7 @@ type QUICConnection struct { connIndex uint8 credentials Credentials connectorID uuid.UUID + datagramVersion string features []string numPreviousAttempts uint8 gracePeriod time.Duration @@ -90,6 +91,7 @@ func NewQUICConnection( connIndex uint8, credentials Credentials, connectorID uuid.UUID, + datagramVersion string, features []string, numPreviousAttempts uint8, gracePeriod time.Duration, @@ -136,6 +138,7 @@ func NewQUICConnection( connIndex: connIndex, credentials: credentials, connectorID: connectorID, + datagramVersion: datagramVersion, features: features, numPreviousAttempts: numPreviousAttempts, gracePeriod: gracePeriod, @@ -281,6 +284,10 @@ func (q *QUICConnection) SendDatagram(data []byte) error { return q.conn.SendDatagram(data) } +func (q *QUICConnection) DatagramVersion() string { + return q.datagramVersion +} + func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error) { stream, err := q.conn.OpenStream() if err != nil { diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index cfadf00c77..f3f92ddcab 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -5,7 +5,6 @@ package cloudflare import ( "context" "crypto/tls" - "crypto/x509" "io" "net" "net/http" @@ -32,6 +31,11 @@ const ( metadataHTTPStatus = "HttpStatus" ) +var ( + loadOriginCABasePool = cloudflareRootCertPool + readOriginCAFile = os.ReadFile +) + // ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2. type ConnectResponseWriter interface { // WriteResponse sends the connect response (ack or error) with optional metadata. @@ -71,7 +75,7 @@ func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadW // HandleDatagram handles an incoming QUIC datagram. func (i *Inbound) HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) { - switch i.datagramVersion { + switch datagramVersionForSender(sender) { case "v3": muxer := i.getOrCreateV3Muxer(sender) muxer.HandleDatagram(ctx, datagram) @@ -291,7 +295,12 @@ func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteClose metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) + transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) + if err != nil { + i.logger.ErrorContext(ctx, "build origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -300,7 +309,12 @@ func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWrite metadata.Network = N.NetworkTCP i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination) - transport, cleanup := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) + transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost]) + if err != nil { + i.logger.ErrorContext(ctx, "build origin transport: ", err) + respWriter.WriteResponse(err, nil) + return + } defer cleanup() i.roundTripHTTP(ctx, stream, respWriter, request, service, transport) } @@ -389,7 +403,11 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, } } -func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func()) { +func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) { + tlsConfig, err := newOriginTLSConfig(originRequest, effectiveOriginHost(originRequest, requestHost)) + if err != nil { + return nil, nil, err + } input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) transport := &http.Transport{ @@ -399,13 +417,13 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter IdleConnTimeout: originRequest.KeepAliveTimeout, MaxIdleConns: originRequest.KeepAliveConnections, MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(originRequest, requestHost), + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: tlsConfig, DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil }, } - applyHTTPTransportProxy(transport, originRequest) - return transport, cleanup + return transport, cleanup, nil } func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { @@ -416,6 +434,10 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost if service.OriginRequest.NoHappyEyeballs { dialer.FallbackDelay = -1 } + tlsConfig, err := newOriginTLSConfig(service.OriginRequest, effectiveOriginHost(service.OriginRequest, requestHost)) + if err != nil { + return nil, nil, err + } transport := &http.Transport{ DisableCompression: true, ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, @@ -423,9 +445,9 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, MaxIdleConns: service.OriginRequest.KeepAliveConnections, MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - TLSClientConfig: buildOriginTLSConfig(service.OriginRequest, requestHost), + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: tlsConfig, } - applyHTTPTransportProxy(transport, service.OriginRequest) switch service.Kind { case ResolvedServiceUnix, ResolvedServiceUnixTLS: transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { @@ -437,37 +459,34 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost return transport, func() {}, nil } -func buildOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) *tls.Config { +func effectiveOriginHost(originRequest OriginRequestConfig, requestHost string) string { + if originRequest.HTTPHostHeader != "" { + return originRequest.HTTPHostHeader + } + return requestHost +} + +func newOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) (*tls.Config, error) { + rootCAs, err := loadOriginCABasePool() + if err != nil { + return nil, E.Cause(err, "load origin root CAs") + } tlsConfig := &tls.Config{ InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec ServerName: originTLSServerName(originRequest, requestHost), + RootCAs: rootCAs, } if originRequest.CAPool == "" { - return tlsConfig + return tlsConfig, nil } - pemData, err := os.ReadFile(originRequest.CAPool) + pemData, err := readOriginCAFile(originRequest.CAPool) if err != nil { - return tlsConfig - } - pool := x509.NewCertPool() - if pool.AppendCertsFromPEM(pemData) { - tlsConfig.RootCAs = pool - } - return tlsConfig -} - -func applyHTTPTransportProxy(transport *http.Transport, originRequest OriginRequestConfig) { - if originRequest.ProxyAddress == "" || originRequest.ProxyPort == 0 { - return + return nil, E.Cause(err, "read origin ca pool") } - switch strings.ToLower(originRequest.ProxyType) { - case "", "http": - proxyURL := &url.URL{ - Scheme: "http", - Host: net.JoinHostPort(originRequest.ProxyAddress, strconv.Itoa(int(originRequest.ProxyPort))), - } - transport.Proxy = http.ProxyURL(proxyURL) + if !tlsConfig.RootCAs.AppendCertsFromPEM(pemData) { + return nil, E.New("parse origin ca pool") } + return tlsConfig, nil } func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string { @@ -591,11 +610,11 @@ func buildHTTPRequestFromMetadata(ctx context.Context, connectRequest *ConnectRe func isTransferEncodingChunked(request *http.Request) bool { for _, encoding := range request.TransferEncoding { - if strings.EqualFold(encoding, "chunked") { + if strings.Contains(strings.ToLower(encoding), "chunked") { return true } } - return false + return strings.Contains(strings.ToLower(request.Header.Get("Transfer-Encoding")), "chunked") } func encodeResponseHeaders(statusCode int, header http.Header) []Metadata { @@ -629,3 +648,19 @@ func (c *streamConn) RemoteAddr() net.Addr { return nil } func (c *streamConn) SetDeadline(_ time.Time) error { return nil } func (c *streamConn) SetReadDeadline(_ time.Time) error { return nil } func (c *streamConn) SetWriteDeadline(_ time.Time) error { return nil } + +type datagramVersionedSender interface { + DatagramVersion() string +} + +func datagramVersionForSender(sender DatagramSender) string { + versioned, ok := sender.(datagramVersionedSender) + if !ok { + return defaultDatagramVersion + } + version := versioned.DatagramVersion() + if version == "" { + return defaultDatagramVersion + } + return version +} diff --git a/protocol/cloudflare/features.go b/protocol/cloudflare/features.go index 5b26336ab5..aa64d51cec 100644 --- a/protocol/cloudflare/features.go +++ b/protocol/cloudflare/features.go @@ -7,12 +7,15 @@ import ( "encoding/json" "hash/fnv" "net" + "sync" "time" ) const ( - featureSelectorHostname = "cfd-features.argotunnel.com" - featureLookupTimeout = 10 * time.Second + featureSelectorHostname = "cfd-features.argotunnel.com" + featureLookupTimeout = 10 * time.Second + defaultDatagramVersion = "v2" + defaultFeatureRefreshInterval = time.Hour ) type cloudflaredFeaturesRecord struct { @@ -30,23 +33,84 @@ var lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) { return []byte(records[0]), nil } -func resolveDatagramVersion(ctx context.Context, accountTag string, configured string) string { +type featureSelector struct { + configured string + accountTag string + lookup func(context.Context) ([]byte, error) + refreshInterval time.Duration + currentDatagramVersion string + + access sync.RWMutex +} + +func newFeatureSelector(ctx context.Context, accountTag string, configured string) *featureSelector { + selector := &featureSelector{ + configured: configured, + accountTag: accountTag, + lookup: lookupCloudflaredFeatures, + refreshInterval: defaultFeatureRefreshInterval, + currentDatagramVersion: defaultDatagramVersion, + } if configured != "" { - return configured + selector.currentDatagramVersion = configured + return selector + } + _ = selector.refresh(ctx) + if selector.refreshInterval > 0 { + go selector.refreshLoop(ctx) } - record, err := lookupCloudflaredFeatures(ctx) + return selector +} + +func (s *featureSelector) Snapshot() (string, []string) { + if s == nil { + return defaultDatagramVersion, DefaultFeatures(defaultDatagramVersion) + } + s.access.RLock() + defer s.access.RUnlock() + return s.currentDatagramVersion, DefaultFeatures(s.currentDatagramVersion) +} + +func (s *featureSelector) refresh(ctx context.Context) error { + if s == nil || s.configured != "" { + return nil + } + record, err := s.lookup(ctx) + if err != nil { + return err + } + version, err := resolveRemoteDatagramVersion(s.accountTag, record) if err != nil { - return "v2" + return err } + s.access.Lock() + s.currentDatagramVersion = version + s.access.Unlock() + return nil +} + +func (s *featureSelector) refreshLoop(ctx context.Context) { + ticker := time.NewTicker(s.refreshInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _ = s.refresh(ctx) + } + } +} +func resolveRemoteDatagramVersion(accountTag string, record []byte) (string, error) { var features cloudflaredFeaturesRecord if err := json.Unmarshal(record, &features); err != nil { - return "v2" + return "", err } if accountEnabled(accountTag, features.DatagramV3Percentage) { - return "v3" + return "v3", nil } - return "v2" + return defaultDatagramVersion, nil } func accountEnabled(accountTag string, percentage uint32) bool { diff --git a/protocol/cloudflare/features_test.go b/protocol/cloudflare/features_test.go index 82534eb47c..560bbd1010 100644 --- a/protocol/cloudflare/features_test.go +++ b/protocol/cloudflare/features_test.go @@ -4,27 +4,116 @@ package cloudflare import ( "context" + "errors" + "slices" "testing" ) -func TestResolveDatagramVersionConfiguredWins(t *testing.T) { - version := resolveDatagramVersion(context.Background(), "account", "v3") +func TestFeatureSelectorConfiguredWins(t *testing.T) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + selector := newFeatureSelector(ctx, "account", "v3") + version, features := selector.Snapshot() if version != "v3" { t.Fatalf("expected configured version to win, got %s", version) } + if !slices.Contains(features, "support_datagram_v3_2") { + t.Fatalf("expected v3 feature list, got %#v", features) + } } -func TestResolveDatagramVersionRemoteSelection(t *testing.T) { - originalLookup := lookupCloudflaredFeatures - lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) { - return []byte(`{"dv3_2":100}`), nil +func TestFeatureSelectorInitialRemoteSelection(t *testing.T) { + selector := &featureSelector{ + accountTag: "account", + lookup: func(context.Context) ([]byte, error) { return []byte(`{"dv3_2":100}`), nil }, + currentDatagramVersion: defaultDatagramVersion, + } + + if err := selector.refresh(context.Background()); err != nil { + t.Fatal(err) } - defer func() { - lookupCloudflaredFeatures = originalLookup - }() - version := resolveDatagramVersion(context.Background(), "account", "") + version, _ := selector.Snapshot() if version != "v3" { t.Fatalf("expected auto-selected v3, got %s", version) } } + +func TestFeatureSelectorRefreshUpdatesSnapshot(t *testing.T) { + record := []byte(`{"dv3_2":0}`) + selector := &featureSelector{ + accountTag: "account", + currentDatagramVersion: defaultDatagramVersion, + lookup: func(context.Context) ([]byte, error) { + return record, nil + }, + } + + if err := selector.refresh(context.Background()); err != nil { + t.Fatal(err) + } + version, _ := selector.Snapshot() + if version != defaultDatagramVersion { + t.Fatalf("expected initial v2, got %s", version) + } + + record = []byte(`{"dv3_2":100}`) + if err := selector.refresh(context.Background()); err != nil { + t.Fatal(err) + } + version, _ = selector.Snapshot() + if version != "v3" { + t.Fatalf("expected refreshed v3, got %s", version) + } +} + +func TestFeatureSelectorRefreshFailureKeepsPreviousValue(t *testing.T) { + selector := &featureSelector{ + accountTag: "account", + currentDatagramVersion: "v3", + lookup: func(context.Context) ([]byte, error) { + return nil, errors.New("lookup failed") + }, + } + + if err := selector.refresh(context.Background()); err == nil { + t.Fatal("expected refresh failure") + } + + version, _ := selector.Snapshot() + if version != "v3" { + t.Fatalf("expected previous version to be retained, got %s", version) + } +} + +func TestInboundUsesFreshFeatureSnapshotOnRetry(t *testing.T) { + inbound := &Inbound{ + featureSelector: &featureSelector{ + accountTag: "account", + currentDatagramVersion: defaultDatagramVersion, + }, + } + + version, features := inbound.currentConnectionFeatures() + if version != defaultDatagramVersion { + t.Fatalf("expected initial v2, got %s", version) + } + if slices.Contains(features, "support_datagram_v3_2") { + t.Fatalf("unexpected v3 feature list: %#v", features) + } + + inbound.featureSelector.access.Lock() + inbound.featureSelector.currentDatagramVersion = "v3" + inbound.featureSelector.access.Unlock() + + version, features = inbound.currentConnectionFeatures() + if version != "v3" { + t.Fatalf("expected refreshed v3, got %s", version) + } + if !slices.Contains(features, "support_datagram_v3_2") { + t.Fatalf("expected v3 feature list, got %#v", features) + } +} diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index fa05ca7978..64c8718cb4 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -197,6 +197,7 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i protocol: protocol, edgeIPVersion: 0, datagramVersion: "", + featureSelector: newFeatureSelector(ctx, credentials.AccountTag, ""), gracePeriod: 5 * time.Second, configManager: configManager, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 442f700834..cd6e38ae54 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -43,6 +43,7 @@ type Inbound struct { region string edgeIPVersion int datagramVersion string + featureSelector *featureSelector gracePeriod time.Duration configManager *ConfigManager flowLimiter *FlowLimiter @@ -133,6 +134,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo region: region, edgeIPVersion: edgeIPVersion, datagramVersion: datagramVersion, + featureSelector: newFeatureSelector(inboundCtx, credentials.AccountTag, datagramVersion), gracePeriod: gracePeriod, configManager: configManager, flowLimiter: &FlowLimiter{}, @@ -167,12 +169,9 @@ func (i *Inbound) Start(stage adapter.StartStage) error { i.haConnections = cappedHAConnections } - i.datagramVersion = resolveDatagramVersion(i.ctx, i.credentials.AccountTag, i.datagramVersion) - features := DefaultFeatures(i.datagramVersion) - for connIndex := 0; connIndex < i.haConnections; connIndex++ { i.done.Add(1) - go i.superviseConnection(uint8(connIndex), edgeAddrs, features) + go i.superviseConnection(uint8(connIndex), edgeAddrs) select { case readyConnIndex := <-i.connectedNotify: if readyConnIndex != uint8(connIndex) { @@ -235,7 +234,7 @@ const ( firstConnectionReadyTimeout = 15 * time.Second ) -func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { +func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { defer i.done.Done() edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs)) @@ -248,7 +247,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } edgeAddr := edgeAddrs[edgeIndex] - err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries)) + err := i.serveConnection(connIndex, edgeAddr, uint8(retries)) if err == nil || i.ctx.Err() != nil { return } @@ -275,15 +274,16 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } } -func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { +func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, numPreviousAttempts uint8) error { protocol := i.protocol if protocol == "" { protocol = "quic" } + datagramVersion, features := i.currentConnectionFeatures() switch protocol { case "quic": - err := i.serveQUIC(connIndex, edgeAddr, features, numPreviousAttempts) + err := i.serveQUIC(connIndex, edgeAddr, datagramVersion, features, numPreviousAttempts) if err == nil || i.ctx.Err() != nil { return err } @@ -299,12 +299,12 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, features } } -func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { +func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion string, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")") connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, - i.credentials, i.connectorID, + i.credentials, i.connectorID, datagramVersion, features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() { i.notifyConnected(connIndex) }, i.logger, @@ -322,6 +322,17 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, features []stri return connection.Serve(i.ctx, i) } +func (i *Inbound) currentConnectionFeatures() (string, []string) { + if i.featureSelector != nil { + return i.featureSelector.Snapshot() + } + version := i.datagramVersion + if version == "" { + version = defaultDatagramVersion + } + return version, DefaultFeatures(version) +} + func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")") diff --git a/protocol/cloudflare/integration_test.go b/protocol/cloudflare/integration_test.go index 8d19000489..80ecd8310c 100644 --- a/protocol/cloudflare/integration_test.go +++ b/protocol/cloudflare/integration_test.go @@ -5,6 +5,7 @@ package cloudflare import ( "io" "net/http" + "strings" "testing" "time" @@ -132,7 +133,21 @@ func TestHTTPResponseCorrectness(t *testing.T) { }) t.Run("PostEcho", func(t *testing.T) { - t.Skip("POST body streaming through QUIC data streams needs further investigation") + resp, err := http.Post(testURL+"/echo", "text/plain", strings.NewReader("payload")) + if err != nil { + t.Fatal("POST /echo: ", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got ", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("read body: ", err) + } + if string(body) != "payload" { + t.Error("unexpected body: ", string(body)) + } }) } diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index c00807ff83..ec94ef1237 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -3,11 +3,23 @@ package cloudflare import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "io" + "math/big" "net/http" "net/url" + "os" "strings" "testing" + "time" + + "github.com/sagernet/sing-box/adapter" ) func TestOriginTLSServerName(t *testing.T) { @@ -30,6 +42,30 @@ func TestOriginTLSServerName(t *testing.T) { } }) + t.Run("match sni to host uses http host header", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + MatchSNIToHost: true, + }, effectiveOriginHost(OriginRequestConfig{ + HTTPHostHeader: "origin.example.com", + MatchSNIToHost: true, + }, "request.example.com")) + if serverName != "origin.example.com" { + t.Fatalf("expected origin.example.com, got %s", serverName) + } + }) + + t.Run("match sni to host strips port from http host header", func(t *testing.T) { + serverName := originTLSServerName(OriginRequestConfig{ + MatchSNIToHost: true, + }, effectiveOriginHost(OriginRequestConfig{ + HTTPHostHeader: "origin.example.com:8443", + MatchSNIToHost: true, + }, "request.example.com")) + if serverName != "origin.example.com" { + t.Fatalf("expected origin.example.com, got %s", serverName) + } + }) + t.Run("disabled match keeps empty server name", func(t *testing.T) { serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com") if serverName != "" { @@ -38,22 +74,86 @@ func TestOriginTLSServerName(t *testing.T) { }) } -func TestApplyHTTPTransportProxy(t *testing.T) { - transport := &http.Transport{} - applyHTTPTransportProxy(transport, OriginRequestConfig{ - ProxyAddress: "127.0.0.1", - ProxyPort: 8080, - ProxyType: "http", - }) - if transport.Proxy == nil { - t.Fatal("expected proxy function to be configured") +func TestNewOriginTLSConfigErrorsOnMissingCAPool(t *testing.T) { + originalBaseLoader := loadOriginCABasePool + loadOriginCABasePool = func() (*x509.CertPool, error) { + return x509.NewCertPool(), nil + } + defer func() { + loadOriginCABasePool = originalBaseLoader + }() + + _, err := newOriginTLSConfig(OriginRequestConfig{ + CAPool: "/path/does/not/exist.pem", + }, "request.example.com") + if err == nil { + t.Fatal("expected error for missing ca pool") + } +} + +func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.T) { + basePEM, baseCert := createTestCertificatePEM(t, "base") + customPEM, customCert := createTestCertificatePEM(t, "custom") + + basePool := x509.NewCertPool() + if !basePool.AppendCertsFromPEM(basePEM) { + t.Fatal("expected base cert to append") + } + + originalBaseLoader := loadOriginCABasePool + loadOriginCABasePool = func() (*x509.CertPool, error) { + return basePool, nil + } + defer func() { + loadOriginCABasePool = originalBaseLoader + }() + + caFile := writeTempPEM(t, customPEM) + tlsConfig, err := newOriginTLSConfig(OriginRequestConfig{ + CAPool: caFile, + }, "request.example.com") + if err != nil { + t.Fatal(err) + } + if tlsConfig.RootCAs == nil { + t.Fatal("expected root CA pool") + } + subjects := tlsConfig.RootCAs.Subjects() + if len(subjects) != 2 { + t.Fatalf("expected 2 subjects, got %d", len(subjects)) + } + if !containsSubject(subjects, baseCert.RawSubject) { + t.Fatal("expected base subject to remain in pool") + } + if !containsSubject(subjects, customCert.RawSubject) { + t.Fatal("expected custom subject to be appended to pool") + } +} + +func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://proxy.example.com:8080") + + inbound := &Inbound{} + transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{ + Kind: ResolvedServiceUnix, + UnixPath: "/tmp/test.sock", + OriginRequest: OriginRequestConfig{ + ProxyAddress: "127.0.0.1", + ProxyPort: 8081, + ProxyType: "http", + }, + }, "") + if err != nil { + t.Fatal(err) } + defer cleanup() + proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}) if err != nil { t.Fatal(err) } - if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:8080" { - t.Fatalf("unexpected proxy URL: %#v", proxyURL) + if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" { + t.Fatalf("expected environment proxy URL, got %#v", proxyURL) } } @@ -70,14 +170,32 @@ func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) { t.Fatal(err) } defer cleanup() - if transport.Proxy != nil { - t.Fatal("expected no proxy when proxy fields are empty") + if transport.Proxy == nil { + t.Fatal("expected proxy function to be configured from environment") } if transport.DialContext == nil { t.Fatal("expected custom direct dial context") } } +func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) { + originalBaseLoader := loadOriginCABasePool + loadOriginCABasePool = func() (*x509.CertPool, error) { + return x509.NewCertPool(), nil + } + defer func() { + loadOriginCABasePool = originalBaseLoader + }() + + inbound := &Inbound{} + _, _, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{ + CAPool: "/path/does/not/exist.pem", + }, "") + if err == nil { + t.Fatal("expected transport build error") + } +} + func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody) if err != nil { @@ -135,3 +253,50 @@ func TestNormalizeOriginRequestWebsocket(t *testing.T) { t.Fatal("expected websocket body to be nil") } } + +func createTestCertificatePEM(t *testing.T, commonName string) ([]byte, *x509.Certificate) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatal(err) + } + certificate, err := x509.ParseCertificate(der) + if err != nil { + t.Fatal(err) + } + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), certificate +} + +func writeTempPEM(t *testing.T, pemData []byte) string { + t.Helper() + path := t.TempDir() + "/ca.pem" + if err := os.WriteFile(path, pemData, 0o600); err != nil { + t.Fatal(err) + } + return path +} + +func containsSubject(subjects [][]byte, want []byte) bool { + for _, subject := range subjects { + if bytes.Equal(subject, want) { + return true + } + } + return false +} diff --git a/protocol/cloudflare/request_builder_test.go b/protocol/cloudflare/request_builder_test.go new file mode 100644 index 0000000000..c8c6d6a058 --- /dev/null +++ b/protocol/cloudflare/request_builder_test.go @@ -0,0 +1,78 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" +) + +func TestBuildHTTPRequestFromMetadataUsesNoBodyWhenLengthZeroWithoutChunked(t *testing.T) { + request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{ + Dest: "http://example.com", + Type: ConnectionTypeHTTP, + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "cf.host"}, + }, + }, io.NopCloser(bytes.NewBuffer(nil))) + if err != nil { + t.Fatal(err) + } + if request.Body != http.NoBody { + t.Fatalf("expected http.NoBody, got %#v", request.Body) + } +} + +func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingChunked(t *testing.T) { + request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{ + Dest: "http://example.com", + Type: ConnectionTypeHTTP, + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodPost}, + {Key: metadataHTTPHost, Val: "cf.host"}, + {Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "chunked"}, + }, + }, io.NopCloser(bytes.NewBufferString("payload"))) + if err != nil { + t.Fatal(err) + } + if request.Body == http.NoBody { + t.Fatal("expected request body to be preserved") + } + body, err := io.ReadAll(request.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "payload" { + t.Fatalf("unexpected body %q", body) + } +} + +func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingContainsChunked(t *testing.T) { + request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{ + Dest: "http://example.com", + Type: ConnectionTypeHTTP, + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodPost}, + {Key: metadataHTTPHost, Val: "cf.host"}, + {Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "gzip,chunked"}, + }, + }, io.NopCloser(bytes.NewBufferString("payload"))) + if err != nil { + t.Fatal(err) + } + if request.Body == http.NoBody { + t.Fatal("expected request body to be preserved") + } + body, err := io.ReadAll(request.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "payload" { + t.Fatalf("unexpected body %q", body) + } +} From c07abeeab3c7ca6054035801b976d3d19af11c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 11:30:22 +0800 Subject: [PATCH 37/41] Fix cloudflared parity regressions --- protocol/cloudflare/access.go | 16 ++- protocol/cloudflare/access_test.go | 37 ++++++ protocol/cloudflare/connection_http2.go | 7 + protocol/cloudflare/connection_quic.go | 33 ++++- protocol/cloudflare/connection_quic_test.go | 56 +++++++- protocol/cloudflare/datagram_rpc_test.go | 133 +++++++++++++++++++ protocol/cloudflare/datagram_rpc_v3.go | 73 ++++++++++ protocol/cloudflare/datagram_v2.go | 6 + protocol/cloudflare/dispatch.go | 25 +++- protocol/cloudflare/origin_request_test.go | 8 +- protocol/cloudflare/response_trailer_test.go | 91 +++++++++++++ 11 files changed, 472 insertions(+), 13 deletions(-) create mode 100644 protocol/cloudflare/datagram_rpc_test.go create mode 100644 protocol/cloudflare/datagram_rpc_v3.go create mode 100644 protocol/cloudflare/response_trailer_test.go diff --git a/protocol/cloudflare/access.go b/protocol/cloudflare/access.go index fc40e72331..f51168e213 100644 --- a/protocol/cloudflare/access.go +++ b/protocol/cloudflare/access.go @@ -56,17 +56,21 @@ func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Reques if err != nil { return err } - if len(v.audTags) == 0 { + if accessTokenAudienceAllowed(token.Audience, v.audTags) { return nil } - for _, jwtAudTag := range token.Audience { - for _, acceptedAudTag := range v.audTags { - if acceptedAudTag == jwtAudTag { - return nil + return E.New("access token audience does not match configured aud_tag") +} + +func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool { + for _, tokenAudTag := range tokenAudience { + for _, configuredAudTag := range configuredAudTags { + if configuredAudTag == tokenAudTag { + return true } } } - return E.New("access token audience does not match configured aud_tag") + return false } func accessIssuerURL(teamName string, environment string) string { diff --git a/protocol/cloudflare/access_test.go b/protocol/cloudflare/access_test.go index 3cceb155e7..5fb6fa1783 100644 --- a/protocol/cloudflare/access_test.go +++ b/protocol/cloudflare/access_test.go @@ -50,6 +50,43 @@ func TestValidateAccessConfiguration(t *testing.T) { } } +func TestAccessTokenAudienceAllowed(t *testing.T) { + testCases := []struct { + name string + tokenAudience []string + configuredTags []string + expected bool + }{ + { + name: "matching audience", + tokenAudience: []string{"aud-1", "aud-2"}, + configuredTags: []string{"aud-2"}, + expected: true, + }, + { + name: "empty configured tags rejected", + tokenAudience: []string{"aud-1"}, + configuredTags: nil, + expected: false, + }, + { + name: "non matching audience rejected", + tokenAudience: []string{"aud-1"}, + configuredTags: []string{"aud-2"}, + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags) + if allowed != testCase.expected { + t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected) + } + }) + } +} + func TestRoundTripHTTPAccessDenied(t *testing.T) { originalFactory := newAccessValidator defer func() { diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 33192806f7..56ac895e5a 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -411,6 +411,13 @@ type http2ResponseWriter struct { headersSent bool } +func (w *http2ResponseWriter) AddTrailer(name, value string) { + if !w.headersSent { + return + } + w.writer.Header().Add(http2.TrailerPrefix+name, value) +} + func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { if w.headersSent { return nil diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 4c5256d09d..e83bf82985 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -9,6 +9,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/sagernet/quic-go" @@ -262,7 +263,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, q.logger.Debug("failed to read connect request: ", err) return } - handler.HandleDataStream(ctx, rwc, request, q.connIndex) + handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex) case StreamTypeRPC: handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q) @@ -388,3 +389,33 @@ func (s *streamReadWriteCloser) Close() error { s.stream.CancelRead(0) return s.stream.Close() } + +// nopCloserReadWriter lets handlers stop consuming the read side without closing +// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior, +// where the request body can be closed before the response is fully written. +type nopCloserReadWriter struct { + io.ReadWriteCloser + + sawEOF bool + closed uint32 +} + +func (n *nopCloserReadWriter) Read(p []byte) (int, error) { + if n.sawEOF { + return 0, io.EOF + } + if atomic.LoadUint32(&n.closed) > 0 { + return 0, fmt.Errorf("closed by handler") + } + + readLen, err := n.ReadWriteCloser.Read(p) + if err == io.EOF { + n.sawEOF = true + } + return readLen, err +} + +func (n *nopCloserReadWriter) Close() error { + atomic.StoreUint32(&n.closed, 1) + return nil +} diff --git a/protocol/cloudflare/connection_quic_test.go b/protocol/cloudflare/connection_quic_test.go index ac7f58aba6..78479dad80 100644 --- a/protocol/cloudflare/connection_quic_test.go +++ b/protocol/cloudflare/connection_quic_test.go @@ -2,7 +2,11 @@ package cloudflare -import "testing" +import ( + "io" + "strings" + "testing" +) func TestQUICInitialPacketSize(t *testing.T) { testCases := []struct { @@ -23,3 +27,53 @@ func TestQUICInitialPacketSize(t *testing.T) { }) } } + +type mockReadWriteCloser struct { + reader strings.Reader + writes []byte +} + +func (m *mockReadWriteCloser) Read(p []byte) (int, error) { + return m.reader.Read(p) +} + +func (m *mockReadWriteCloser) Write(p []byte) (int, error) { + m.writes = append(m.writes, p...) + return len(p), nil +} + +func (m *mockReadWriteCloser) Close() error { + return nil +} + +func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) { + inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")} + wrapper := &nopCloserReadWriter{ReadWriteCloser: inner} + + if err := wrapper.Close(); err != nil { + t.Fatal(err) + } + + if _, err := wrapper.Read(make([]byte, 1)); err == nil { + t.Fatal("expected read to fail after close") + } + + if _, err := wrapper.Write([]byte("response")); err != nil { + t.Fatal(err) + } + if string(inner.writes) != "response" { + t.Fatalf("unexpected writes %q", inner.writes) + } +} + +func TestNOPCloserReadWriterTracksEOF(t *testing.T) { + inner := &mockReadWriteCloser{reader: *strings.NewReader("")} + wrapper := &nopCloserReadWriter{ReadWriteCloser: inner} + + if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } + if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF { + t.Fatalf("expected cached EOF, got %v", err) + } +} diff --git a/protocol/cloudflare/datagram_rpc_test.go b/protocol/cloudflare/datagram_rpc_test.go new file mode 100644 index 0000000000..08974a9cb8 --- /dev/null +++ b/protocol/cloudflare/datagram_rpc_test.go @@ -0,0 +1,133 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + + capnp "zombiezen.com/go/capnproto2" +) + +func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) { + t.Helper() + + _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg) + if err != nil { + t.Fatal(err) + } + sessionID := uuid.New() + if err := params.SetSessionId(sessionID[:]); err != nil { + t.Fatal(err) + } + if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil { + t.Fatal(err) + } + params.SetDstPort(53) + params.SetCloseAfterIdleHint(int64(30)) + if err := params.SetTraceContext(traceContext); err != nil { + t.Fatal(err) + } + + _, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg) + if err != nil { + t.Fatal(err) + } + + call := tunnelrpc.SessionManager_registerUdpSession{ + Ctx: context.Background(), + Params: params, + Results: results, + } + return call, results.Result +} + +func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession { + t.Helper() + + _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg) + if err != nil { + t.Fatal(err) + } + sessionID := uuid.New() + if err := params.SetSessionId(sessionID[:]); err != nil { + t.Fatal(err) + } + if err := params.SetMessage("close"); err != nil { + t.Fatal(err) + } + + _, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg) + if err != nil { + t.Fatal(err) + } + + return tunnelrpc.SessionManager_unregisterUdpSession{ + Ctx: context.Background(), + Params: params, + Results: results, + } +} + +func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) { + server := &cloudflaredV3Server{ + inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")}, + } + call, readResult := newRegisterUDPSessionCall(t, "trace-context") + if err := server.RegisterUdpSession(call); err != nil { + t.Fatal(err) + } + + result, err := readResult() + if err != nil { + t.Fatal(err) + } + resultErr, err := result.Err() + if err != nil { + t.Fatal(err) + } + if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() { + t.Fatalf("unexpected registration error %q", resultErr) + } + spans, err := result.Spans() + if err != nil { + t.Fatal(err) + } + if len(spans) != 0 { + t.Fatalf("expected empty spans, got %x", spans) + } +} + +func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) { + server := &cloudflaredV3Server{ + inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")}, + } + err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t)) + if err == nil { + t.Fatal("expected unsupported unregister error") + } + if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() { + t.Fatalf("unexpected unregister error %v", err) + } +} diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go new file mode 100644 index 0000000000..38af323ff7 --- /dev/null +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -0,0 +1,73 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "errors" + "io" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + E "github.com/sagernet/sing/common/exceptions" + + "zombiezen.com/go/capnproto2/rpc" +) + +var ( + errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC") + errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC") +) + +type cloudflaredV3Server struct { + inbound *Inbound + logger log.ContextLogger +} + +func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error { + result, err := call.Results.NewResult() + if err != nil { + return err + } + if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil { + return err + } + return result.SetSpans([]byte{}) +} + +func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error { + return errUnsupportedDatagramV3UDPUnregistration +} + +func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { + version := call.Params.Version() + configData, _ := call.Params.Config() + updateResult := s.inbound.ApplyConfig(version, configData) + result, err := call.Results.NewResult() + if err != nil { + return err + } + result.SetLatestAppliedVersion(updateResult.LastAppliedVersion) + if updateResult.Err != nil { + result.SetErr(updateResult.Err.Error()) + } else { + result.SetErr("") + } + return nil +} + +// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs. +func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) { + srv := &cloudflaredV3Server{ + inbound: inbound, + logger: logger, + } + client := tunnelrpc.CloudflaredServer_ServerToClient(srv) + transport := rpc.StreamTransport(stream) + rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + <-rpcConn.Done() + E.Errors( + rpcConn.Close(), + transport.Close(), + ) +} diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 9ba52f9731..8fa3ffa625 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -474,6 +474,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg destinationPort := call.Params.DstPort() closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint()) + if _, traceErr := call.Params.TraceContext(); traceErr != nil { + return traceErr + } err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle) @@ -481,6 +484,9 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg if allocErr != nil { return allocErr } + if spansErr := result.SetSpans([]byte{}); spansErr != nil { + return spansErr + } if err != nil { result.SetErr(err.Error()) } diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index f3f92ddcab..77afee784e 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -34,6 +34,7 @@ const ( var ( loadOriginCABasePool = cloudflareRootCertPool readOriginCAFile = os.ReadFile + proxyFromEnvironment = http.ProxyFromEnvironment ) // ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2. @@ -42,6 +43,10 @@ type ConnectResponseWriter interface { WriteResponse(responseError error, metadata []Metadata) error } +type connectResponseTrailerWriter interface { + AddTrailer(name, value string) +} + // quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp). type quicResponseWriter struct { stream io.Writer @@ -69,8 +74,13 @@ func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser // HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup. func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) { - muxer := i.getOrCreateV2Muxer(sender) - ServeRPCStream(ctx, stream, i, muxer, i.logger) + switch datagramVersionForSender(sender) { + case "v3": + ServeV3RPCStream(ctx, stream, i, i.logger) + default: + muxer := i.getOrCreateV2Muxer(sender) + ServeRPCStream(ctx, stream, i, muxer, i.logger) + } } // HandleDatagram handles an incoming QUIC datagram. @@ -401,6 +411,13 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, if err != nil && !E.IsClosedOrCanceled(err) { i.logger.DebugContext(ctx, "copy HTTP response body: ", err) } + if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok { + for name, values := range response.Trailer { + for _, value := range values { + trailerWriter.AddTrailer(name, value) + } + } + } } func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) { @@ -417,7 +434,7 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter IdleConnTimeout: originRequest.KeepAliveTimeout, MaxIdleConns: originRequest.KeepAliveConnections, MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - Proxy: http.ProxyFromEnvironment, + Proxy: proxyFromEnvironment, TLSClientConfig: tlsConfig, DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil @@ -445,7 +462,7 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, MaxIdleConns: service.OriginRequest.KeepAliveConnections, MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - Proxy: http.ProxyFromEnvironment, + Proxy: proxyFromEnvironment, TLSClientConfig: tlsConfig, } switch service.Kind { diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index ec94ef1237..a63c422363 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -131,7 +131,13 @@ func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing. } func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) { - t.Setenv("HTTP_PROXY", "http://proxy.example.com:8080") + originalProxyFromEnvironment := proxyFromEnvironment + proxyFromEnvironment = func(request *http.Request) (*url.URL, error) { + return url.Parse("http://proxy.example.com:8080") + } + defer func() { + proxyFromEnvironment = originalProxyFromEnvironment + }() inbound := &Inbound{} transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{ diff --git a/protocol/cloudflare/response_trailer_test.go b/protocol/cloudflare/response_trailer_test.go new file mode 100644 index 0000000000..5b833972a7 --- /dev/null +++ b/protocol/cloudflare/response_trailer_test.go @@ -0,0 +1,91 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sagernet/sing-box/log" +) + +type trailerCaptureResponseWriter struct { + status int + trailers http.Header +} + +func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error { + for _, entry := range metadata { + if entry.Key == metadataHTTPStatus { + w.status = http.StatusOK + } + } + return nil +} + +func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) { + if w.trailers == nil { + w.trailers = make(http.Header) + } + w.trailers.Add(name, value) +} + +type captureReadWriteCloser struct { + body []byte +} + +func (c *captureReadWriteCloser) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (c *captureReadWriteCloser) Write(p []byte) (int, error) { + c.body = append(c.body, p...) + return len(p), nil +} + +func (c *captureReadWriteCloser) Close() error { + return nil +} + +func TestRoundTripHTTPCopiesTrailers(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Trailer", "X-Test-Trailer") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + w.Header().Set("X-Test-Trailer", "trailer-value") + })) + defer server.Close() + + transport, ok := server.Client().Transport.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type %T", server.Client().Transport) + } + + inboundInstance := &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + } + stream := &captureReadWriteCloser{} + respWriter := &trailerCaptureResponseWriter{} + request := &ConnectRequest{ + Dest: server.URL, + Type: ConnectionTypeHTTP, + Metadata: []Metadata{ + {Key: metadataHTTPMethod, Val: http.MethodGet}, + {Key: metadataHTTPHost, Val: "example.com"}, + }, + } + + inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{ + OriginRequest: defaultOriginRequestConfig(), + }, transport) + + if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" { + t.Fatalf("expected copied trailer, got %q", got) + } + if string(stream.body) != "ok" { + t.Fatalf("unexpected response body %q", stream.body) + } +} From e1847aab638bd465baec4bdbea498f9c8140306b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 12:53:58 +0800 Subject: [PATCH 38/41] Align cloudflared stream scheme handling --- protocol/cloudflare/config_decode_test.go | 4 +- protocol/cloudflare/inbound.go | 10 ++ protocol/cloudflare/ingress_test.go | 61 +++++++++++ protocol/cloudflare/runtime_config.go | 53 +++++++--- protocol/cloudflare/special_service.go | 15 +++ protocol/cloudflare/special_service_test.go | 108 +++++++++++++++++++- 6 files changed, 229 insertions(+), 22 deletions(-) diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go index 0c6f834736..df05d64ecd 100644 --- a/protocol/cloudflare/config_decode_test.go +++ b/protocol/cloudflare/config_decode_test.go @@ -27,12 +27,12 @@ func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) { } } -func TestNormalizeProtocolAcceptsAuto(t *testing.T) { +func TestNormalizeProtocolAutoUsesTokenStyleSentinel(t *testing.T) { protocol, err := normalizeProtocol("auto") if err != nil { t.Fatal(err) } if protocol != "" { - t.Fatalf("expected auto protocol to normalize to empty string, got %q", protocol) + t.Fatalf("expected auto protocol to normalize to token-style empty sentinel, got %q", protocol) } } diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index cd6e38ae54..f405d36366 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -276,6 +276,11 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, numPreviousAttempts uint8) error { protocol := i.protocol + // An empty protocol means the user configured "auto". For the token-provided, + // remotely-managed tunnel mode we implement here, that intentionally matches + // cloudflared's token path: start with QUIC and fall back to HTTP/2 on failure. + // If we ever support non-token/local-config modes, that is where remote + // percentage-based protocol selection should be introduced. if protocol == "" { protocol = "quic" } @@ -423,6 +428,11 @@ func parseToken(token string) (Credentials, error) { return tunnelToken.ToCredentials(), nil } +// "auto" does not choose a transport here. We normalize it to an empty +// sentinel so serveConnection can apply the token-style behavior later. +// In the token-provided, remotely-managed tunnel path supported here, that +// matches cloudflared's NewProtocolSelector(..., tunnelTokenProvided=true) +// branch rather than the non-token remote-percentage selector. func normalizeProtocol(protocol string) (string, error) { if protocol == "auto" { return "", nil diff --git a/protocol/cloudflare/ingress_test.go b/protocol/cloudflare/ingress_test.go index 61111ab77d..bd18eb83d7 100644 --- a/protocol/cloudflare/ingress_test.go +++ b/protocol/cloudflare/ingress_test.go @@ -191,6 +191,67 @@ func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) { } } +func TestParseResolvedServiceGenericStreamSchemeWithoutPort(t *testing.T) { + service, err := parseResolvedService("ftp://127.0.0.1", defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.Kind != ResolvedServiceStream { + t.Fatalf("expected stream service, got %v", service.Kind) + } + if service.Destination.AddrString() != "127.0.0.1" { + t.Fatalf("expected destination host 127.0.0.1, got %s", service.Destination.AddrString()) + } + if service.Destination.Port != 0 { + t.Fatalf("expected destination port 0, got %d", service.Destination.Port) + } + if service.StreamHasPort { + t.Fatal("expected generic stream service without port to report missing port") + } +} + +func TestParseResolvedServiceGenericStreamSchemeWithPort(t *testing.T) { + service, err := parseResolvedService("ftp://127.0.0.1:21", defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.Kind != ResolvedServiceStream { + t.Fatalf("expected stream service, got %v", service.Kind) + } + if service.Destination.String() != "127.0.0.1:21" { + t.Fatalf("expected destination 127.0.0.1:21, got %s", service.Destination) + } + if !service.StreamHasPort { + t.Fatal("expected generic stream service with explicit port to be dialable") + } +} + +func TestParseResolvedServiceSSHDefaultPort(t *testing.T) { + service, err := parseResolvedService("ssh://127.0.0.1", defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.Destination.String() != "127.0.0.1:22" { + t.Fatalf("expected destination 127.0.0.1:22, got %s", service.Destination) + } + if !service.StreamHasPort { + t.Fatal("expected ssh stream service to apply default port") + } +} + +func TestParseResolvedServiceTCPDefaultPort(t *testing.T) { + service, err := parseResolvedService("tcp://127.0.0.1", defaultOriginRequestConfig()) + if err != nil { + t.Fatal(err) + } + if service.Destination.String() != "127.0.0.1:7864" { + t.Fatalf("expected destination 127.0.0.1:7864, got %s", service.Destination) + } + if !service.StreamHasPort { + t.Fatal("expected tcp stream service to apply default port") + } +} + func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) { inboundInstance := newTestIngressInbound(t) inboundInstance.configManager.activeConfig = RuntimeConfig{ diff --git a/protocol/cloudflare/runtime_config.go b/protocol/cloudflare/runtime_config.go index 5b0c8d9edf..e61d18966a 100644 --- a/protocol/cloudflare/runtime_config.go +++ b/protocol/cloudflare/runtime_config.go @@ -45,6 +45,7 @@ type ResolvedService struct { Kind ResolvedServiceKind Service string Destination M.Socksaddr + StreamHasPort bool BaseURL *url.URL UnixPath string StatusCode int @@ -88,6 +89,15 @@ func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL { return &canonicalURL } +func isHTTPServiceScheme(scheme string) bool { + switch scheme { + case "http", "https", "ws", "wss": + return true + default: + return false + } +} + type compiledIngressRule struct { Hostname string PunycodeHostname string @@ -459,35 +469,46 @@ func parseResolvedService(rawService string, originRequest OriginRequestConfig) return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService) } - switch parsedURL.Scheme { - case "http", "https", "ws", "wss": + if isHTTPServiceScheme(parsedURL.Scheme) { return ResolvedService{ Kind: ResolvedServiceHTTP, Service: rawService, - Destination: parseServiceDestination(parsedURL), + Destination: parseHTTPServiceDestination(parsedURL), BaseURL: canonicalizeHTTPOriginURL(parsedURL), OriginRequest: originRequest, }, nil - case "tcp", "ssh", "rdp", "smb": - return ResolvedService{ - Kind: ResolvedServiceStream, - Service: rawService, - Destination: parseServiceDestination(parsedURL), - BaseURL: parsedURL, - OriginRequest: originRequest, - }, nil - default: - return ResolvedService{}, E.New("unsupported ingress service scheme: ", parsedURL.Scheme) } + + destination, hasPort := parseStreamServiceDestination(parsedURL) + return ResolvedService{ + Kind: ResolvedServiceStream, + Service: rawService, + Destination: destination, + StreamHasPort: hasPort, + BaseURL: parsedURL, + OriginRequest: originRequest, + }, nil } -func parseServiceDestination(parsedURL *url.URL) M.Socksaddr { +func parseHTTPServiceDestination(parsedURL *url.URL) M.Socksaddr { host := parsedURL.Hostname() port := parsedURL.Port() if port == "" { switch parsedURL.Scheme { case "https", "wss": port = "443" + default: + port = "80" + } + } + return M.ParseSocksaddr(net.JoinHostPort(host, port)) +} + +func parseStreamServiceDestination(parsedURL *url.URL) (M.Socksaddr, bool) { + host := parsedURL.Hostname() + port := parsedURL.Port() + if port == "" { + switch parsedURL.Scheme { case "ssh": port = "22" case "rdp": @@ -497,10 +518,10 @@ func parseServiceDestination(parsedURL *url.URL) M.Socksaddr { case "tcp": port = "7864" default: - port = "80" + return M.ParseSocksaddrHostPort(host, 0), false } } - return M.ParseSocksaddr(net.JoinHostPort(host, port)) + return M.ParseSocksaddr(net.JoinHostPort(host, port)), true } func validateHostname(hostname string, isLast bool) error { diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index 6c6d142ae1..822dfea018 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -42,6 +42,10 @@ func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCl } func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + if !service.StreamHasPort { + respWriter.WriteResponse(E.New("address ", streamServiceHostname(service), ": missing port in address"), nil) + return + } i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType) } @@ -181,6 +185,17 @@ func requestHeaderValue(request *ConnectRequest, headerName string) string { return "" } +func streamServiceHostname(service ResolvedService) string { + if service.BaseURL != nil && service.BaseURL.Hostname() != "" { + return service.BaseURL.Hostname() + } + parsedURL, err := url.Parse(service.Service) + if err == nil && parsedURL.Hostname() != "" { + return parsedURL.Hostname() + } + return service.Destination.AddrString() +} + func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) { metadata := adapter.InboundContext{ Inbound: i.Tag(), diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 8c39543c29..afaa67a3d2 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "net/url" "strconv" "sync/atomic" "testing" @@ -439,8 +440,9 @@ func TestHandleStreamService(t *testing.T) { go func() { defer close(done) inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ - Kind: ResolvedServiceStream, - Destination: M.ParseSocksaddr(listener.Addr().String()), + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + StreamHasPort: true, }) }() @@ -497,8 +499,9 @@ func TestHandleStreamServiceProxyTypeSocks(t *testing.T) { go func() { defer close(done) inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ - Kind: ResolvedServiceStream, - Destination: M.ParseSocksaddr(listener.Addr().String()), + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + StreamHasPort: true, OriginRequest: OriginRequestConfig{ ProxyType: "socks", }, @@ -540,3 +543,100 @@ func TestHandleStreamServiceProxyTypeSocks(t *testing.T) { t.Fatal("socks stream service did not exit") } } + +func TestHandleStreamServiceGenericSchemeWithPort(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Service: "ftp://" + listener.Addr().String(), + Destination: M.ParseSocksaddr(listener.Addr().String()), + StreamHasPort: true, + }) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for stream service connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("generic stream service did not exit") + } +} + +func TestHandleStreamServiceGenericSchemeWithoutPort(t *testing.T) { + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + defer serverSide.Close() + + router := &countingRouter{} + inboundInstance := newSpecialServiceInboundWithRouter(t, router) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Service: "ftp://127.0.0.1", + Destination: M.ParseSocksaddrHostPort("127.0.0.1", 0), + StreamHasPort: false, + BaseURL: &url.URL{ + Scheme: "ftp", + Host: "127.0.0.1", + }, + }) + + if respWriter.err == nil { + t.Fatal("expected missing port error") + } + if respWriter.err.Error() != "address 127.0.0.1: missing port in address" { + t.Fatalf("unexpected error: %v", respWriter.err) + } + if respWriter.status == http.StatusSwitchingProtocols { + t.Fatalf("expected non-upgrade response on error, got %d", respWriter.status) + } + if router.count.Load() != 0 { + t.Fatalf("expected router not to be used, got %d", router.count.Load()) + } +} From 2edbc42629f4bbce8d100955a34eaf1752004183 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 16:48:09 +0800 Subject: [PATCH 39/41] Fix cloudflared safe stream, safe transport and split dialer --- option/cloudflared.go | 1 + protocol/cloudflare/connection_http2.go | 2 +- protocol/cloudflare/connection_quic.go | 39 +++++++-------- protocol/cloudflare/control.go | 4 +- protocol/cloudflare/datagram_rpc_v3.go | 6 +-- protocol/cloudflare/datagram_v2.go | 8 ++-- protocol/cloudflare/inbound.go | 11 ++++- protocol/cloudflare/safe_transport.go | 63 +++++++++++++++++++++++++ 8 files changed, 103 insertions(+), 31 deletions(-) create mode 100644 protocol/cloudflare/safe_transport.go diff --git a/option/cloudflared.go b/option/cloudflared.go index e597ebb77e..7daaafb9f4 100644 --- a/option/cloudflared.go +++ b/option/cloudflared.go @@ -7,6 +7,7 @@ type CloudflaredInboundOptions struct { HAConnections int `json:"ha_connections,omitempty"` Protocol string `json:"protocol,omitempty"` ControlDialer DialerOptions `json:"control_dialer,omitempty"` + TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"` EdgeIPVersion int `json:"edge_ip_version,omitempty"` DatagramVersion string `json:"datagram_version,omitempty"` GracePeriod badoption.Duration `json:"grace_period,omitempty"` diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 56ac895e5a..0b21cab067 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -78,7 +78,7 @@ func NewHTTP2Connection( ServerName: h2EdgeSNI, } - tcpConn, err := inbound.controlDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) + tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) if err != nil { return nil, E.Cause(err, "dial edge TCP") } diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index e83bf82985..32f872a56b 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -80,10 +80,6 @@ func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reaso return err } -var ( - quicPortByConnIndex = make(map[uint8]int) - quicPortAccess sync.Mutex -) // NewQUICConnection dials the edge and establishes a QUIC connection. func NewQUICConnection( @@ -96,7 +92,7 @@ func NewQUICConnection( features []string, numPreviousAttempts uint8, gracePeriod time.Duration, - controlDialer N.Dialer, + tunnelDialer N.Dialer, onConnected func(), logger log.ContextLogger, ) (*QUICConnection, error) { @@ -121,7 +117,7 @@ func NewQUICConnection( InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion), } - udpConn, err := createUDPConnForConnIndex(ctx, connIndex, edgeAddr, controlDialer) + udpConn, err := createUDPConnForConnIndex(ctx, edgeAddr, tunnelDialer) if err != nil { return nil, E.Cause(err, "listen UDP for QUIC edge") } @@ -147,11 +143,15 @@ func NewQUICConnection( }, nil } -func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *EdgeAddr, controlDialer N.Dialer) (*net.UDPConn, error) { - quicPortAccess.Lock() - defer quicPortAccess.Unlock() - - packetConn, err := controlDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) +// createUDPConnForConnIndex creates a UDP socket for QUIC via the tunnel dialer. +// Unlike cloudflared, we do not attempt to reuse previously-bound ports across +// reconnects — the dialer interface does not support specifying local ports, +// and fixed port binding is not important for our use case. +// We also do not apply Darwin-specific udp4/udp6 network selection to work around +// quic-go#3793 (DF bit on macOS dual-stack); the dialer controls network selection +// and this is a non-critical platform-specific limitation. +func createUDPConnForConnIndex(ctx context.Context, edgeAddr *EdgeAddr, tunnelDialer N.Dialer) (*net.UDPConn, error) { + packetConn, err := tunnelDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port())) if err != nil { return nil, err } @@ -160,12 +160,6 @@ func createUDPConnForConnIndex(ctx context.Context, connIndex uint8, edgeAddr *E packetConn.Close() return nil, fmt.Errorf("unexpected packet conn type %T", packetConn) } - udpAddr, ok := udpConn.LocalAddr().(*net.UDPAddr) - if !ok { - udpConn.Close() - return nil, fmt.Errorf("unexpected local UDP address type %T", udpConn.LocalAddr()) - } - quicPortByConnIndex[connIndex] = udpAddr.Port return udpConn, nil } @@ -368,9 +362,11 @@ type DatagramSender interface { SendDatagram(data []byte) error } -// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser. +// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser +// with mutex-protected writes and safe close semantics. type streamReadWriteCloser struct { - stream *quic.Stream + stream *quic.Stream + writeAccess sync.Mutex } func newStreamReadWriteCloser(stream *quic.Stream) *streamReadWriteCloser { @@ -382,10 +378,15 @@ func (s *streamReadWriteCloser) Read(p []byte) (int, error) { } func (s *streamReadWriteCloser) Write(p []byte) (int, error) { + s.writeAccess.Lock() + defer s.writeAccess.Unlock() return s.stream.Write(p) } func (s *streamReadWriteCloser) Close() error { + _ = s.stream.SetWriteDeadline(time.Now()) + s.writeAccess.Lock() + defer s.writeAccess.Unlock() s.stream.CancelRead(0) return s.stream.Close() } diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index e6a0b070f7..194a722597 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -46,8 +46,8 @@ type registrationRPCClient interface { // NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. // The stream should be the first QUIC stream (control stream). func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { - transport := rpc.StreamTransport(stream) - conn := rpc.NewConn(transport) + transport := safeTransport(stream) + conn := newRPCClientConn(transport, ctx) return &RegistrationClient{ client: tunnelrpc.TunnelServer{Client: conn.Bootstrap(ctx)}, rpcConn: conn, diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go index 38af323ff7..6c40db8826 100644 --- a/protocol/cloudflare/datagram_rpc_v3.go +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -10,8 +10,6 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" E "github.com/sagernet/sing/common/exceptions" - - "zombiezen.com/go/capnproto2/rpc" ) var ( @@ -63,8 +61,8 @@ func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *I logger: logger, } client := tunnelrpc.CloudflaredServer_ServerToClient(srv) - transport := rpc.StreamTransport(stream) - rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + transport := safeTransport(stream) + rpcConn := newRPCServerConn(transport, client.Client) <-rpcConn.Done() E.Errors( rpcConn.Close(), diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index 8fa3ffa625..d7454cab1c 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -76,8 +76,8 @@ var newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2 if err != nil { return nil, err } - transport := rpc.StreamTransport(stream) - conn := rpc.NewConn(transport) + transport := safeTransport(stream) + conn := newRPCClientConn(transport, ctx) return &capnpV2SessionRPCClient{ client: tunnelrpc.SessionManager{Client: conn.Bootstrap(ctx)}, rpcConn: conn, @@ -533,8 +533,8 @@ func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inb logger: logger, } client := tunnelrpc.CloudflaredServer_ServerToClient(srv) - transport := rpc.StreamTransport(stream) - rpcConn := rpc.NewConn(transport, rpc.MainInterface(client.Client)) + transport := safeTransport(stream) + rpcConn := newRPCServerConn(transport, client.Client) <-rpcConn.Done() E.Errors( rpcConn.Close(), diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index f405d36366..674abef16b 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -49,6 +49,7 @@ type Inbound struct { flowLimiter *FlowLimiter accessCache *accessValidatorCache controlDialer N.Dialer + tunnelDialer N.Dialer connectionAccess sync.Mutex connections []io.Closer @@ -110,6 +111,13 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if err != nil { return nil, E.Cause(err, "build cloudflared control dialer") } + tunnelDialer, err := boxDialer.NewWithOptions(boxDialer.Options{ + Context: ctx, + Options: options.TunnelDialer, + }) + if err != nil { + return nil, E.Cause(err, "build cloudflared tunnel dialer") + } region := options.Region if region != "" && credentials.Endpoint != "" { @@ -140,6 +148,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo flowLimiter: &FlowLimiter{}, accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, controlDialer: controlDialer, + tunnelDialer: tunnelDialer, datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), datagramV3Manager: NewDatagramV3SessionManager(), @@ -310,7 +319,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion connection, err := NewQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, datagramVersion, - features, numPreviousAttempts, i.gracePeriod, i.controlDialer, func() { + features, numPreviousAttempts, i.gracePeriod, i.tunnelDialer, func() { i.notifyConnected(connIndex) }, i.logger, ) diff --git a/protocol/cloudflare/safe_transport.go b/protocol/cloudflare/safe_transport.go new file mode 100644 index 0000000000..99b7880d7e --- /dev/null +++ b/protocol/cloudflare/safe_transport.go @@ -0,0 +1,63 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "time" + + E "github.com/sagernet/sing/common/exceptions" + + capnp "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/rpc" +) + +const ( + safeTransportMaxRetries = 3 + safeTransportRetryInterval = 500 * time.Millisecond +) + +type safeReadWriteCloser struct { + io.ReadWriteCloser + retries int +} + +func (s *safeReadWriteCloser) Read(p []byte) (int, error) { + n, err := s.ReadWriteCloser.Read(p) + if n == 0 && err != nil && isTemporaryError(err) { + if s.retries >= safeTransportMaxRetries { + return 0, E.Cause(err, "read capnproto transport after multiple temporary errors") + } + s.retries++ + time.Sleep(safeTransportRetryInterval) + return n, err + } + if err == nil { + s.retries = 0 + } + return n, err +} + +func isTemporaryError(err error) bool { + type temporary interface{ Temporary() bool } + t, ok := err.(temporary) + return ok && t.Temporary() +} + +func safeTransport(stream io.ReadWriteCloser) rpc.Transport { + return rpc.StreamTransport(&safeReadWriteCloser{ReadWriteCloser: stream}) +} + +type noopCapnpLogger struct{} + +func (noopCapnpLogger) Infof(ctx context.Context, format string, args ...interface{}) {} +func (noopCapnpLogger) Errorf(ctx context.Context, format string, args ...interface{}) {} + +func newRPCClientConn(transport rpc.Transport, ctx context.Context) *rpc.Conn { + return rpc.NewConn(transport, rpc.ConnLog(noopCapnpLogger{})) +} + +func newRPCServerConn(transport rpc.Transport, client capnp.Client) *rpc.Conn { + return rpc.NewConn(transport, rpc.MainInterface(client), rpc.ConnLog(noopCapnpLogger{})) +} From e3ed3f00eb6dae4d9591893dda7719bf8a415f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 27 Mar 2026 16:31:47 +0800 Subject: [PATCH 40/41] Fix cloudflared parity gaps --- protocol/cloudflare/connection_drain_test.go | 4 +- protocol/cloudflare/connection_http2.go | 65 ++++- .../connection_http2_behavior_test.go | 167 +++++++++++++ protocol/cloudflare/connection_quic.go | 18 +- protocol/cloudflare/connection_quic_test.go | 50 ++++ protocol/cloudflare/control.go | 6 +- .../cloudflare/datagram_lifecycle_test.go | 3 +- protocol/cloudflare/datagram_rpc_test.go | 68 +++++- protocol/cloudflare/datagram_rpc_v3.go | 7 +- protocol/cloudflare/datagram_v2.go | 30 ++- protocol/cloudflare/datagram_v3.go | 7 +- protocol/cloudflare/datagram_v3_test.go | 81 ++++++ protocol/cloudflare/direct_origin_test.go | 66 +++++ protocol/cloudflare/dispatch.go | 50 +++- protocol/cloudflare/flow_limiter_test.go | 19 +- protocol/cloudflare/helpers_test.go | 46 ++-- protocol/cloudflare/icmp.go | 97 +++++--- protocol/cloudflare/icmp_test.go | 96 +++++++- protocol/cloudflare/inbound.go | 231 ++++++++++++++---- protocol/cloudflare/inbound_state_test.go | 155 ++++++++++++ protocol/cloudflare/rpc_stream_test.go | 186 ++++++++++++++ protocol/cloudflare/special_service.go | 29 ++- protocol/cloudflare/special_service_test.go | 38 +++ 23 files changed, 1378 insertions(+), 141 deletions(-) create mode 100644 protocol/cloudflare/connection_http2_behavior_test.go create mode 100644 protocol/cloudflare/inbound_state_test.go create mode 100644 protocol/cloudflare/rpc_stream_test.go diff --git a/protocol/cloudflare/connection_drain_test.go b/protocol/cloudflare/connection_drain_test.go index 0d975a1547..756129502f 100644 --- a/protocol/cloudflare/connection_drain_test.go +++ b/protocol/cloudflare/connection_drain_test.go @@ -10,8 +10,9 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/sagernet/quic-go" + + "github.com/google/uuid" ) type stubNetConn struct { @@ -43,6 +44,7 @@ func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.N func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) { return nil, errors.New("unused") } + func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) { return nil, errors.New("unused") } diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index 0b21cab067..a8d4dae40e 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -9,6 +9,7 @@ import ( "math" "net" "net/http" + "runtime/debug" "strconv" "strings" "sync" @@ -27,8 +28,17 @@ const ( h2EdgeSNI = "h2.cftunnel.com" h2ResponseMetaCloudflared = `{"src":"cloudflared"}` h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}` + contentTypeHeader = "content-type" + contentLengthHeader = "content-length" + transferEncodingHeader = "transfer-encoding" + chunkTransferEncoding = "chunked" + sseContentType = "text/event-stream" + grpcContentType = "application/grpc" + ndjsonContentType = "application/x-ndjson" ) +var flushableContentTypes = []string{sseContentType, grpcContentType, ndjsonContentType} + // HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge. // Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server. type HTTP2Connection struct { @@ -191,7 +201,7 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque return } c.registrationResult = result - c.inbound.notifyConnected(c.connIndex) + c.inbound.notifyConnected(c.connIndex, "http2") c.logger.Info("connected to ", result.Location, " (connection ", result.ConnectionID, ")") @@ -246,14 +256,18 @@ func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Reques } } + flushState := &http2FlushState{shouldFlush: connectionType != ConnectionTypeHTTP} stream := &http2DataStream{ reader: r.Body, writer: w, flusher: flusher, + state: flushState, + logger: c.logger, } respWriter := &http2ResponseWriter{ - writer: w, - flusher: flusher, + writer: w, + flusher: flusher, + flushState: flushState, } c.inbound.dispatchRequest(ctx, stream, respWriter, request) @@ -386,15 +400,26 @@ type http2DataStream struct { reader io.ReadCloser writer http.ResponseWriter flusher http.Flusher + state *http2FlushState + logger log.ContextLogger } func (s *http2DataStream) Read(p []byte) (int, error) { return s.reader.Read(p) } -func (s *http2DataStream) Write(p []byte) (int, error) { - n, err := s.writer.Write(p) - if err == nil { +func (s *http2DataStream) Write(p []byte) (n int, err error) { + defer func() { + if recovered := recover(); recovered != nil { + if s.logger != nil { + s.logger.Debug("recovered from HTTP/2 data stream panic: ", recovered, "\n", string(debug.Stack())) + } + n = 0 + err = io.ErrClosedPipe + } + }() + n, err = s.writer.Write(p) + if err == nil && s.state != nil && s.state.shouldFlush { s.flusher.Flush() } return n, err @@ -409,6 +434,7 @@ type http2ResponseWriter struct { writer http.ResponseWriter flusher http.Flusher headersSent bool + flushState *http2FlushState } func (w *http2ResponseWriter) AddTrailer(name, value string) { @@ -462,12 +488,37 @@ func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Meta w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders)) w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin) + if w.flushState != nil && shouldFlushHTTPHeaders(userHeaders) { + w.flushState.shouldFlush = true + } if statusCode == http.StatusSwitchingProtocols { statusCode = http.StatusOK } w.writer.WriteHeader(statusCode) - w.flusher.Flush() + if w.flushState != nil && w.flushState.shouldFlush { + w.flusher.Flush() + } return nil } + +type http2FlushState struct { + shouldFlush bool +} + +func shouldFlushHTTPHeaders(headers http.Header) bool { + if headers.Get(contentLengthHeader) == "" { + return true + } + if transferEncoding := strings.ToLower(headers.Get(transferEncodingHeader)); transferEncoding != "" && strings.Contains(transferEncoding, chunkTransferEncoding) { + return true + } + contentType := strings.ToLower(headers.Get(contentTypeHeader)) + for _, flushable := range flushableContentTypes { + if strings.HasPrefix(contentType, flushable) { + return true + } + } + return false +} diff --git a/protocol/cloudflare/connection_http2_behavior_test.go b/protocol/cloudflare/connection_http2_behavior_test.go new file mode 100644 index 0000000000..04477be338 --- /dev/null +++ b/protocol/cloudflare/connection_http2_behavior_test.go @@ -0,0 +1,167 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "io" + "net/http" + "testing" + + "github.com/sagernet/sing-box/log" +) + +type captureHTTP2Writer struct { + header http.Header + flushCount int + statusCode int + body []byte + panicWrite bool +} + +func (w *captureHTTP2Writer) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *captureHTTP2Writer) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +func (w *captureHTTP2Writer) Write(p []byte) (int, error) { + if w.panicWrite { + panic("write after close") + } + w.body = append(w.body, p...) + return len(p), nil +} + +func (w *captureHTTP2Writer) Flush() { + w.flushCount++ +} + +func TestHTTP2NonStreamingResponseDoesNotFlush(t *testing.T) { + writer := &captureHTTP2Writer{} + flushState := &http2FlushState{} + respWriter := &http2ResponseWriter{ + writer: writer, + flusher: writer, + flushState: flushState, + } + + err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, http.Header{ + "Content-Type": []string{"application/json"}, + "Content-Length": []string{"2"}, + })) + if err != nil { + t.Fatal(err) + } + if writer.flushCount != 0 { + t.Fatalf("expected no header flush for non-streaming response, got %d", writer.flushCount) + } + + stream := &http2DataStream{ + writer: writer, + flusher: writer, + state: flushState, + logger: log.NewNOPFactory().NewLogger("test"), + } + if _, err := stream.Write([]byte("ok")); err != nil { + t.Fatal(err) + } + if writer.flushCount != 0 { + t.Fatalf("expected no body flush for non-streaming response, got %d", writer.flushCount) + } +} + +func TestHTTP2StreamingResponsesFlush(t *testing.T) { + testCases := []struct { + name string + header http.Header + }{ + { + name: "sse", + header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "Content-Length": []string{"1"}, + }, + }, + { + name: "grpc", + header: http.Header{ + "Content-Type": []string{"application/grpc"}, + "Content-Length": []string{"1"}, + }, + }, + { + name: "ndjson", + header: http.Header{ + "Content-Type": []string{"application/x-ndjson"}, + "Content-Length": []string{"1"}, + }, + }, + { + name: "chunked", + header: http.Header{ + "Content-Type": []string{"application/json"}, + "Content-Length": []string{"-1"}, + "Transfer-Encoding": []string{"chunked"}, + }, + }, + { + name: "no-content-length", + header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + writer := &captureHTTP2Writer{} + flushState := &http2FlushState{} + respWriter := &http2ResponseWriter{ + writer: writer, + flusher: writer, + flushState: flushState, + } + + err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, testCase.header)) + if err != nil { + t.Fatal(err) + } + if writer.flushCount == 0 { + t.Fatal("expected header flush for streaming response") + } + + stream := &http2DataStream{ + writer: writer, + flusher: writer, + state: flushState, + logger: log.NewNOPFactory().NewLogger("test"), + } + if _, err := stream.Write([]byte("chunk")); err != nil { + t.Fatal(err) + } + if writer.flushCount < 2 { + t.Fatalf("expected body flush for streaming response, got %d flushes", writer.flushCount) + } + }) + } +} + +func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) { + writer := &captureHTTP2Writer{panicWrite: true} + stream := &http2DataStream{ + writer: writer, + flusher: writer, + state: &http2FlushState{shouldFlush: true}, + logger: log.NewNOPFactory().NewLogger("test"), + } + + _, err := stream.Write([]byte("panic")) + if err != io.ErrClosedPipe { + t.Fatalf("expected io.ErrClosedPipe, got %v", err) + } +} diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 32f872a56b..41afa0e90c 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -60,6 +60,15 @@ type QUICConnection struct { closeOnce sync.Once } +type quicStreamHandle interface { + io.Reader + io.Writer + io.Closer + CancelRead(code quic.StreamErrorCode) + CancelWrite(code quic.StreamErrorCode) + SetWriteDeadline(t time.Time) error +} + type quicConnection interface { OpenStream() (*quic.Stream, error) AcceptStream(ctx context.Context) (*quic.Stream, error) @@ -80,7 +89,6 @@ func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reaso return err } - // NewQUICConnection dials the edge and establishes a QUIC connection. func NewQUICConnection( ctx context.Context, @@ -240,13 +248,14 @@ func (q *QUICConnection) acceptStreams(ctx context.Context, handler StreamHandle } } -func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, handler StreamHandler) { +func (q *QUICConnection) handleStream(ctx context.Context, stream quicStreamHandle, handler StreamHandler) { rwc := newStreamReadWriteCloser(stream) defer rwc.Close() streamType, err := ReadStreamSignature(rwc) if err != nil { q.logger.Debug("failed to read stream signature: ", err) + stream.CancelWrite(0) return } @@ -255,6 +264,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, request, err := ReadConnectRequest(rwc) if err != nil { q.logger.Debug("failed to read connect request: ", err) + stream.CancelWrite(0) return } handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex) @@ -365,11 +375,11 @@ type DatagramSender interface { // streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser // with mutex-protected writes and safe close semantics. type streamReadWriteCloser struct { - stream *quic.Stream + stream quicStreamHandle writeAccess sync.Mutex } -func newStreamReadWriteCloser(stream *quic.Stream) *streamReadWriteCloser { +func newStreamReadWriteCloser(stream quicStreamHandle) *streamReadWriteCloser { return &streamReadWriteCloser{stream: stream} } diff --git a/protocol/cloudflare/connection_quic_test.go b/protocol/cloudflare/connection_quic_test.go index 78479dad80..5c7fe864ff 100644 --- a/protocol/cloudflare/connection_quic_test.go +++ b/protocol/cloudflare/connection_quic_test.go @@ -3,9 +3,14 @@ package cloudflare import ( + "context" "io" "strings" "testing" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-box/log" ) func TestQUICInitialPacketSize(t *testing.T) { @@ -77,3 +82,48 @@ func TestNOPCloserReadWriterTracksEOF(t *testing.T) { t.Fatalf("expected cached EOF, got %v", err) } } + +type fakeQUICStream struct { + reader strings.Reader + cancelWriteCount int +} + +func (s *fakeQUICStream) Read(p []byte) (int, error) { return s.reader.Read(p) } +func (s *fakeQUICStream) Write(p []byte) (int, error) { return len(p), nil } +func (s *fakeQUICStream) Close() error { return nil } +func (s *fakeQUICStream) CancelRead(quic.StreamErrorCode) {} +func (s *fakeQUICStream) CancelWrite(quic.StreamErrorCode) { + s.cancelWriteCount++ +} +func (s *fakeQUICStream) SetWriteDeadline(time.Time) error { return nil } + +func TestHandleStreamCancelsWriteOnSignatureError(t *testing.T) { + stream := &fakeQUICStream{reader: *strings.NewReader("broken")} + connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")} + + connection.handleStream(context.Background(), stream, nil) + if stream.cancelWriteCount != 1 { + t.Fatalf("expected CancelWrite on signature error, got %d", stream.cancelWriteCount) + } +} + +type nopStreamHandler struct{} + +func (nopStreamHandler) HandleDataStream(context.Context, io.ReadWriteCloser, *ConnectRequest, uint8) { +} +func (nopStreamHandler) HandleRPCStream(context.Context, io.ReadWriteCloser, uint8) {} +func (nopStreamHandler) HandleRPCStreamWithSender(context.Context, io.ReadWriteCloser, uint8, DatagramSender) { +} +func (nopStreamHandler) HandleDatagram(context.Context, []byte, DatagramSender) {} + +func TestHandleStreamCancelsWriteOnConnectRequestError(t *testing.T) { + stream := &fakeQUICStream{ + reader: *strings.NewReader(string(dataStreamSignature[:])), + } + connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")} + + connection.handleStream(context.Background(), stream, nopStreamHandler{}) + if stream.cancelWriteCount != 1 { + t.Fatalf("expected CancelWrite on connect request error, got %d", stream.cancelWriteCount) + } +} diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index 194a722597..a68648bb6f 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -19,7 +19,7 @@ import ( ) const ( - registrationTimeout = 10 * time.Second + rpcTimeout = 5 * time.Second ) var clientVersion = "sing-box " + C.Version @@ -63,7 +63,7 @@ func (c *RegistrationClient) RegisterConnection( connIndex uint8, options *RegistrationConnectionOptions, ) (*RegistrationResult, error) { - ctx, cancel := context.WithTimeout(ctx, registrationTimeout) + ctx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() promise := c.client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error { @@ -147,8 +147,6 @@ func (c *RegistrationClient) RegisterConnection( // Unregister sends the UnregisterConnection RPC. func (c *RegistrationClient) Unregister(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, registrationTimeout) - defer cancel() promise := c.client.UnregisterConnection(ctx, nil) _, err := promise.Struct() return err diff --git a/protocol/cloudflare/datagram_lifecycle_test.go b/protocol/cloudflare/datagram_lifecycle_test.go index b08e3a7e58..bfcde80be2 100644 --- a/protocol/cloudflare/datagram_lifecycle_test.go +++ b/protocol/cloudflare/datagram_lifecycle_test.go @@ -10,11 +10,12 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + + "github.com/google/uuid" ) type v2UnregisterCall struct { diff --git a/protocol/cloudflare/datagram_rpc_test.go b/protocol/cloudflare/datagram_rpc_test.go index 08974a9cb8..33ab397903 100644 --- a/protocol/cloudflare/datagram_rpc_test.go +++ b/protocol/cloudflare/datagram_rpc_test.go @@ -4,13 +4,15 @@ package cloudflare import ( "context" + "net" "testing" + "time" - "github.com/google/uuid" "github.com/sagernet/sing-box/adapter/inbound" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + "github.com/google/uuid" capnp "zombiezen.com/go/capnproto2" ) @@ -90,6 +92,40 @@ func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregist } } +func newUnregisterUDPSessionCallForSession(t *testing.T, sessionID uuid.UUID, message string) tunnelrpc.SessionManager_unregisterUdpSession { + t.Helper() + + _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg) + if err != nil { + t.Fatal(err) + } + if err := params.SetSessionId(sessionID[:]); err != nil { + t.Fatal(err) + } + if err := params.SetMessage(message); err != nil { + t.Fatal(err) + } + + _, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) + if err != nil { + t.Fatal(err) + } + results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg) + if err != nil { + t.Fatal(err) + } + + return tunnelrpc.SessionManager_unregisterUdpSession{ + Ctx: context.Background(), + Params: params, + Results: results, + } +} + func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) { server := &cloudflaredV3Server{ inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")}, @@ -131,3 +167,33 @@ func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) { t.Fatalf("unexpected unregister error %v", err) } } + +func TestV2RPCUnregisterUDPSessionPropagatesMessage(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()} + muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger) + + sessionID := uuid.New() + if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil { + t.Fatal(err) + } + muxer.sessionAccess.RLock() + session := muxer.sessions[sessionID] + muxer.sessionAccess.RUnlock() + if session == nil { + t.Fatal("expected registered session") + } + + server := &cloudflaredServer{ + inbound: inboundInstance, + muxer: muxer, + ctx: context.Background(), + logger: inboundInstance.logger, + } + if err := server.UnregisterUdpSession(newUnregisterUDPSessionCallForSession(t, sessionID, "edge close")); err != nil { + t.Fatal(err) + } + if reason := session.closeReason(); reason != "edge close" { + t.Fatalf("expected close reason propagated from edge, got %q", reason) + } +} diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go index 6c40db8826..071a7a88f1 100644 --- a/protocol/cloudflare/datagram_rpc_v3.go +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -63,7 +63,12 @@ func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *I client := tunnelrpc.CloudflaredServer_ServerToClient(srv) transport := safeTransport(stream) rpcConn := newRPCServerConn(transport, client.Client) - <-rpcConn.Done() + rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) + defer cancel() + select { + case <-rpcConn.Done(): + case <-rpcCtx.Done(): + } E.Errors( rpcConn.Close(), transport.Close(), diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index d7454cab1c..ae7a5e2d33 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -19,6 +19,7 @@ import ( "github.com/google/uuid" "zombiezen.com/go/capnproto2/rpc" + "zombiezen.com/go/capnproto2/server" ) // V2 wire format: [payload | 16B sessionID | 1B type] (suffix-based) @@ -204,7 +205,7 @@ func (m *DatagramV2Muxer) RegisterSession( } // UnregisterSession removes a UDP session. -func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID) { +func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID, message string) { m.sessionAccess.Lock() session, exists := m.sessions[sessionID] if exists { @@ -213,7 +214,7 @@ func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID) { m.sessionAccess.Unlock() if exists { - session.markRemoteClosed() + session.markRemoteClosed(message) session.close() m.logger.Info("unregistered V2 UDP session ", sessionID) } @@ -231,7 +232,7 @@ func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession, m.sessionAccess.Unlock() if !session.remoteClosed() { - unregisterCtx, cancel := context.WithTimeout(context.Background(), registrationTimeout) + unregisterCtx, cancel := context.WithTimeout(context.Background(), rpcTimeout) defer cancel() if err := m.unregisterRemoteSession(unregisterCtx, session.id, session.closeReason()); err != nil { m.logger.Debug("failed to unregister V2 UDP session ", session.id, ": ", err) @@ -388,10 +389,12 @@ func (s *udpSession) closeWithReason(reason string) { s.close() } -func (s *udpSession) markRemoteClosed() { +func (s *udpSession) markRemoteClosed(message string) { s.stateAccess.Lock() s.closedByRemote = true - if s.closeReasonString == "" { + if message != "" { + s.closeReasonString = message + } else if s.closeReasonString == "" { s.closeReasonString = "unregistered by edge" } s.stateAccess.Unlock() @@ -458,6 +461,7 @@ type cloudflaredServer struct { } func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error { + server.Ack(call.Options) sessionIDBytes, err := call.Params.SessionId() if err != nil { return err @@ -494,6 +498,7 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg } func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error { + server.Ack(call.Options) sessionIDBytes, err := call.Params.SessionId() if err != nil { return err @@ -503,11 +508,17 @@ func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_u return err } - s.muxer.UnregisterSession(sessionID) + message, err := call.Params.Message() + if err != nil { + return err + } + + s.muxer.UnregisterSession(sessionID, message) return nil } func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { + server.Ack(call.Options) version := call.Params.Version() configData, _ := call.Params.Config() updateResult := s.inbound.ApplyConfig(version, configData) @@ -535,7 +546,12 @@ func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inb client := tunnelrpc.CloudflaredServer_ServerToClient(srv) transport := safeTransport(stream) rpcConn := newRPCServerConn(transport, client.Client) - <-rpcConn.Done() + rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) + defer cancel() + select { + case <-rpcConn.Done(): + case <-rpcCtx.Done(): + } E.Errors( rpcConn.Close(), transport.Close(), diff --git a/protocol/cloudflare/datagram_v3.go b/protocol/cloudflare/datagram_v3.go index 42719c5a1e..ee0bca1ff6 100644 --- a/protocol/cloudflare/datagram_v3.go +++ b/protocol/cloudflare/datagram_v3.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "net/netip" + "os" "sync" "time" @@ -183,7 +184,7 @@ func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) { } func (m *DatagramV3Muxer) handlePayload(data []byte) { - if len(data) < v3RequestIDLength { + if len(data) < v3RequestIDLength || len(data) > v3RequestIDLength+maxV3UDPPayloadLen { return } @@ -390,6 +391,10 @@ func (s *v3Session) writeLoop() { case payload := <-s.writeChan: err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination)) if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + s.inbound.logger.Debug("drop V3 UDP payload due to write deadline exceeded") + continue + } s.close() return } diff --git a/protocol/cloudflare/datagram_v3_test.go b/protocol/cloudflare/datagram_v3_test.go index 22f7cc5859..5b0ab08478 100644 --- a/protocol/cloudflare/datagram_v3_test.go +++ b/protocol/cloudflare/datagram_v3_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/netip" + "os" "testing" "time" @@ -149,3 +150,83 @@ func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) { t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0])) } } + +func TestDatagramV3HandlePayloadDropsOversizedPayload(t *testing.T) { + requestID := RequestID{} + requestID[15] = 9 + session := &v3Session{ + id: requestID, + writeChan: make(chan []byte, 1), + } + manager := NewDatagramV3SessionManager() + manager.sessions[requestID] = session + muxer := &DatagramV3Muxer{ + inbound: &Inbound{ + datagramV3Manager: manager, + }, + } + + payload := make([]byte, v3RequestIDLength+maxV3UDPPayloadLen+1) + copy(payload[:v3RequestIDLength], requestID[:]) + muxer.handlePayload(payload) + + select { + case <-session.writeChan: + t.Fatal("expected oversized payload to be dropped") + default: + } +} + +type deadlinePacketConn struct { + err error +} + +func (c *deadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + buffer.Release() + return M.Socksaddr{}, io.EOF +} + +func (c *deadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return c.err +} + +func (c *deadlinePacketConn) Close() error { return nil } +func (c *deadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *deadlinePacketConn) SetDeadline(time.Time) error { return nil } +func (c *deadlinePacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *deadlinePacketConn) SetWriteDeadline(time.Time) error { return nil } + +func TestDatagramV3WriteLoopDropsDeadlineExceeded(t *testing.T) { + session := &v3Session{ + destination: netip.MustParseAddrPort("127.0.0.1:53"), + origin: &deadlinePacketConn{err: os.ErrDeadlineExceeded}, + inbound: &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + }, + writeChan: make(chan []byte, 1), + closeChan: make(chan struct{}), + } + + done := make(chan struct{}) + go func() { + session.writeLoop() + close(done) + }() + + session.writeToOrigin([]byte("payload")) + time.Sleep(50 * time.Millisecond) + + select { + case <-session.closeChan: + t.Fatal("expected session to remain open after deadline exceeded") + default: + } + + session.close() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected write loop to exit after manual close") + } +} diff --git a/protocol/cloudflare/direct_origin_test.go b/protocol/cloudflare/direct_origin_test.go index 85e1b9d8c7..1c62e230d8 100644 --- a/protocol/cloudflare/direct_origin_test.go +++ b/protocol/cloudflare/direct_origin_test.go @@ -14,6 +14,7 @@ import ( "time" boxTLS "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/log" ) func TestNewDirectOriginTransportUnix(t *testing.T) { @@ -118,3 +119,68 @@ func serveTestHTTPOverListener(listener net.Listener, handler func(http.Response server := &http.Server{Handler: http.HandlerFunc(handler)} _ = server.Serve(listener) } + +func TestDirectOriginTransportCacheReusesMatchingTransports(t *testing.T) { + inboundInstance := &Inbound{ + directTransports: make(map[string]*http.Transport), + } + service := ResolvedService{ + Kind: ResolvedServiceUnix, + UnixPath: "/tmp/test.sock", + BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, + } + + transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com") + if err != nil { + t.Fatal(err) + } + transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com") + if err != nil { + t.Fatal(err) + } + if transport1 != transport2 { + t.Fatal("expected matching direct-origin transports to be reused") + } + + transport3, _, err := inboundInstance.newDirectOriginTransport(service, "other.example.com") + if err != nil { + t.Fatal(err) + } + if transport3 == transport1 { + t.Fatal("expected different cache keys to produce different transports") + } +} + +func TestApplyConfigClearsDirectOriginTransportCache(t *testing.T) { + configManager, err := NewConfigManager() + if err != nil { + t.Fatal(err) + } + inboundInstance := &Inbound{ + logger: log.NewNOPFactory().NewLogger("test"), + configManager: configManager, + directTransports: make(map[string]*http.Transport), + } + service := ResolvedService{ + Kind: ResolvedServiceUnix, + UnixPath: "/tmp/test.sock", + BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, + } + + transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com") + if err != nil { + t.Fatal(err) + } + result := inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`)) + if result.Err != nil { + t.Fatal(result.Err) + } + + transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com") + if err != nil { + t.Fatal(err) + } + if transport1 == transport2 { + t.Fatal("expected ApplyConfig to clear direct-origin transport cache") + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 77afee784e..a2918c003c 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -5,6 +5,7 @@ package cloudflare import ( "context" "crypto/tls" + "encoding/json" "io" "net" "net/http" @@ -380,7 +381,6 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, return http.ErrUseLastResponse }, } - defer httpClient.CloseIdleConnections() response, err := httpClient.Do(httpRequest) if err != nil { @@ -444,6 +444,21 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter } func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) { + cacheKey, err := directOriginTransportKey(service, requestHost) + if err != nil { + return nil, nil, E.Cause(err, "marshal direct origin transport key") + } + + i.directTransportAccess.Lock() + if i.directTransports == nil { + i.directTransports = make(map[string]*http.Transport) + } + if transport, exists := i.directTransports[cacheKey]; exists { + i.directTransportAccess.Unlock() + return transport, func() {}, nil + } + i.directTransportAccess.Unlock() + dialer := &net.Dialer{ Timeout: service.OriginRequest.ConnectTimeout, KeepAlive: service.OriginRequest.TCPKeepAlive, @@ -473,9 +488,42 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost default: return nil, nil, E.New("unsupported direct origin service") } + + i.directTransportAccess.Lock() + if i.directTransports == nil { + i.directTransports = make(map[string]*http.Transport) + } + if cached, exists := i.directTransports[cacheKey]; exists { + i.directTransportAccess.Unlock() + transport.CloseIdleConnections() + return cached, func() {}, nil + } + i.directTransports[cacheKey] = transport + i.directTransportAccess.Unlock() return transport, func() {}, nil } +type directOriginTransportCacheKey struct { + Kind ResolvedServiceKind `json:"kind"` + UnixPath string `json:"unix_path,omitempty"` + RequestHost string `json:"request_host,omitempty"` + Origin OriginRequestConfig `json:"origin"` +} + +func directOriginTransportKey(service ResolvedService, requestHost string) (string, error) { + key := directOriginTransportCacheKey{ + Kind: service.Kind, + UnixPath: service.UnixPath, + RequestHost: effectiveOriginHost(service.OriginRequest, requestHost), + Origin: service.OriginRequest, + } + data, err := json.Marshal(key) + if err != nil { + return "", err + } + return string(data), nil +} + func effectiveOriginHost(originRequest OriginRequestConfig, requestHost string) string { if originRequest.HTTPHostHeader != "" { return originRequest.HTTPHostHeader diff --git a/protocol/cloudflare/flow_limiter_test.go b/protocol/cloudflare/flow_limiter_test.go index 199fdfe18c..628c7a3ed8 100644 --- a/protocol/cloudflare/flow_limiter_test.go +++ b/protocol/cloudflare/flow_limiter_test.go @@ -40,16 +40,23 @@ func newLimitedInbound(t *testing.T, limit uint64) *Inbound { if err != nil { t.Fatal(err) } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) config := configManager.Snapshot() config.WarpRouting.MaxActiveFlows = limit configManager.activeConfig = config return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - router: &testRouter{}, - logger: logFactory.NewLogger("test"), - configManager: configManager, - flowLimiter: &FlowLimiter{}, - datagramV3Manager: NewDatagramV3SessionManager(), + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + ctx: ctx, + cancel: cancel, + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + configManager: configManager, + flowLimiter: &FlowLimiter{}, + datagramV3Manager: NewDatagramV3SessionManager(), + connectionStates: make([]connectionState, 1), + successfulProtocols: make(map[string]struct{}), + directTransports: make(map[string]*http.Transport), } } diff --git a/protocol/cloudflare/helpers_test.go b/protocol/cloudflare/helpers_test.go index 64c8718cb4..2f595aecf2 100644 --- a/protocol/cloudflare/helpers_test.go +++ b/protocol/cloudflare/helpers_test.go @@ -186,27 +186,31 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i ctx, cancel := context.WithCancel(context.Background()) inboundInstance := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), - ctx: ctx, - cancel: cancel, - router: &testRouter{}, - logger: logFactory.NewLogger("test"), - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - edgeIPVersion: 0, - datagramVersion: "", - featureSelector: newFeatureSelector(ctx, credentials.AccountTag, ""), - gracePeriod: 5 * time.Second, - configManager: configManager, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), - datagramV3Manager: NewDatagramV3SessionManager(), - connectedIndices: make(map[uint8]struct{}), - connectedNotify: make(chan uint8, haConnections), - controlDialer: N.SystemDialer, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, + Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"), + ctx: ctx, + cancel: cancel, + router: &testRouter{}, + logger: logFactory.NewLogger("test"), + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + edgeIPVersion: 0, + datagramVersion: "", + featureSelector: newFeatureSelector(ctx, credentials.AccountTag, ""), + gracePeriod: 5 * time.Second, + configManager: configManager, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + datagramV3Manager: NewDatagramV3SessionManager(), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), + controlDialer: N.SystemDialer, + tunnelDialer: N.SystemDialer, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer}, + connectionStates: make([]connectionState, haConnections), + successfulProtocols: make(map[string]struct{}), + directTransports: make(map[string]*http.Transport), } t.Cleanup(func() { diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go index 088fd4159f..c58c373205 100644 --- a/protocol/cloudflare/icmp.go +++ b/protocol/cloudflare/icmp.go @@ -20,8 +20,10 @@ import ( const ( icmpFlowTimeout = 30 * time.Second icmpTraceIdentityLength = 16 + 8 + 1 - defaultICMPPacketTTL = 64 + defaultICMPPacketTTL = 255 icmpErrorHeaderLen = 8 + ipv4TTLExceededQuoteLen = 548 + ipv6TTLExceededQuoteLen = 1232 icmpv4TypeEchoRequest = 8 icmpv4TypeEchoReply = 0 @@ -154,7 +156,13 @@ const ( ) type icmpFlowState struct { - writer *ICMPReplyWriter + writer *ICMPReplyWriter + lastActive time.Time +} + +type traceEntry struct { + context ICMPTraceContext + createdAt time.Time } type ICMPReplyWriter struct { @@ -162,14 +170,14 @@ type ICMPReplyWriter struct { wireVersion icmpWireVersion access sync.Mutex - traces map[ICMPRequestKey]ICMPTraceContext + traces map[ICMPRequestKey]traceEntry } func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter { return &ICMPReplyWriter{ sender: sender, wireVersion: wireVersion, - traces: make(map[ICMPRequestKey]ICMPTraceContext), + traces: make(map[ICMPRequestKey]traceEntry), } } @@ -178,7 +186,10 @@ func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceC return } w.access.Lock() - w.traces[packetInfo.RequestKey()] = traceContext + w.traces[packetInfo.RequestKey()] = traceEntry{ + context: traceContext, + createdAt: time.Now(), + } w.access.Unlock() } @@ -193,11 +204,12 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error { requestKey := packetInfo.ReplyRequestKey() w.access.Lock() - traceContext, loaded := w.traces[requestKey] + entry, loaded := w.traces[requestKey] if loaded { delete(w.traces, requestKey) } w.access.Unlock() + traceContext := entry.context datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext) if err != nil { @@ -206,6 +218,16 @@ func (w *ICMPReplyWriter) WritePacket(packet []byte) error { return w.sender.SendDatagram(datagram) } +func (w *ICMPReplyWriter) cleanupExpired(now time.Time) { + w.access.Lock() + defer w.access.Unlock() + for key, entry := range w.traces { + if now.After(entry.createdAt.Add(icmpFlowTimeout)) { + delete(w.traces, key) + } + } +} + type ICMPBridge struct { inbound *Inbound sender DatagramSender @@ -217,13 +239,17 @@ type ICMPBridge struct { } func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge { - return &ICMPBridge{ + bridge := &ICMPBridge{ inbound: inbound, sender: sender, wireVersion: wireVersion, routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout), flows: make(map[ICMPFlowKey]*icmpFlowState), } + if inbound != nil && inbound.ctx != nil { + go bridge.cleanupLoop(inbound.ctx) + } + return bridge } func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error { @@ -256,7 +282,7 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont return nil } if packetInfo.TTLExpired() { - ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo, maxEncodedICMPPacketLen(b.wireVersion, traceContext)) + ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo) if err != nil { return err } @@ -272,6 +298,7 @@ func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceCont } state := b.getFlowState(packetInfo.FlowKey()) + state.lastActive = time.Now() if traceContext.Traced { state.writer.RegisterRequestTrace(packetInfo, traceContext) } @@ -311,6 +338,31 @@ func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState { return state } +func (b *ICMPBridge) cleanupLoop(ctx context.Context) { + ticker := time.NewTicker(icmpFlowTimeout) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + b.cleanupExpired(now) + } + } +} + +func (b *ICMPBridge) cleanupExpired(now time.Time) { + b.flowAccess.Lock() + defer b.flowAccess.Unlock() + for key, state := range b.flows { + state.writer.cleanupExpired(now) + if now.After(state.lastActive.Add(icmpFlowTimeout)) { + delete(b.flows, key) + } + } +} + func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) { if len(packet) < 1 { return ICMPPacketInfo{}, E.New("empty IP packet") @@ -408,27 +460,24 @@ func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTrace return limit } -func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { +func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) { switch packetInfo.IPVersion { case 4: - return buildIPv4ICMPTTLExceededPacket(packetInfo, maxPacketLen) + return buildIPv4ICMPTTLExceededPacket(packetInfo) case 6: - return buildIPv6ICMPTTLExceededPacket(packetInfo, maxPacketLen) + return buildIPv6ICMPTTLExceededPacket(packetInfo) default: return nil, E.New("unsupported IP version: ", packetInfo.IPVersion) } } -func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { +func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) { const headerLen = 20 if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() { return nil, E.New("TTL exceeded packet requires IPv4 addresses") } - if maxPacketLen <= headerLen+icmpErrorHeaderLen { - return nil, E.New("TTL exceeded packet size limit is too small") - } - quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + quotedLength := min(len(packetInfo.RawPacket), ipv4TTLExceededQuoteLen) packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) packet[0] = 0x45 binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet))) @@ -444,16 +493,13 @@ func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) return packet, nil } -func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo, maxPacketLen int) ([]byte, error) { +func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) { const headerLen = 40 if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() { return nil, E.New("TTL exceeded packet requires IPv6 addresses") } - if maxPacketLen <= headerLen+icmpErrorHeaderLen { - return nil, E.New("TTL exceeded packet size limit is too small") - } - quotedLength := min(len(packetInfo.RawPacket), maxPacketLen-headerLen-icmpErrorHeaderLen) + quotedLength := min(len(packetInfo.RawPacket), ipv6TTLExceededQuoteLen) packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength) packet[0] = 0x60 binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength)) @@ -509,14 +555,7 @@ func checksum(data []byte, initial uint32) uint16 { return ^uint16(sum) } -func encodeV2ICMPDatagram(packet []byte, traceContext ICMPTraceContext) ([]byte, error) { - if traceContext.Traced { - data := make([]byte, 0, len(packet)+len(traceContext.Identity)+1) - data = append(data, packet...) - data = append(data, traceContext.Identity...) - data = append(data, byte(DatagramV2TypeIPWithTrace)) - return data, nil - } +func encodeV2ICMPDatagram(packet []byte, _ ICMPTraceContext) ([]byte, error) { data := make([]byte, 0, len(packet)+1) data = append(data, packet...) data = append(data, byte(DatagramV2TypeIP)) diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go index aeecf5751a..05bc0b594f 100644 --- a/protocol/cloudflare/icmp_test.go +++ b/protocol/cloudflare/icmp_test.go @@ -131,12 +131,11 @@ func TestICMPBridgeHandleV2TracedReply(t *testing.T) { t.Fatalf("expected one reply datagram, got %d", len(sender.sent)) } reply := sender.sent[0] - if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) { - t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1]) + if reply[len(reply)-1] != byte(DatagramV2TypeIP) { + t.Fatalf("expected plain v2 IP reply, got type %d", reply[len(reply)-1]) } - gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1] - if !bytes.Equal(gotIdentity, traceIdentity) { - t.Fatalf("unexpected trace identity: %x", gotIdentity) + if len(reply) != len(buildEchoReply(buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)))+1 { + t.Fatalf("unexpected traced reply size: %d", len(reply)) } } @@ -257,14 +256,10 @@ func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) { t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent)) } reply := sender.sent[0] - if reply[len(reply)-1] != byte(DatagramV2TypeIPWithTrace) { - t.Fatalf("expected traced v2 reply, got type %d", reply[len(reply)-1]) + if reply[len(reply)-1] != byte(DatagramV2TypeIP) { + t.Fatalf("expected plain v2 reply, got type %d", reply[len(reply)-1]) } - gotIdentity := reply[len(reply)-1-icmpTraceIdentityLength : len(reply)-1] - if !bytes.Equal(gotIdentity, traceIdentity) { - t.Fatalf("unexpected trace identity: %x", gotIdentity) - } - rawReply := reply[:len(reply)-1-icmpTraceIdentityLength] + rawReply := reply[:len(reply)-1] packetInfo, err := ParseICMPPacket(rawReply) if err != nil { t.Fatal(err) @@ -275,6 +270,9 @@ func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) { if packetInfo.SourceIP != target || packetInfo.Destination != source { t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) } + if packetInfo.TTL() != 255 { + t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL()) + } } func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) { @@ -319,6 +317,9 @@ func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) { if packetInfo.SourceIP != target || packetInfo.Destination != source { t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination) } + if packetInfo.TTL() != 255 { + t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL()) + } } func TestICMPBridgeDropsNonEcho(t *testing.T) { @@ -348,6 +349,77 @@ func TestICMPBridgeDropsNonEcho(t *testing.T) { } } +func TestBuildICMPTTLExceededPacketUsesRFCQuoteLengths(t *testing.T) { + ipv4Packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1) + ipv4Packet = append(ipv4Packet, bytes.Repeat([]byte{0xaa}, 4096)...) + ipv4Info, err := ParseICMPPacket(ipv4Packet) + if err != nil { + t.Fatal(err) + } + ipv4Reply, err := buildICMPTTLExceededPacket(ipv4Info) + if err != nil { + t.Fatal(err) + } + if len(ipv4Reply) != 20+icmpErrorHeaderLen+ipv4TTLExceededQuoteLen { + t.Fatalf("unexpected IPv4 TTL exceeded size: %d", len(ipv4Reply)) + } + + ipv6Packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1) + ipv6Packet = append(ipv6Packet, bytes.Repeat([]byte{0xbb}, 4096)...) + ipv6Info, err := ParseICMPPacket(ipv6Packet) + if err != nil { + t.Fatal(err) + } + ipv6Reply, err := buildICMPTTLExceededPacket(ipv6Info) + if err != nil { + t.Fatal(err) + } + if len(ipv6Reply) != 40+icmpErrorHeaderLen+ipv6TTLExceededQuoteLen { + t.Fatalf("unexpected IPv6 TTL exceeded size: %d", len(ipv6Reply)) + } +} + +func TestICMPBridgeCleanupExpired(t *testing.T) { + bridge := NewICMPBridge(&Inbound{}, &captureDatagramSender{}, icmpWireV2) + now := time.Now() + + expiredKey := ICMPFlowKey{ + IPVersion: 4, + SourceIP: netip.MustParseAddr("198.18.0.2"), + Destination: netip.MustParseAddr("1.1.1.1"), + } + expiredState := bridge.getFlowState(expiredKey) + expiredState.lastActive = now.Add(-icmpFlowTimeout - time.Second) + expiredState.writer.traces[ICMPRequestKey{Flow: expiredKey, Identifier: 1, Sequence: 1}] = traceEntry{ + context: ICMPTraceContext{Traced: true, Identity: []byte{1}}, + createdAt: now.Add(-icmpFlowTimeout - time.Second), + } + + activeKey := ICMPFlowKey{ + IPVersion: 6, + SourceIP: netip.MustParseAddr("2001:db8::2"), + Destination: netip.MustParseAddr("2606:4700:4700::1111"), + } + activeState := bridge.getFlowState(activeKey) + activeState.lastActive = now + activeState.writer.traces[ICMPRequestKey{Flow: activeKey, Identifier: 2, Sequence: 2}] = traceEntry{ + context: ICMPTraceContext{Traced: true, Identity: []byte{2}}, + createdAt: now, + } + + bridge.cleanupExpired(now) + + if _, exists := bridge.flows[expiredKey]; exists { + t.Fatal("expected expired flow to be removed") + } + if _, exists := bridge.flows[activeKey]; !exists { + t.Fatal("expected active flow to remain") + } + if len(activeState.writer.traces) != 1 { + t.Fatalf("expected active trace to remain, got %d", len(activeState.writer.traces)) + } +} + func buildEchoReply(packet []byte) []byte { info, err := ParseICMPPacket(packet) if err != nil { diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index 674abef16b..f5ec07ec4a 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -8,6 +8,8 @@ import ( "errors" "io" "math/rand" + "net/http" + "runtime/debug" "sync" "time" @@ -30,6 +32,17 @@ func RegisterInbound(registry *inbound.Registry) { var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflared only supports remote-managed tunnels") +var ( + newQUICConnection = NewQUICConnection + newHTTP2Connection = NewHTTP2Connection + serveQUICConnection = func(connection *QUICConnection, ctx context.Context, handler StreamHandler) error { + return connection.Serve(ctx, handler) + } + serveHTTP2Connection = func(connection *HTTP2Connection, ctx context.Context) error { + return connection.Serve(ctx) + } +) + type Inbound struct { inbound.Adapter ctx context.Context @@ -63,6 +76,19 @@ type Inbound struct { connectedAccess sync.Mutex connectedIndices map[uint8]struct{} connectedNotify chan uint8 + + stateAccess sync.Mutex + connectionStates []connectionState + successfulProtocols map[string]struct{} + firstSuccessfulProtocol string + + directTransportAccess sync.Mutex + directTransports map[string]*http.Transport +} + +type connectionState struct { + protocol string + retries uint8 } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) { @@ -130,30 +156,33 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo inboundCtx, cancel := context.WithCancel(ctx) return &Inbound{ - Adapter: inbound.NewAdapter(C.TypeCloudflared, tag), - ctx: inboundCtx, - cancel: cancel, - router: router, - logger: logger, - credentials: credentials, - connectorID: uuid.New(), - haConnections: haConnections, - protocol: protocol, - region: region, - edgeIPVersion: edgeIPVersion, - datagramVersion: datagramVersion, - featureSelector: newFeatureSelector(inboundCtx, credentials.AccountTag, datagramVersion), - gracePeriod: gracePeriod, - configManager: configManager, - flowLimiter: &FlowLimiter{}, - accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, - controlDialer: controlDialer, - tunnelDialer: tunnelDialer, - datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), - datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), - datagramV3Manager: NewDatagramV3SessionManager(), - connectedIndices: make(map[uint8]struct{}), - connectedNotify: make(chan uint8, haConnections), + Adapter: inbound.NewAdapter(C.TypeCloudflared, tag), + ctx: inboundCtx, + cancel: cancel, + router: router, + logger: logger, + credentials: credentials, + connectorID: uuid.New(), + haConnections: haConnections, + protocol: protocol, + region: region, + edgeIPVersion: edgeIPVersion, + datagramVersion: datagramVersion, + featureSelector: newFeatureSelector(inboundCtx, credentials.AccountTag, datagramVersion), + gracePeriod: gracePeriod, + configManager: configManager, + flowLimiter: &FlowLimiter{}, + accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer}, + controlDialer: controlDialer, + tunnelDialer: tunnelDialer, + datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer), + datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer), + datagramV3Manager: NewDatagramV3SessionManager(), + connectedIndices: make(map[uint8]struct{}), + connectedNotify: make(chan uint8, haConnections), + connectionStates: make([]connectionState, haConnections), + successfulProtocols: make(map[string]struct{}), + directTransports: make(map[string]*http.Transport), }, nil } @@ -179,6 +208,7 @@ func (i *Inbound) Start(stage adapter.StartStage) error { } for connIndex := 0; connIndex < i.haConnections; connIndex++ { + i.initializeConnectionState(uint8(connIndex)) i.done.Add(1) go i.superviseConnection(uint8(connIndex), edgeAddrs) select { @@ -197,7 +227,24 @@ func (i *Inbound) Start(stage adapter.StartStage) error { return nil } -func (i *Inbound) notifyConnected(connIndex uint8) { +func (i *Inbound) notifyConnected(connIndex uint8, protocol string) { + i.stateAccess.Lock() + if i.successfulProtocols == nil { + i.successfulProtocols = make(map[string]struct{}) + } + i.ensureConnectionStateLocked(connIndex) + state := i.connectionStates[connIndex] + state.retries = 0 + state.protocol = protocol + i.connectionStates[connIndex] = state + if protocol != "" { + i.successfulProtocols[protocol] = struct{}{} + if i.firstSuccessfulProtocol == "" { + i.firstSuccessfulProtocol = protocol + } + } + i.stateAccess.Unlock() + if i.connectedNotify == nil { return } @@ -217,6 +264,7 @@ func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult { i.logger.Error("update ingress configuration: ", result.Err) return result } + i.resetDirectOriginTransports() i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")") return result } @@ -234,6 +282,7 @@ func (i *Inbound) Close() error { } i.connections = nil i.connectionAccess.Unlock() + i.resetDirectOriginTransports() return nil } @@ -247,7 +296,6 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { defer i.done.Done() edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs)) - retries := 0 for { select { case <-i.ctx.Done(): @@ -256,7 +304,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { } edgeAddr := edgeAddrs[edgeIndex] - err := i.serveConnection(connIndex, edgeAddr, uint8(retries)) + err := i.safeServeConnection(connIndex, edgeAddr) if err == nil || i.ctx.Err() != nil { return } @@ -266,9 +314,9 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { return } - retries++ + retries := i.incrementConnectionRetries(connIndex) edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs)) - backoff := backoffDuration(retries) + backoff := backoffDuration(int(retries)) var retryableErr *RetryableError if errors.As(err, &retryableErr) && retryableErr.Delay > 0 { backoff = retryableErr.Delay @@ -283,16 +331,10 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { } } -func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, numPreviousAttempts uint8) error { - protocol := i.protocol - // An empty protocol means the user configured "auto". For the token-provided, - // remotely-managed tunnel mode we implement here, that intentionally matches - // cloudflared's token path: start with QUIC and fall back to HTTP/2 on failure. - // If we ever support non-token/local-config modes, that is where remote - // percentage-based protocol selection should be introduced. - if protocol == "" { - protocol = "quic" - } +func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr) error { + state := i.connectionState(connIndex) + protocol := state.protocol + numPreviousAttempts := state.retries datagramVersion, features := i.currentConnectionFeatures() switch protocol { @@ -304,6 +346,13 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, numPrevio if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { return err } + if !i.protocolIsAuto() { + return err + } + if i.hasSuccessfulProtocol("quic") { + return err + } + i.setConnectionProtocol(connIndex, "http2") i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err) return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts) case "http2": @@ -313,14 +362,23 @@ func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr, numPrevio } } +func (i *Inbound) safeServeConnection(connIndex uint8, edgeAddr *EdgeAddr) (err error) { + defer func() { + if recovered := recover(); recovered != nil { + err = E.New("panic in serve connection: ", recovered, "\n", string(debug.Stack())) + } + }() + return i.serveConnection(connIndex, edgeAddr) +} + func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion string, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")") - connection, err := NewQUICConnection( + connection, err := newQUICConnection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, datagramVersion, features, numPreviousAttempts, i.gracePeriod, i.tunnelDialer, func() { - i.notifyConnected(connIndex) + i.notifyConnected(connIndex, "quic") }, i.logger, ) if err != nil { @@ -333,7 +391,7 @@ func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion i.RemoveDatagramMuxer(connection) }() - return connection.Serve(i.ctx, i) + return serveQUICConnection(connection, i.ctx, i) } func (i *Inbound) currentConnectionFeatures() (string, []string) { @@ -350,7 +408,7 @@ func (i *Inbound) currentConnectionFeatures() (string, []string) { func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error { i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")") - connection, err := NewHTTP2Connection( + connection, err := newHTTP2Connection( i.ctx, edgeAddr, connIndex, i.credentials, i.connectorID, features, numPreviousAttempts, i.gracePeriod, i, i.logger, @@ -362,7 +420,92 @@ func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []str i.trackConnection(connection) defer i.untrackConnection(connection) - return connection.Serve(i.ctx) + return serveHTTP2Connection(connection, i.ctx) +} + +func (i *Inbound) initializeConnectionState(connIndex uint8) { + i.stateAccess.Lock() + defer i.stateAccess.Unlock() + i.ensureConnectionStateLocked(connIndex) + if i.connectionStates[connIndex].protocol == "" { + i.connectionStates[connIndex].protocol = i.initialProtocolLocked() + } +} + +func (i *Inbound) connectionState(connIndex uint8) connectionState { + i.stateAccess.Lock() + defer i.stateAccess.Unlock() + i.ensureConnectionStateLocked(connIndex) + state := i.connectionStates[connIndex] + if state.protocol == "" { + state.protocol = i.initialProtocolLocked() + i.connectionStates[connIndex] = state + } + return state +} + +func (i *Inbound) incrementConnectionRetries(connIndex uint8) uint8 { + i.stateAccess.Lock() + defer i.stateAccess.Unlock() + i.ensureConnectionStateLocked(connIndex) + state := i.connectionStates[connIndex] + state.retries++ + i.connectionStates[connIndex] = state + return state.retries +} + +func (i *Inbound) setConnectionProtocol(connIndex uint8, protocol string) { + i.stateAccess.Lock() + defer i.stateAccess.Unlock() + i.ensureConnectionStateLocked(connIndex) + state := i.connectionStates[connIndex] + state.protocol = protocol + i.connectionStates[connIndex] = state +} + +func (i *Inbound) hasSuccessfulProtocol(protocol string) bool { + i.stateAccess.Lock() + defer i.stateAccess.Unlock() + if i.successfulProtocols == nil { + return false + } + _, ok := i.successfulProtocols[protocol] + return ok +} + +func (i *Inbound) protocolIsAuto() bool { + return i.protocol == "" +} + +func (i *Inbound) ensureConnectionStateLocked(connIndex uint8) { + requiredLen := int(connIndex) + 1 + if len(i.connectionStates) >= requiredLen { + return + } + grown := make([]connectionState, requiredLen) + copy(grown, i.connectionStates) + i.connectionStates = grown +} + +func (i *Inbound) initialProtocolLocked() string { + if i.protocol != "" { + return i.protocol + } + if i.firstSuccessfulProtocol != "" { + return i.firstSuccessfulProtocol + } + return "quic" +} + +func (i *Inbound) resetDirectOriginTransports() { + i.directTransportAccess.Lock() + transports := i.directTransports + i.directTransports = make(map[string]*http.Transport) + i.directTransportAccess.Unlock() + + for _, transport := range transports { + transport.CloseIdleConnections() + } } func (i *Inbound) trackConnection(connection io.Closer) { diff --git a/protocol/cloudflare/inbound_state_test.go b/protocol/cloudflare/inbound_state_test.go new file mode 100644 index 0000000000..fe045c00c1 --- /dev/null +++ b/protocol/cloudflare/inbound_state_test.go @@ -0,0 +1,155 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/sagernet/sing-box/log" + N "github.com/sagernet/sing/common/network" + + "github.com/google/uuid" +) + +func restoreConnectionHooks(t *testing.T) { + t.Helper() + + originalNewQUICConnection := newQUICConnection + originalNewHTTP2Connection := newHTTP2Connection + originalServeQUICConnection := serveQUICConnection + originalServeHTTP2Connection := serveHTTP2Connection + t.Cleanup(func() { + newQUICConnection = originalNewQUICConnection + newHTTP2Connection = originalNewHTTP2Connection + serveQUICConnection = originalServeQUICConnection + serveHTTP2Connection = originalServeHTTP2Connection + }) +} + +func TestServeConnectionAutoFallbackSticky(t *testing.T) { + restoreConnectionHooks(t) + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "" + inboundInstance.initializeConnectionState(0) + + var quicCalls, http2Calls int + newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) { + quicCalls++ + return &QUICConnection{}, nil + } + serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error { + return errors.New("quic failed") + } + newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) { + http2Calls++ + return &HTTP2Connection{}, nil + } + serveHTTP2Connection = func(*HTTP2Connection, context.Context) error { + return errors.New("http2 failed") + } + + if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" { + t.Fatalf("expected HTTP/2 fallback error, got %v", err) + } + if state := inboundInstance.connectionState(0); state.protocol != "http2" { + t.Fatalf("expected sticky HTTP/2 fallback, got %#v", state) + } + + if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" { + t.Fatalf("expected second HTTP/2 error, got %v", err) + } + if quicCalls != 1 { + t.Fatalf("expected QUIC to be attempted once, got %d", quicCalls) + } + if http2Calls != 2 { + t.Fatalf("expected HTTP/2 to be attempted twice, got %d", http2Calls) + } +} + +func TestSecondConnectionInitialProtocolUsesFirstSuccess(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "" + + inboundInstance.notifyConnected(0, "http2") + inboundInstance.initializeConnectionState(1) + + if state := inboundInstance.connectionState(1); state.protocol != "http2" { + t.Fatalf("expected second connection to inherit HTTP/2, got %#v", state) + } +} + +func TestServeConnectionSkipsFallbackWhenQUICAlreadySucceeded(t *testing.T) { + restoreConnectionHooks(t) + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "" + inboundInstance.notifyConnected(0, "quic") + inboundInstance.initializeConnectionState(1) + + var http2Calls int + quicErr := errors.New("quic failed") + newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) { + return &QUICConnection{}, nil + } + serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error { + return quicErr + } + newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) { + http2Calls++ + return &HTTP2Connection{}, nil + } + + err := inboundInstance.serveConnection(1, &EdgeAddr{}) + if !errors.Is(err, quicErr) { + t.Fatalf("expected QUIC error without fallback, got %v", err) + } + if http2Calls != 0 { + t.Fatalf("expected no HTTP/2 fallback, got %d calls", http2Calls) + } + if state := inboundInstance.connectionState(1); state.protocol != "quic" { + t.Fatalf("expected connection to remain on QUIC, got %#v", state) + } +} + +func TestNotifyConnectedResetsRetries(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "" + inboundInstance.initializeConnectionState(0) + inboundInstance.incrementConnectionRetries(0) + inboundInstance.incrementConnectionRetries(0) + + inboundInstance.notifyConnected(0, "http2") + + state := inboundInstance.connectionState(0) + if state.retries != 0 { + t.Fatalf("expected retries reset after success, got %d", state.retries) + } + if state.protocol != "http2" { + t.Fatalf("expected protocol to be pinned to success, got %q", state.protocol) + } +} + +func TestSafeServeConnectionRecoversPanic(t *testing.T) { + restoreConnectionHooks(t) + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "quic" + inboundInstance.initializeConnectionState(0) + + newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) { + return &QUICConnection{}, nil + } + serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error { + panic("boom") + } + + err := inboundInstance.safeServeConnection(0, &EdgeAddr{}) + if err == nil || !strings.Contains(err.Error(), "panic in serve connection") { + t.Fatalf("expected recovered panic error, got %v", err) + } +} diff --git a/protocol/cloudflare/rpc_stream_test.go b/protocol/cloudflare/rpc_stream_test.go new file mode 100644 index 0000000000..07576407cf --- /dev/null +++ b/protocol/cloudflare/rpc_stream_test.go @@ -0,0 +1,186 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" + N "github.com/sagernet/sing/common/network" + + "github.com/google/uuid" +) + +type blockingRPCStream struct { + closed chan struct{} +} + +func newBlockingRPCStream() *blockingRPCStream { + return &blockingRPCStream{closed: make(chan struct{})} +} + +func (s *blockingRPCStream) Read(_ []byte) (int, error) { + <-s.closed + return 0, io.EOF +} + +func (s *blockingRPCStream) Write(p []byte) (int, error) { + return len(p), nil +} + +func (s *blockingRPCStream) Close() error { + select { + case <-s.closed: + default: + close(s.closed) + } + return nil +} + +type blockingPacketDialRouter struct { + testRouter + entered chan struct{} + release chan struct{} +} + +func (r *blockingPacketDialRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) { + select { + case <-r.entered: + default: + close(r.entered) + } + + select { + case <-r.release: + return newBlockingPacketConn(), nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func newRPCInbound(t *testing.T, router adapter.Router) *Inbound { + t.Helper() + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.router = router + return inboundInstance +} + +func newRPCClientPair(t *testing.T, ctx context.Context) (tunnelrpc.CloudflaredServer, io.Closer, io.Closer, net.Conn, net.Conn) { + t.Helper() + + serverSide, clientSide := net.Pipe() + transport := safeTransport(clientSide) + clientConn := newRPCClientConn(transport, ctx) + client := tunnelrpc.CloudflaredServer{Client: clientConn.Bootstrap(ctx)} + return client, clientConn, transport, serverSide, clientSide +} + +func TestServeRPCStreamRespectsContextDeadline(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + stream := newBlockingRPCStream() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + ServeRPCStream(ctx, stream, inboundInstance, NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger), inboundInstance.logger) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected ServeRPCStream to exit after context deadline") + } +} + +func TestServeV3RPCStreamRespectsContextDeadline(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + stream := newBlockingRPCStream() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + ServeV3RPCStream(ctx, stream, inboundInstance, inboundInstance.logger) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected ServeV3RPCStream to exit after context deadline") + } +} + +func TestV2RPCAckAllowsConcurrentDispatch(t *testing.T) { + router := &blockingPacketDialRouter{ + entered: make(chan struct{}), + release: make(chan struct{}), + } + inboundInstance := newRPCInbound(t, router) + muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx) + defer clientConn.Close() + defer transport.Close() + defer clientSide.Close() + + done := make(chan struct{}) + go func() { + ServeRPCStream(ctx, serverSide, inboundInstance, muxer, inboundInstance.logger) + close(done) + }() + + registerPromise := client.RegisterUdpSession(ctx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error { + sessionID := uuid.New() + if err := p.SetSessionId(sessionID[:]); err != nil { + return err + } + if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil { + return err + } + p.SetDstPort(53) + p.SetCloseAfterIdleHint(int64(time.Second)) + return p.SetTraceContext("") + }) + + select { + case <-router.entered: + case <-time.After(time.Second): + t.Fatal("expected register RPC to enter the blocking dial") + } + + updateCtx, updateCancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer updateCancel() + updatePromise := client.UpdateConfiguration(updateCtx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error { + p.SetVersion(1) + return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`)) + }) + if _, err := updatePromise.Result().Struct(); err != nil { + t.Fatalf("expected concurrent update RPC to succeed, got %v", err) + } + + close(router.release) + if _, err := registerPromise.Result().Struct(); err != nil { + t.Fatalf("expected register RPC to complete, got %v", err) + } + + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected ServeRPCStream to exit") + } +} diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index 822dfea018..080d397b25 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -28,7 +28,9 @@ var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") const ( socksReplySuccess = 0 socksReplyRuleFailure = 2 + socksReplyNetworkUnreachable = 3 socksReplyHostUnreachable = 4 + socksReplyConnectionRefused = 5 socksReplyCommandNotSupported = 7 ) @@ -223,6 +225,19 @@ func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ip if _, err := io.ReadFull(conn, methods); err != nil { return err } + var supportsNoAuth bool + for _, method := range methods { + if method == 0 { + supportsNoAuth = true + break + } + } + if !supportsNoAuth { + if _, err := conn.Write([]byte{5, 255}); err != nil { + return err + } + return E.New("unknown authentication type") + } if _, err := conn.Write([]byte{5, 0}); err != nil { return err } @@ -254,7 +269,7 @@ func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ip } targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) if err != nil { - _ = writeSocksReply(conn, socksReplyHostUnreachable) + _ = writeSocksReply(conn, socksReplyForDialError(err)) return err } defer cleanup() @@ -270,6 +285,18 @@ func writeSocksReply(conn net.Conn, reply byte) error { return err } +func socksReplyForDialError(err error) byte { + lower := strings.ToLower(err.Error()) + switch { + case strings.Contains(lower, "refused"): + return socksReplyConnectionRefused + case strings.Contains(lower, "network is unreachable"): + return socksReplyNetworkUnreachable + default: + return socksReplyHostUnreachable + } +} + func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) { switch addressType { case 1: diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index afaa67a3d2..98b92a8397 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -4,6 +4,7 @@ package cloudflare import ( "context" + "errors" "io" "net" "net/http" @@ -182,6 +183,43 @@ func writeSocksConnectIPv4(t *testing.T, conn net.Conn, address string) []byte { return data } +func TestServeSocksProxyRejectsMissingNoAuth(t *testing.T) { + inboundInstance := newSpecialServiceInbound(t) + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + errCh := make(chan error, 1) + go func() { + errCh <- inboundInstance.serveSocksProxy(context.Background(), serverSide, nil) + }() + + if _, err := clientSide.Write([]byte{5, 1, 2}); err != nil { + t.Fatal(err) + } + response := make([]byte, 2) + if _, err := io.ReadFull(clientSide, response); err != nil { + t.Fatal(err) + } + if string(response) != string([]byte{5, 255}) { + t.Fatalf("unexpected auth rejection response: %v", response) + } + if err := <-errCh; err == nil { + t.Fatal("expected socks auth rejection error") + } +} + +func TestSocksReplyForDialError(t *testing.T) { + if reply := socksReplyForDialError(io.EOF); reply != socksReplyHostUnreachable { + t.Fatalf("expected host unreachable for generic error, got %d", reply) + } + if reply := socksReplyForDialError(errors.New("connection refused")); reply != 5 { + t.Fatalf("expected connection refused reply, got %d", reply) + } + if reply := socksReplyForDialError(errors.New("network is unreachable")); reply != 3 { + t.Fatalf("expected network unreachable reply, got %d", reply) + } +} + func TestHandleBastionStream(t *testing.T) { listener := startEchoListener(t) defer listener.Close() From d5a2fd7e95a159cbf5cf29265816541b0813ef9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 27 Mar 2026 19:11:49 +0800 Subject: [PATCH 41/41] Fix low-cost cloudflared parity gaps --- option/cloudflared.go | 18 ++--- protocol/cloudflare/config_decode_test.go | 34 +++++++++ protocol/cloudflare/connection_drain_test.go | 34 +++++++++ protocol/cloudflare/connection_http2.go | 8 +- .../connection_http2_behavior_test.go | 24 ++++++ protocol/cloudflare/connection_quic.go | 20 +++-- protocol/cloudflare/control.go | 26 ++++++- protocol/cloudflare/datagram_rpc_test.go | 34 ++++++++- protocol/cloudflare/datagram_rpc_v3.go | 2 + protocol/cloudflare/datagram_v2.go | 13 +++- protocol/cloudflare/direct_origin_test.go | 20 +++++ protocol/cloudflare/dispatch.go | 32 ++++---- protocol/cloudflare/edge_tls.go | 17 +++++ protocol/cloudflare/edge_tls_test.go | 31 ++++++++ protocol/cloudflare/icmp.go | 13 +++- protocol/cloudflare/icmp_test.go | 12 +++ protocol/cloudflare/inbound.go | 33 +++++++-- protocol/cloudflare/inbound_state_test.go | 73 +++++++++++++++++++ protocol/cloudflare/origin_dial.go | 22 +++++- protocol/cloudflare/origin_dial_test.go | 56 ++++++++++++++ protocol/cloudflare/origin_request_test.go | 29 ++++++++ protocol/cloudflare/rpc_stream_test.go | 65 +++++++++++++++++ 22 files changed, 565 insertions(+), 51 deletions(-) create mode 100644 protocol/cloudflare/edge_tls.go create mode 100644 protocol/cloudflare/edge_tls_test.go create mode 100644 protocol/cloudflare/origin_dial_test.go diff --git a/option/cloudflared.go b/option/cloudflared.go index 7daaafb9f4..ff50ce1532 100644 --- a/option/cloudflared.go +++ b/option/cloudflared.go @@ -3,13 +3,13 @@ package option import "github.com/sagernet/sing/common/json/badoption" type CloudflaredInboundOptions struct { - Token string `json:"token,omitempty"` - HAConnections int `json:"ha_connections,omitempty"` - Protocol string `json:"protocol,omitempty"` - ControlDialer DialerOptions `json:"control_dialer,omitempty"` - TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"` - EdgeIPVersion int `json:"edge_ip_version,omitempty"` - DatagramVersion string `json:"datagram_version,omitempty"` - GracePeriod badoption.Duration `json:"grace_period,omitempty"` - Region string `json:"region,omitempty"` + Token string `json:"token,omitempty"` + HAConnections int `json:"ha_connections,omitempty"` + Protocol string `json:"protocol,omitempty"` + ControlDialer DialerOptions `json:"control_dialer,omitempty"` + TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"` + EdgeIPVersion int `json:"edge_ip_version,omitempty"` + DatagramVersion string `json:"datagram_version,omitempty"` + GracePeriod *badoption.Duration `json:"grace_period,omitempty"` + Region string `json:"region,omitempty"` } diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go index df05d64ecd..b68fc9ca27 100644 --- a/protocol/cloudflare/config_decode_test.go +++ b/protocol/cloudflare/config_decode_test.go @@ -5,9 +5,11 @@ package cloudflare import ( "context" "testing" + "time" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" ) func TestNewInboundRequiresToken(t *testing.T) { @@ -36,3 +38,35 @@ func TestNormalizeProtocolAutoUsesTokenStyleSentinel(t *testing.T) { t.Fatalf("expected auto protocol to normalize to token-style empty sentinel, got %q", protocol) } } + +func TestResolveGracePeriodDefaultsToThirtySeconds(t *testing.T) { + if got := resolveGracePeriod(nil); got != 30*time.Second { + t.Fatalf("expected default grace period 30s, got %s", got) + } +} + +func TestResolveGracePeriodPreservesExplicitZero(t *testing.T) { + var options option.CloudflaredInboundOptions + if err := json.Unmarshal([]byte(`{"grace_period":"0s"}`), &options); err != nil { + t.Fatal(err) + } + if options.GracePeriod == nil { + t.Fatal("expected explicit grace period to be set") + } + if got := resolveGracePeriod(options.GracePeriod); got != 0 { + t.Fatalf("expected explicit zero grace period, got %s", got) + } +} + +func TestResolveGracePeriodPreservesNonZeroValue(t *testing.T) { + var options option.CloudflaredInboundOptions + if err := json.Unmarshal([]byte(`{"grace_period":"45s"}`), &options); err != nil { + t.Fatal(err) + } + if options.GracePeriod == nil { + t.Fatal("expected explicit grace period to be set") + } + if got := resolveGracePeriod(options.GracePeriod); got != 45*time.Second { + t.Fatalf("expected grace period 45s, got %s", got) + } +} diff --git a/protocol/cloudflare/connection_drain_test.go b/protocol/cloudflare/connection_drain_test.go index 756129502f..8911ba1955 100644 --- a/protocol/cloudflare/connection_drain_test.go +++ b/protocol/cloudflare/connection_drain_test.go @@ -232,3 +232,37 @@ func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) { t.Fatal("expected graceful shutdown to finish") } } + +func TestQUICGracefulShutdownStopsWaitingWhenServeContextEnds(t *testing.T) { + conn := newStubQUICConn() + registrationClient := newMockRegistrationClient() + serveCtx, cancelServe := context.WithCancel(context.Background()) + connection := &QUICConnection{ + conn: conn, + gracePeriod: time.Second, + registrationClient: registrationClient, + registrationResult: &RegistrationResult{}, + serveCtx: serveCtx, + serveCancel: func() {}, + } + + done := make(chan struct{}) + go func() { + connection.gracefulShutdown() + close(done) + }() + + select { + case <-registrationClient.unregisterCalled: + case <-time.After(time.Second): + t.Fatal("expected unregister call") + } + + cancelServe() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatal("expected graceful shutdown to stop waiting once serve context ends") + } +} diff --git a/protocol/cloudflare/connection_http2.go b/protocol/cloudflare/connection_http2.go index a8d4dae40e..8b00404f9d 100644 --- a/protocol/cloudflare/connection_http2.go +++ b/protocol/cloudflare/connection_http2.go @@ -83,10 +83,7 @@ func NewHTTP2Connection( return nil, E.Cause(err, "load Cloudflare root CAs") } - tlsConfig := &tls.Config{ - RootCAs: rootCAs, - ServerName: h2EdgeSNI, - } + tlsConfig := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil) tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port())) if err != nil { @@ -283,7 +280,8 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp err := json.NewDecoder(r.Body).Decode(&body) if err != nil { c.logger.Error("decode configuration update: ", err) - w.WriteHeader(http.StatusBadRequest) + w.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared) + w.WriteHeader(http.StatusBadGateway) return } result := c.inbound.ApplyConfig(body.Version, body.Config) diff --git a/protocol/cloudflare/connection_http2_behavior_test.go b/protocol/cloudflare/connection_http2_behavior_test.go index 04477be338..4d875d1aac 100644 --- a/protocol/cloudflare/connection_http2_behavior_test.go +++ b/protocol/cloudflare/connection_http2_behavior_test.go @@ -3,6 +3,7 @@ package cloudflare import ( + "bytes" "io" "net/http" "testing" @@ -165,3 +166,26 @@ func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) { t.Fatalf("expected io.ErrClosedPipe, got %v", err) } } + +func TestHandleConfigurationUpdateDecodeFailureReturnsBadGateway(t *testing.T) { + writer := &captureHTTP2Writer{} + connection := &HTTP2Connection{ + logger: log.NewNOPFactory().NewLogger("test"), + } + request, err := http.NewRequest(http.MethodPost, "https://example.com", bytes.NewBufferString("{")) + if err != nil { + t.Fatal(err) + } + + connection.handleConfigurationUpdate(request, writer) + + if writer.statusCode != http.StatusBadGateway { + t.Fatalf("expected status %d, got %d", http.StatusBadGateway, writer.statusCode) + } + if meta := writer.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflared { + t.Fatalf("unexpected response meta: %q", meta) + } + if len(writer.body) != 0 { + t.Fatalf("expected empty response body, got %q", string(writer.body)) + } +} diff --git a/protocol/cloudflare/connection_quic.go b/protocol/cloudflare/connection_quic.go index 41afa0e90c..3f66eb632c 100644 --- a/protocol/cloudflare/connection_quic.go +++ b/protocol/cloudflare/connection_quic.go @@ -4,7 +4,6 @@ package cloudflare import ( "context" - "crypto/tls" "fmt" "io" "net" @@ -54,6 +53,7 @@ type QUICConnection struct { registrationResult *RegistrationResult onConnected func() + serveCtx context.Context serveCancel context.CancelFunc registrationClose sync.Once shutdownOnce sync.Once @@ -109,11 +109,7 @@ func NewQUICConnection( return nil, E.Cause(err, "load Cloudflare root CAs") } - tlsConfig := &tls.Config{ - RootCAs: rootCAs, - ServerName: quicEdgeSNI, - NextProtos: []string{quicEdgeALPN}, - } + tlsConfig := newEdgeTLSConfig(rootCAs, quicEdgeSNI, []string{quicEdgeALPN}) quicConfig := &quic.Config{ HandshakeIdleTimeout: quicHandshakeIdleTimeout, @@ -190,6 +186,7 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error " (connection ", q.registrationResult.ConnectionID, ")") serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx)) + q.serveCtx = serveCtx q.serveCancel = serveCancel errChan := make(chan error, 2) @@ -321,9 +318,16 @@ func (q *QUICConnection) gracefulShutdown() { } q.closeRegistrationClient() if q.gracePeriod > 0 { + waitCtx := q.serveCtx + if waitCtx == nil { + waitCtx = context.Background() + } timer := time.NewTimer(q.gracePeriod) - <-timer.C - timer.Stop() + defer timer.Stop() + select { + case <-timer.C: + case <-waitCtx.Done(): + } } q.closeNow("graceful shutdown") }) diff --git a/protocol/cloudflare/control.go b/protocol/cloudflare/control.go index a68648bb6f..7bdb2c0a4e 100644 --- a/protocol/cloudflare/control.go +++ b/protocol/cloudflare/control.go @@ -4,6 +4,7 @@ package cloudflare import ( "context" + "errors" "io" "net" "runtime" @@ -43,6 +44,29 @@ type registrationRPCClient interface { Close() error } +type permanentRegistrationError struct { + Err error +} + +func (e *permanentRegistrationError) Error() string { + if e == nil || e.Err == nil { + return "permanent registration error" + } + return e.Err.Error() +} + +func (e *permanentRegistrationError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func isPermanentRegistrationError(err error) bool { + var permanentErr *permanentRegistrationError + return errors.As(err, &permanentErr) +} + // NewRegistrationClient creates a Cap'n Proto RPC client over the given stream. // The stream should be the first QUIC stream (control stream). func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient { @@ -118,7 +142,7 @@ func (c *RegistrationClient) RegisterConnection( Delay: time.Duration(resultError.RetryAfter()), } } - return nil, registrationError + return nil, &permanentRegistrationError{Err: registrationError} case tunnelrpc.ConnectionResponse_result_Which_connectionDetails: connDetails, err := result.ConnectionDetails() diff --git a/protocol/cloudflare/datagram_rpc_test.go b/protocol/cloudflare/datagram_rpc_test.go index 33ab397903..f5e13e31ad 100644 --- a/protocol/cloudflare/datagram_rpc_test.go +++ b/protocol/cloudflare/datagram_rpc_test.go @@ -17,6 +17,10 @@ import ( ) func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) { + return newRegisterUDPSessionCallWithDstIP(t, []byte{127, 0, 0, 1}, traceContext) +} + +func newRegisterUDPSessionCallWithDstIP(t *testing.T, dstIP []byte, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) { t.Helper() _, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil)) @@ -31,7 +35,7 @@ func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.Ses if err := params.SetSessionId(sessionID[:]); err != nil { t.Fatal(err) } - if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil { + if err := params.SetDstIp(dstIP); err != nil { t.Fatal(err) } params.SetDstPort(53) @@ -197,3 +201,31 @@ func TestV2RPCUnregisterUDPSessionPropagatesMessage(t *testing.T) { t.Fatalf("expected close reason propagated from edge, got %q", reason) } } + +func TestV2RPCRegisterUDPSessionRejectsMissingDestinationIP(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()} + server := &cloudflaredServer{ + inbound: inboundInstance, + muxer: NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger), + ctx: context.Background(), + logger: inboundInstance.logger, + } + call, readResult := newRegisterUDPSessionCallWithDstIP(t, nil, "") + + if err := server.RegisterUdpSession(call); err != nil { + t.Fatal(err) + } + + result, err := readResult() + if err != nil { + t.Fatal(err) + } + resultErr, err := result.Err() + if err != nil { + t.Fatal(err) + } + if resultErr != "missing destination IP" { + t.Fatalf("unexpected result error %q", resultErr) + } +} diff --git a/protocol/cloudflare/datagram_rpc_v3.go b/protocol/cloudflare/datagram_rpc_v3.go index 071a7a88f1..73aa53ab74 100644 --- a/protocol/cloudflare/datagram_rpc_v3.go +++ b/protocol/cloudflare/datagram_rpc_v3.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc" E "github.com/sagernet/sing/common/exceptions" + "zombiezen.com/go/capnproto2/server" ) var ( @@ -38,6 +39,7 @@ func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager } func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error { + server.Ack(call.Options) version := call.Params.Version() configData, _ := call.Params.Config() updateResult := s.inbound.ApplyConfig(version, configData) diff --git a/protocol/cloudflare/datagram_v2.go b/protocol/cloudflare/datagram_v2.go index ae7a5e2d33..eaa1e61a2a 100644 --- a/protocol/cloudflare/datagram_v2.go +++ b/protocol/cloudflare/datagram_v2.go @@ -164,11 +164,16 @@ func (m *DatagramV2Muxer) RegisterSession( destinationPort uint16, closeAfterIdle time.Duration, ) error { + if destinationIP == nil { + return E.New("missing destination IP") + } var destinationAddr netip.Addr if ip4 := destinationIP.To4(); ip4 != nil { destinationAddr = netip.AddrFrom4([4]byte(ip4)) + } else if ip16 := destinationIP.To16(); ip16 != nil { + destinationAddr = netip.AddrFrom16([16]byte(ip16)) } else { - destinationAddr = netip.AddrFrom16([16]byte(destinationIP.To16())) + return E.New("invalid destination IP") } destination := netip.AddrPortFrom(destinationAddr, destinationPort) @@ -482,7 +487,11 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg return traceErr } - err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle) + if len(destinationIP) == 0 { + err = E.New("missing destination IP") + } else { + err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle) + } result, allocErr := call.Results.NewResult() if allocErr != nil { diff --git a/protocol/cloudflare/direct_origin_test.go b/protocol/cloudflare/direct_origin_test.go index 1c62e230d8..6b7786f44c 100644 --- a/protocol/cloudflare/direct_origin_test.go +++ b/protocol/cloudflare/direct_origin_test.go @@ -184,3 +184,23 @@ func TestApplyConfigClearsDirectOriginTransportCache(t *testing.T) { t.Fatal("expected ApplyConfig to clear direct-origin transport cache") } } + +func TestNewDirectOriginTransportUsesCloudflaredDefaults(t *testing.T) { + inboundInstance := &Inbound{} + transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{ + Kind: ResolvedServiceUnix, + UnixPath: "/tmp/test.sock", + BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, + }, "") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + if transport.ExpectContinueTimeout != time.Second { + t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout) + } + if transport.DisableCompression { + t.Fatal("expected compression to remain enabled by default") + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index a2918c003c..dc3bba85c0 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -428,14 +428,14 @@ func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{}) transport := &http.Transport{ - DisableCompression: true, - ForceAttemptHTTP2: originRequest.HTTP2Origin, - TLSHandshakeTimeout: originRequest.TLSTimeout, - IdleConnTimeout: originRequest.KeepAliveTimeout, - MaxIdleConns: originRequest.KeepAliveConnections, - MaxIdleConnsPerHost: originRequest.KeepAliveConnections, - Proxy: proxyFromEnvironment, - TLSClientConfig: tlsConfig, + ExpectContinueTimeout: time.Second, + ForceAttemptHTTP2: originRequest.HTTP2Origin, + TLSHandshakeTimeout: originRequest.TLSTimeout, + IdleConnTimeout: originRequest.KeepAliveTimeout, + MaxIdleConns: originRequest.KeepAliveConnections, + MaxIdleConnsPerHost: originRequest.KeepAliveConnections, + Proxy: proxyFromEnvironment, + TLSClientConfig: tlsConfig, DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return input, nil }, @@ -471,14 +471,14 @@ func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost return nil, nil, err } transport := &http.Transport{ - DisableCompression: true, - ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, - TLSHandshakeTimeout: service.OriginRequest.TLSTimeout, - IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, - MaxIdleConns: service.OriginRequest.KeepAliveConnections, - MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, - Proxy: proxyFromEnvironment, - TLSClientConfig: tlsConfig, + ExpectContinueTimeout: time.Second, + ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin, + TLSHandshakeTimeout: service.OriginRequest.TLSTimeout, + IdleConnTimeout: service.OriginRequest.KeepAliveTimeout, + MaxIdleConns: service.OriginRequest.KeepAliveConnections, + MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections, + Proxy: proxyFromEnvironment, + TLSClientConfig: tlsConfig, } switch service.Kind { case ResolvedServiceUnix, ResolvedServiceUnixTLS: diff --git a/protocol/cloudflare/edge_tls.go b/protocol/cloudflare/edge_tls.go new file mode 100644 index 0000000000..7381d53b42 --- /dev/null +++ b/protocol/cloudflare/edge_tls.go @@ -0,0 +1,17 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "crypto/tls" + "crypto/x509" +) + +func newEdgeTLSConfig(rootCAs *x509.CertPool, serverName string, nextProtos []string) *tls.Config { + return &tls.Config{ + RootCAs: rootCAs, + ServerName: serverName, + NextProtos: nextProtos, + CurvePreferences: []tls.CurveID{tls.CurveP256}, + } +} diff --git a/protocol/cloudflare/edge_tls_test.go b/protocol/cloudflare/edge_tls_test.go new file mode 100644 index 0000000000..759bd9cd0c --- /dev/null +++ b/protocol/cloudflare/edge_tls_test.go @@ -0,0 +1,31 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "crypto/tls" + "crypto/x509" + "testing" +) + +func TestNewEdgeTLSConfigUsesP256(t *testing.T) { + rootCAs := x509.NewCertPool() + config := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil) + + if config.RootCAs != rootCAs { + t.Fatal("expected root CA pool to be preserved") + } + if config.ServerName != h2EdgeSNI { + t.Fatalf("expected server name %q, got %q", h2EdgeSNI, config.ServerName) + } + if len(config.CurvePreferences) != 1 || config.CurvePreferences[0] != tls.CurveP256 { + t.Fatalf("unexpected curve preferences: %#v", config.CurvePreferences) + } +} + +func TestNewEdgeTLSConfigPreservesNextProtos(t *testing.T) { + config := newEdgeTLSConfig(x509.NewCertPool(), quicEdgeSNI, []string{quicEdgeALPN}) + if len(config.NextProtos) != 1 || config.NextProtos[0] != quicEdgeALPN { + t.Fatalf("unexpected next protos: %#v", config.NextProtos) + } +} diff --git a/protocol/cloudflare/icmp.go b/protocol/cloudflare/icmp.go index c58c373205..71b7cfe4c4 100644 --- a/protocol/cloudflare/icmp.go +++ b/protocol/cloudflare/icmp.go @@ -24,6 +24,7 @@ const ( icmpErrorHeaderLen = 8 ipv4TTLExceededQuoteLen = 548 ipv6TTLExceededQuoteLen = 1232 + maxICMPPayloadLen = 1280 icmpv4TypeEchoRequest = 8 icmpv4TypeEchoReply = 0 @@ -519,7 +520,7 @@ func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext case icmpWireV2: return encodeV2ICMPDatagram(packet, traceContext) case icmpWireV3: - return encodeV3ICMPDatagram(packet), nil + return encodeV3ICMPDatagram(packet) default: return nil, E.New("unsupported icmp wire version: ", wireVersion) } @@ -562,9 +563,15 @@ func encodeV2ICMPDatagram(packet []byte, _ ICMPTraceContext) ([]byte, error) { return data, nil } -func encodeV3ICMPDatagram(packet []byte) []byte { +func encodeV3ICMPDatagram(packet []byte) ([]byte, error) { + if len(packet) == 0 { + return nil, E.New("icmp payload is missing") + } + if len(packet) > maxICMPPayloadLen { + return nil, E.New("icmp payload is too large") + } data := make([]byte, 0, len(packet)+1) data = append(data, byte(DatagramV3TypeICMP)) data = append(data, packet...) - return data + return data, nil } diff --git a/protocol/cloudflare/icmp_test.go b/protocol/cloudflare/icmp_test.go index 05bc0b594f..df0b14677d 100644 --- a/protocol/cloudflare/icmp_test.go +++ b/protocol/cloudflare/icmp_test.go @@ -379,6 +379,18 @@ func TestBuildICMPTTLExceededPacketUsesRFCQuoteLengths(t *testing.T) { } } +func TestEncodeV3ICMPDatagramRejectsEmptyPayload(t *testing.T) { + if _, err := encodeV3ICMPDatagram(nil); err == nil { + t.Fatal("expected empty payload to be rejected") + } +} + +func TestEncodeV3ICMPDatagramRejectsOversizedPayload(t *testing.T) { + if _, err := encodeV3ICMPDatagram(make([]byte, maxICMPPayloadLen+1)); err == nil { + t.Fatal("expected oversized payload to be rejected") + } +} + func TestICMPBridgeCleanupExpired(t *testing.T) { bridge := NewICMPBridge(&Inbound{}, &captureDatagramSender{}, icmpWireV2) now := time.Now() diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index f5ec07ec4a..8f5e0b7c90 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -21,6 +21,7 @@ import ( "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badoption" N "github.com/sagernet/sing/common/network" "github.com/google/uuid" @@ -91,6 +92,26 @@ type connectionState struct { retries uint8 } +func resolveGracePeriod(value *badoption.Duration) time.Duration { + if value == nil { + return 30 * time.Second + } + return time.Duration(*value) +} + +func connectionRetryDecision(err error) (retry bool, cancelAll bool) { + switch { + case err == nil: + return false, false + case errors.Is(err, ErrNonRemoteManagedTunnelUnsupported): + return false, true + case isPermanentRegistrationError(err): + return false, false + default: + return true, false + } +} + func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) { if options.Token == "" { return nil, E.New("missing token") @@ -120,10 +141,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3") } - gracePeriod := time.Duration(options.GracePeriod) - if gracePeriod == 0 { - gracePeriod = 30 * time.Second - } + gracePeriod := resolveGracePeriod(options.GracePeriod) configManager, err := NewConfigManager() if err != nil { @@ -308,11 +326,16 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) { if err == nil || i.ctx.Err() != nil { return } - if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) { + retry, cancelAll := connectionRetryDecision(err) + if cancelAll { i.logger.Error("connection ", connIndex, " failed permanently: ", err) i.cancel() return } + if !retry { + i.logger.Error("connection ", connIndex, " failed permanently: ", err) + return + } retries := i.incrementConnectionRetries(connIndex) edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs)) diff --git a/protocol/cloudflare/inbound_state_test.go b/protocol/cloudflare/inbound_state_test.go index fe045c00c1..c397b81815 100644 --- a/protocol/cloudflare/inbound_state_test.go +++ b/protocol/cloudflare/inbound_state_test.go @@ -153,3 +153,76 @@ func TestSafeServeConnectionRecoversPanic(t *testing.T) { t.Fatalf("expected recovered panic error, got %v", err) } } + +func TestSuperviseConnectionStopsOnPermanentRegistrationError(t *testing.T) { + restoreConnectionHooks(t) + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "quic" + inboundInstance.initializeConnectionState(0) + + permanentErr := &permanentRegistrationError{Err: errors.New("permanent register error")} + newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) { + return &QUICConnection{}, nil + } + serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error { + return permanentErr + } + + inboundInstance.done.Add(1) + done := make(chan struct{}) + go func() { + inboundInstance.superviseConnection(0, []*EdgeAddr{{}}) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected connection supervision to stop") + } + + if retries := inboundInstance.connectionState(0).retries; retries != 0 { + t.Fatalf("expected no retries for permanent registration error, got %d", retries) + } + + select { + case <-inboundInstance.ctx.Done(): + t.Fatal("expected permanent registration error to stop only this connection") + default: + } +} + +func TestSuperviseConnectionCancelsInboundOnNonRemoteManagedError(t *testing.T) { + restoreConnectionHooks(t) + + inboundInstance := newLimitedInbound(t, 0) + inboundInstance.protocol = "quic" + inboundInstance.initializeConnectionState(0) + + newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) { + return &QUICConnection{}, nil + } + serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error { + return ErrNonRemoteManagedTunnelUnsupported + } + + inboundInstance.done.Add(1) + done := make(chan struct{}) + go func() { + inboundInstance.superviseConnection(0, []*EdgeAddr{{}}) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected connection supervision to stop") + } + + select { + case <-inboundInstance.ctx.Done(): + case <-time.After(time.Second): + t.Fatal("expected inbound cancellation on non-remote-managed tunnel error") + } +} diff --git a/protocol/cloudflare/origin_dial.go b/protocol/cloudflare/origin_dial.go index c937aa35b3..fcc5acefcc 100644 --- a/protocol/cloudflare/origin_dial.go +++ b/protocol/cloudflare/origin_dial.go @@ -5,13 +5,29 @@ package cloudflare import ( "context" "net/netip" + "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) +const originUDPWriteTimeout = 200 * time.Millisecond + +type udpWriteDeadlinePacketConn struct { + N.PacketConn +} + +func (c *udpWriteDeadlinePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + _ = c.PacketConn.SetWriteDeadline(time.Now().Add(originUDPWriteTimeout)) + defer func() { + _ = c.PacketConn.SetWriteDeadline(time.Time{}) + }() + return c.PacketConn.WritePacket(buffer, destination) +} + type routedOriginPacketDialer interface { DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) } @@ -29,11 +45,15 @@ func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination neti defer cancel() } - return originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{ + packetConn, err := originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{ Inbound: i.Tag(), InboundType: i.Type(), Network: N.NetworkUDP, Destination: M.SocksaddrFromNetIP(destination), UDPConnect: true, }) + if err != nil { + return nil, err + } + return &udpWriteDeadlinePacketConn{PacketConn: packetConn}, nil } diff --git a/protocol/cloudflare/origin_dial_test.go b/protocol/cloudflare/origin_dial_test.go new file mode 100644 index 0000000000..e5c979dbe2 --- /dev/null +++ b/protocol/cloudflare/origin_dial_test.go @@ -0,0 +1,56 @@ +//go:build with_cloudflared + +package cloudflare + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type captureDeadlinePacketConn struct { + err error + deadlines []time.Time +} + +func (c *captureDeadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + buffer.Release() + return M.Socksaddr{}, errors.New("unused") +} + +func (c *captureDeadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error { + buffer.Release() + return c.err +} + +func (c *captureDeadlinePacketConn) Close() error { return nil } +func (c *captureDeadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (c *captureDeadlinePacketConn) SetDeadline(time.Time) error { return nil } +func (c *captureDeadlinePacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *captureDeadlinePacketConn) SetWriteDeadline(t time.Time) error { + c.deadlines = append(c.deadlines, t) + return nil +} + +func TestDeadlinePacketConnWrapsWriteDeadline(t *testing.T) { + packetConn := &captureDeadlinePacketConn{} + wrapped := &udpWriteDeadlinePacketConn{PacketConn: packetConn} + + if err := wrapped.WritePacket(buf.As([]byte("payload")), M.Socksaddr{}); err != nil { + t.Fatal(err) + } + + if len(packetConn.deadlines) != 2 { + t.Fatalf("expected two deadline updates, got %d", len(packetConn.deadlines)) + } + if packetConn.deadlines[0].IsZero() { + t.Fatal("expected first deadline to set a timeout") + } + if !packetConn.deadlines[1].IsZero() { + t.Fatal("expected second deadline to clear the timeout") + } +} diff --git a/protocol/cloudflare/origin_request_test.go b/protocol/cloudflare/origin_request_test.go index a63c422363..efc94e5e41 100644 --- a/protocol/cloudflare/origin_request_test.go +++ b/protocol/cloudflare/origin_request_test.go @@ -12,6 +12,7 @@ import ( "encoding/pem" "io" "math/big" + "net" "net/http" "net/url" "os" @@ -20,8 +21,18 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + N "github.com/sagernet/sing/common/network" ) +type noopRouteConnectionRouter struct { + testRouter +} + +func (r *noopRouteConnectionRouter) RouteConnectionEx(_ context.Context, conn net.Conn, _ adapter.InboundContext, onClose N.CloseHandlerFunc) { + _ = conn.Close() + onClose(nil) +} + func TestOriginTLSServerName(t *testing.T) { t.Run("origin server name overrides host", func(t *testing.T) { serverName := originTLSServerName(OriginRequestConfig{ @@ -202,6 +213,24 @@ func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) { } } +func TestNewRouterOriginTransportUsesCloudflaredDefaults(t *testing.T) { + inbound := &Inbound{ + router: &noopRouteConnectionRouter{}, + } + transport, cleanup, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{}, "") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + if transport.ExpectContinueTimeout != time.Second { + t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout) + } + if transport.DisableCompression { + t.Fatal("expected compression to remain enabled by default") + } +} + func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody) if err != nil { diff --git a/protocol/cloudflare/rpc_stream_test.go b/protocol/cloudflare/rpc_stream_test.go index 07576407cf..e97eba7e91 100644 --- a/protocol/cloudflare/rpc_stream_test.go +++ b/protocol/cloudflare/rpc_stream_test.go @@ -184,3 +184,68 @@ func TestV2RPCAckAllowsConcurrentDispatch(t *testing.T) { t.Fatal("expected ServeRPCStream to exit") } } + +func TestV3RPCAckAllowsConcurrentDispatch(t *testing.T) { + inboundInstance := newLimitedInbound(t, 0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx) + defer clientConn.Close() + defer transport.Close() + defer clientSide.Close() + + done := make(chan struct{}) + go func() { + ServeV3RPCStream(ctx, serverSide, inboundInstance, inboundInstance.logger) + close(done) + }() + + inboundInstance.configManager.access.Lock() + updatePromise := client.UpdateConfiguration(ctx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error { + p.SetVersion(1) + return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`)) + }) + + time.Sleep(20 * time.Millisecond) + + registerCtx, registerCancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer registerCancel() + registerPromise := client.RegisterUdpSession(registerCtx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error { + sessionID := uuid.New() + if err := p.SetSessionId(sessionID[:]); err != nil { + return err + } + if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil { + return err + } + p.SetDstPort(53) + p.SetCloseAfterIdleHint(int64(time.Second)) + return p.SetTraceContext("") + }) + + registerResult, err := registerPromise.Result().Struct() + if err != nil { + t.Fatalf("expected concurrent v3 register RPC to succeed, got %v", err) + } + resultErr, err := registerResult.Err() + if err != nil { + t.Fatal(err) + } + if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() { + t.Fatalf("unexpected registration error %q", resultErr) + } + + inboundInstance.configManager.access.Unlock() + if _, err := updatePromise.Result().Struct(); err != nil { + t.Fatalf("expected update RPC to complete, got %v", err) + } + + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected ServeV3RPCStream to exit") + } +}