Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ interface WebSocket {
* Returns the size in bytes of all messages enqueued to be transmitted to the server. This
* doesn't include framing overhead. It also doesn't include any bytes buffered by the operating
* system or network intermediaries. This method returns 0 if no messages are waiting
* in the queue. If may return a nonzero value after the web socket has been canceled; this
* in the queue. It may return a nonzero value after the web socket has been canceled; this
* indicates that enqueued messages were not transmitted.
*/
fun queueSize(): Long

/**
* Attempts to enqueue {@code text} to be UTF-8 encoded and sent as a the data of a text (type
* {@code 0x1}) message.
* Attempts to enqueue {@code bytes} to be sent as the data of a binary (type {@code 0x2})
* message.
*
* <p>This method returns true if the message was enqueued. Messages that would overflow the
* outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of
Expand All @@ -58,8 +58,8 @@ interface WebSocket {
fun send(bytes: ByteString): Boolean

/**
* Attempts to enqueue {@code bytes} to be sent as a the data of a binary (type {@code 0x2})
* message.
* Attempts to enqueue {@code text} to be UTF-8 encoded and sent as the data of a text (type
* {@code 0x1}) message.
*
* <p>This method returns true if the message was enqueued. Messages that would overflow the
* outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package misk.web

import okhttp3.WebSocket
import okio.ByteString
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.TimeUnit

class FakeWebSocketListener : okhttp3.WebSocketListener() {
val messages = LinkedBlockingDeque<String>()
val binaryMessages = LinkedBlockingDeque<ByteString>()

override fun onMessage(webSocket: okhttp3.WebSocket, text: String) {
messages.add(text)
}

override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
binaryMessages.add(bytes)
}

fun takeMessage() = messages.pollFirst(2, TimeUnit.SECONDS)
fun takeBinaryMessage() = binaryMessages.pollFirst(2, TimeUnit.SECONDS)
}
68 changes: 46 additions & 22 deletions misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ internal class JettyWebSocket(
val response: JettyServerUpgradeResponse
) : WebSocket {

internal sealed interface Message {
val byteCount: Long
data class Text(val text: String) : Message {
override val byteCount = text.utf8Size()
}
data class Binary(val data: ByteString) : Message {
override val byteCount = data.size.toLong()
}
}

/** Total size of messages enqueued and not yet transmitted by Jetty. */
private var outgoingQueueSize = 0L

/** Messages to send when the Web Socket connects. */
private var queue = ArrayDeque<String>()
private var queue = ArrayDeque<Message>()

/** Application's listener to notify of incoming messages from the client. */
private var listener: WebSocketListener? = null
Expand Down Expand Up @@ -93,41 +103,55 @@ internal class JettyWebSocket(
}

override fun send(text: String): Boolean {
val byteCount = text.utf8Size()
return enqueue(Message.Text(text))
}

override fun send(bytes: ByteString): Boolean {
return enqueue(Message.Binary(bytes))
}

private fun enqueue(message: Message): Boolean {
val byteCount = message.byteCount
if (outgoingQueueSize + byteCount > MAX_QUEUE_SIZE) {
close(1001, null)
return false
}

outgoingQueueSize += byteCount
queue.add(text)
queue.add(message)
sendQueue()

return true
}

private fun sendQueue() {
while (adapter.isConnected && queue.isNotEmpty()) {
val text = queue.pop()
val byteCount = text.utf8Size()

adapter.remote.sendString(
text,
object : WriteCallback {
override fun writeSuccess() {
outgoingQueueSize -= byteCount
}

override fun writeFailed(x: Throwable?) {
outgoingQueueSize -= byteCount
}
val message = queue.pop()
val byteCount = message.byteCount
val callback = object : WriteCallback {
override fun writeSuccess() {
outgoingQueueSize -= byteCount
}
)
}
}

override fun send(bytes: ByteString): Boolean {
TODO()
override fun writeFailed(x: Throwable?) {
outgoingQueueSize -= byteCount
}
}
when (message) {
is Message.Text -> {
adapter.remote.sendString(
message.text,
callback
)
}
is Message.Binary -> {
adapter.remote.sendBytes(
message.data.asByteBuffer(),
callback
)
}
}
}
}

override fun close(code: Int, reason: String?): Boolean {
Expand Down Expand Up @@ -167,7 +191,7 @@ internal class JettyWebSocket(
it.match(DispatchMechanism.WEBSOCKET, null, listOf(), httpCall.url)
}

val bestAction = candidateActions.sorted().firstOrNull() ?: return null
val bestAction = candidateActions.minOrNull() ?: return null
bestAction.action.scopeAndHandle(request.httpServletRequest, httpCall, bestAction.pathMatcher)
return realWebSocket.adapter
}
Expand Down
37 changes: 37 additions & 0 deletions misk/src/test/kotlin/misk/web/WebSocketsTest.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package misk.web

import com.squareup.protos.test.parsing.Warehouse
import misk.MiskTestingServiceModule
import misk.inject.KAbstractModule
import misk.logging.LogCollectorModule
Expand All @@ -19,6 +20,7 @@ import org.junit.jupiter.api.Test
import misk.logging.LogCollector
import jakarta.inject.Inject
import jakarta.inject.Singleton
import okio.ByteString

@MiskTest(startService = true)
internal class WebSocketsTest {
Expand Down Expand Up @@ -50,6 +52,30 @@ internal class WebSocketsTest {
)
}

@Test
fun binaryWebSocket() {
val client = OkHttpClient()

val request = Request.Builder()
.url(jettyService.httpServerUrl.resolve("/echo")!!)
.build()

val webSocket = client.newWebSocket(request, listener)
val warehouse = Warehouse.Builder()
.warehouse_token("WH_1")
.warehouse_id(42)
.build()

webSocket.send(warehouse.encodeByteString())

val expected = Warehouse.Builder()
.warehouse_token("ACK WH_1")
.warehouse_id(43)
.build()
val actual = listener.takeBinaryMessage()?.let { Warehouse.ADAPTER.decode(it) }
assertEquals(actual, expected)
}

@Test
fun loggingDisabledByEnv() {
val client = OkHttpClient()
Expand Down Expand Up @@ -87,6 +113,17 @@ class EchoWebSocket @Inject constructor() : WebAction {
webSocket.send("ACK $text")
}

override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
val message = Warehouse.ADAPTER.decode(bytes)
webSocket.send(
message.newBuilder()
.warehouse_token("ACK ${message.warehouse_token}")
.warehouse_id(message.warehouse_id + 1)
.build()
.encodeByteString()
)
}

override fun toString() = "EchoListener"
}
}
Expand Down