diff --git a/sip/transport_layer.go b/sip/transport_layer.go index ba3ad59..cea89e6 100644 --- a/sip/transport_layer.go +++ b/sip/transport_layer.go @@ -43,6 +43,10 @@ type TransportLayer struct { // dnsPreferSRV does always SRV lookup first dnsPreferSRV bool dnsPreferIP int // 0 - no preference , 1 -ip4, 2 - ip6 + + // forceLocalReplySocket forces using the connection that received the request + // for UDP responses, instead of trying to match by Via header + forceLocalReplySocket bool } type TransportLayerOption func(l *TransportLayer) @@ -67,6 +71,12 @@ func WithTransportLayerDNSLookupSRV(preferSRV bool) TransportLayerOption { } } +func WithTransportLayerForceLocalReplySocket(force bool) TransportLayerOption { + return func(l *TransportLayer) { + l.forceLocalReplySocket = force + } +} + // TODO will be exposed // withTransportLayerDNSLookupIP allows to set which ip4 or ip6 to prefer on resolve // default is ip4 @@ -465,6 +475,8 @@ func (l *TransportLayer) serverRequestConnection(ctx context.Context, req *Reque } sourceAddr := req.MessageData.Source() + + // For reliable transports, always use the connection that received the request if IsReliable(network) && sourceAddr != "" { // If the "sent-protocol" is a reliable transport protocol such as // TCP or SCTP, or TLS over those, the response MUST be sent using @@ -477,6 +489,16 @@ func (l *TransportLayer) serverRequestConnection(ctx context.Context, req *Reque } } + // For unreliable transports (UDP), if forceLocalReplySocket is enabled, + // always try to use the connection that received the request first + // The connection pool stores connections by remote address (source address) + if !IsReliable(network) && l.forceLocalReplySocket && sourceAddr != "" { + conn := transport.GetConnection(sourceAddr) + if conn != nil { + return conn, nil + } + } + viaHop := req.Via() if viaHop == nil { return nil, fmt.Errorf("no Via Header present") @@ -771,6 +793,12 @@ func (l *TransportLayer) allTransports() []transport { return []transport{l.udp, l.tcp, l.tls, l.ws, l.wss} } +// SetForceLocalReplySocket sets the forceLocalReplySocket flag on the transport layer. +// This allows setting the flag after the transport layer is created. +func (l *TransportLayer) SetForceLocalReplySocket(force bool) { + l.forceLocalReplySocket = force +} + func IsReliable(network string) bool { switch network { case "udp", "UDP": diff --git a/ua.go b/ua.go index b6b7aa0..30a1961 100644 --- a/ua.go +++ b/ua.go @@ -80,6 +80,18 @@ func WithUserAgentTransportLayerOptions(o ...sip.TransportLayerOption) UserAgent } } +// WithForceLocalReplySocket forces the transport layer to use the connection +// that received the request for UDP responses, instead of trying to match by Via header. +// This ensures responses are sent from the same socket that received the request. +// This applies to both server and client transactions. +func WithForceLocalReplySocket() UserAgentOption { + return func(ua *UserAgent) error { + // Add the transport layer option + ua.tpOptions = append(ua.tpOptions, sip.WithTransportLayerForceLocalReplySocket(true)) + return nil + } +} + // NewUA creates User Agent // User Agent will create transport and transaction layer // Check options for customizing user agent