Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package zio.http.netty.client

import java.net.InetSocketAddress
import java.net.{Inet6Address, InetAddress, InetSocketAddress}
import java.util.concurrent.TimeUnit

import zio._
Expand All @@ -36,6 +36,8 @@ private[netty] trait NettyConnectionPool extends ConnectionPool[JChannel]

private[netty] object NettyConnectionPool {

private val HappyEyeballsDelay: Duration = 250.millis

protected def createChannel(
channelFactory: JChannelFactory[JChannel],
eventLoopGroup: JEventLoopGroup,
Expand Down Expand Up @@ -106,38 +108,150 @@ private[netty] object NettyConnectionPool {

for {
resolvedHosts <- dnsResolver.resolve(location.host)
hosts <- Random.shuffle(resolvedHosts.toList)
hostsNec <- ZIO.succeed(NonEmptyChunk.fromIterable(hosts.head, hosts.tail))
ch <- collectFirstSuccess(hostsNec) { host =>
ZIO.suspend {
val bootstrap = new Bootstrap()
.channelFactory(channelFactory)
.group(eventLoopGroup)
.remoteAddress(new InetSocketAddress(host, location.port))
.withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt))
.handler(initializer)
localAddress.foreach(bootstrap.localAddress)

val channelFuture = bootstrap.connect()
val ch = channelFuture.channel()
Scope.addFinalizer {
NettyFutureExecutor.executed {
channelFuture.cancel(true)
ch.close()
}.when(ch.isOpen).ignoreLogged
} *> NettyFutureExecutor.executed(channelFuture).as(ch)
}
}
ch <-
// Use Happy Eyeballs algorithm
happyEyeballsConnect(
resolvedHosts,
channelFactory,
eventLoopGroup,
location,
initializer,
connectionTimeout,
localAddress,
)
} yield ch
}

private def collectFirstSuccess[R, E, A, B](
as: NonEmptyChunk[A],
)(f: A => ZIO[R, E, B])(implicit trace: Trace): ZIO[R, E, B] = {
ZIO.suspendSucceed {
val it = as.iterator
def loop: ZIO[R, E, B] = f(it.next()).catchAll(e => if (it.hasNext) loop else ZIO.fail(e))
loop
/**
* Attempts to connect to a single address.
*/
private def connectToAddress(
host: InetAddress,
channelFactory: JChannelFactory[JChannel],
eventLoopGroup: JEventLoopGroup,
location: URL.Location.Absolute,
initializer: ChannelInitializer[JChannel],
connectionTimeout: Option[Duration],
localAddress: Option[InetSocketAddress],
)(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = {
ZIO.suspend {
val bootstrap = new Bootstrap()
.channelFactory(channelFactory)
.group(eventLoopGroup)
.remoteAddress(new InetSocketAddress(host, location.port))
.withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt))
.handler(initializer)
localAddress.foreach(bootstrap.localAddress)

val channelFuture = bootstrap.connect()
val ch = channelFuture.channel()
Scope.addFinalizer {
NettyFutureExecutor.executed {
channelFuture.cancel(true)
ch.close()
}.whenDiscard(ch.isOpen).ignoreLogged
}.uninterruptible *> NettyFutureExecutor.executed(channelFuture).as(ch)
}
}

/**
* Returns a sequence of connection attempts with their delays. Per RFC 8305,
* we start with IPv6, then after firstAddressFamilyDelay we try IPv4, then
* alternate between families.
*/
private def sortAddresses(resolvedHosts: Chunk[InetAddress]): List[InetAddress] = {
val (ipv6Addresses, ipv4Addresses) = resolvedHosts.partition(_.isInstanceOf[Inet6Address])
val ipv6Iter = ipv6Addresses.iterator
val ipv4Iter = ipv4Addresses.iterator
val builder = List.newBuilder[InetAddress]
builder.sizeHint(resolvedHosts.size)

// Alternate between families
var useIpv6 = true
while (ipv6Iter.hasNext || ipv4Iter.hasNext) {

if (useIpv6 && ipv6Iter.hasNext) {
builder += ipv6Iter.next()
} else if (ipv4Iter.hasNext) {
builder += ipv4Iter.next()
} else if (ipv6Iter.hasNext) {
builder += ipv6Iter.next()
}

useIpv6 = !useIpv6
}

builder.result()
}

/**
* Implements Happy Eyeballs (RFC 8305) connection algorithm. Races connection
* attempts to IPv6 and IPv4 addresses with staggered delays.
*/
private def happyEyeballsConnect(
resolvedHosts: Chunk[InetAddress],
channelFactory: JChannelFactory[JChannel],
eventLoopGroup: JEventLoopGroup,
location: URL.Location.Absolute,
initializer: ChannelInitializer[JChannel],
connectionTimeout: Option[Duration],
localAddress: Option[InetSocketAddress],
)(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = {

if (resolvedHosts.isEmpty) {
ZIO.fail(new RuntimeException("No addresses to connect to"))
} else if (resolvedHosts.size == 1) {
connectToAddress(
resolvedHosts.head,
channelFactory,
eventLoopGroup,
location,
initializer,
connectionTimeout,
localAddress,
)
} else {
val addresses = sortAddresses(resolvedHosts)
for {
lastFailed <- Queue.dropping[Unit](requestedCapacity = 1)
successful <- Ref.make(List.empty[JChannel])
_ <- ZIO.raceAll(
connectToAddress(
addresses.head,
channelFactory,
eventLoopGroup,
location,
initializer,
connectionTimeout,
localAddress,
).onExit {
case e: Exit.Success[JChannel] => successful.update(channels => channels :+ e.value)
case _: Exit.Failure[_] => lastFailed.offer(()).unit
},
addresses.tail.zipWithIndex.map { case (address, index) =>
ZIO.sleep(HappyEyeballsDelay * index.toDouble).raceFirst(lastFailed.take).ignore *>
connectToAddress(
address,
channelFactory,
eventLoopGroup,
location,
initializer,
connectionTimeout,
localAddress,
).onExit {
case e: Exit.Success[JChannel] => successful.update(channels => channels :+ e.value)
case _: Exit.Failure[_] => lastFailed.offer(()).unit
}
},
)
channels <- successful.get
channel <- channels.collectFirst { case c if c.isOpen => c } match {
case ch: Some[JChannel] => ZIO.succeed(ch.value)
case None => ZIO.fail(new RuntimeException("All connection attempts failed"))
}
_ <- ZIO.foreachDiscard(channels.tail)(ch => ZIO.ignore(ch.close()))
} yield channel

}
}

Expand Down
Loading