@@ -21,6 +21,8 @@ import io.reactivex.Single
21
21
import io.rsocket.kotlin.interceptors.GlobalInterceptors
22
22
import io.rsocket.kotlin.internal.*
23
23
import io.rsocket.kotlin.internal.fragmentation.FragmentationInterceptor
24
+ import io.rsocket.kotlin.internal.lease.ClientLeaseSupport
25
+ import io.rsocket.kotlin.internal.lease.ServerLeaseSupport
24
26
import io.rsocket.kotlin.transport.ClientTransport
25
27
import io.rsocket.kotlin.transport.ServerTransport
26
28
import io.rsocket.kotlin.util.AbstractRSocket
@@ -47,6 +49,7 @@ object RSocketFactory {
47
49
private var acceptor: ClientAcceptor = { { emptyRSocket } }
48
50
private var errorConsumer: (Throwable ) -> Unit = { it.printStackTrace() }
49
51
private var mtu = 0
52
+ private var leaseRefConsumer: ((LeaseRef ) -> Unit )? = null
50
53
private val interceptors = GlobalInterceptors .create()
51
54
private var flags = 0
52
55
private var setupPayload: Payload = DefaultPayload .EMPTY
@@ -76,6 +79,12 @@ object RSocketFactory {
76
79
return this
77
80
}
78
81
82
+ fun enableLease (leaseRefConsumer : (LeaseRef ) -> Unit ): ClientRSocketFactory {
83
+ this .flags = Frame .Setup .enableLease(flags)
84
+ this .leaseRefConsumer = leaseRefConsumer
85
+ return this
86
+ }
87
+
79
88
fun errorConsumer (errorConsumer : (Throwable ) -> Unit ): ClientRSocketFactory {
80
89
this .errorConsumer = errorConsumer
81
90
return this
@@ -93,49 +102,78 @@ object RSocketFactory {
93
102
}
94
103
95
104
fun transport (transport : () -> ClientTransport ): Start <RSocket > =
96
- ClientStart (transport, interceptors())
105
+ clientStart(acceptor, transport)
106
+
107
+ fun transport (transport : ClientTransport ): Start <RSocket > =
108
+ transport { transport }
97
109
98
110
fun acceptor (acceptor : ClientAcceptor ): ClientTransportAcceptor {
99
111
this .acceptor = acceptor
100
112
return object : ClientTransportAcceptor {
101
113
override fun transport (transport : () -> ClientTransport )
102
- : Start <RSocket > =
103
- ClientStart (transport, interceptors())
114
+ : Start <RSocket > = clientStart(acceptor, transport)
104
115
}
105
116
}
106
117
107
- private fun interceptors (): InterceptorRegistry =
108
- interceptors.copyWith {
109
- it.connectionFirst(
110
- FragmentationInterceptor (mtu))
111
- }
118
+ private fun clientStart (acceptor : ClientAcceptor ,
119
+ transport : () -> ClientTransport ): ClientStart =
112
120
113
- private inner class ClientStart (private val transportClient : () -> ClientTransport ,
114
- private val interceptors : InterceptorRegistry )
121
+ ClientStart (acceptor,
122
+ errorConsumer,
123
+ mtu,
124
+ leaseRefConsumer,
125
+ flags,
126
+ setupPayload,
127
+ keepAlive.copy(),
128
+ mediaType.copy(),
129
+ streamRequestLimit,
130
+ transport,
131
+ interceptors.copy())
132
+
133
+ private class ClientStart (
134
+ private val acceptor : ClientAcceptor ,
135
+ private val errorConsumer : (Throwable ) -> Unit ,
136
+ private var mtu : Int ,
137
+ private val leaseRef : ((LeaseRef ) -> Unit )? ,
138
+ private val flags : Int ,
139
+ private val setupPayload : Payload ,
140
+ private val keepAlive : KeepAlive ,
141
+ private val mediaType : MediaType ,
142
+ private val streamRequestLimit : Int ,
143
+ private val transportClient : () -> ClientTransport ,
144
+ private val parentInterceptors : InterceptorRegistry )
115
145
: Start <RSocket > {
116
146
117
147
override fun start (): Single <RSocket > {
118
148
return transportClient()
119
149
.connect()
120
150
.flatMap { connection ->
121
- val setupFrame = createSetupFrame()
151
+
152
+ val withLease =
153
+ enableLease(parentInterceptors)
154
+
155
+ val interceptors =
156
+ enableFragmentation(withLease)
157
+
158
+ val interceptConnection = interceptors as InterceptConnection
159
+ val interceptRSocket = interceptors as InterceptRSocket
122
160
123
161
val demuxer = ClientConnectionDemuxer (
124
162
connection,
125
- interceptors )
163
+ interceptConnection )
126
164
127
165
val rSocketRequester = RSocketRequester (
128
166
demuxer.requesterConnection(),
129
167
errorConsumer,
130
168
ClientStreamIds (),
131
169
streamRequestLimit)
132
170
133
- val wrappedRequester = interceptors
171
+ val wrappedRequester = interceptRSocket
134
172
.interceptRequester(rSocketRequester)
135
173
136
174
val handlerRSocket = acceptor()(wrappedRequester)
137
175
138
- val wrappedHandler = interceptors
176
+ val wrappedHandler = interceptRSocket
139
177
.interceptHandler(handlerRSocket)
140
178
141
179
RSocketResponder (
@@ -149,12 +187,21 @@ object RSocketFactory {
149
187
keepAlive,
150
188
errorConsumer)
151
189
190
+ val setupFrame = createSetupFrame()
191
+
152
192
connection
153
193
.sendOne(setupFrame)
154
194
.andThen(Single .just(wrappedRequester))
155
195
}
156
196
}
157
197
198
+ private fun enableFragmentation (parentInterceptors : InterceptorRegistry )
199
+ : InterceptorRegistry {
200
+ parentInterceptors.connectionFirst(
201
+ FragmentationInterceptor (mtu))
202
+ return parentInterceptors
203
+ }
204
+
158
205
private fun createSetupFrame (): Frame {
159
206
return Frame .Setup .from(
160
207
flags,
@@ -164,14 +211,23 @@ object RSocketFactory {
164
211
mediaType.dataMimeType(),
165
212
setupPayload)
166
213
}
214
+
215
+ private fun enableLease (parentInterceptors : InterceptorRegistry )
216
+ : InterceptorRegistry =
217
+ if (leaseRef != null ) {
218
+ parentInterceptors.copyWith(
219
+ ClientLeaseSupport .enable(leaseRef)())
220
+ } else {
221
+ parentInterceptors.copy()
222
+ }
167
223
}
168
224
}
169
225
170
226
class ServerRSocketFactory internal constructor() {
171
227
172
- private var acceptor: ServerAcceptor = { { _, _ -> Single .just(emptyRSocket) } }
173
228
private var errorConsumer: (Throwable ) -> Unit = { it.printStackTrace() }
174
229
private var mtu = 0
230
+ private var leaseRefConsumer: ((LeaseRef ) -> Unit )? = null
175
231
private val interceptors = GlobalInterceptors .create()
176
232
private var streamRequestLimit = defaultStreamRequestLimit
177
233
@@ -186,6 +242,11 @@ object RSocketFactory {
186
242
return this
187
243
}
188
244
245
+ fun enableLease (leaseRefConsumer : (LeaseRef ) -> Unit ): ServerRSocketFactory {
246
+ this .leaseRefConsumer = leaseRefConsumer
247
+ return this
248
+ }
249
+
189
250
fun errorConsumer (errorConsumer : (Throwable ) -> Unit ): ServerRSocketFactory {
190
251
this .errorConsumer = errorConsumer
191
252
return this
@@ -197,26 +258,28 @@ object RSocketFactory {
197
258
}
198
259
199
260
fun acceptor (acceptor : ServerAcceptor ): ServerTransportAcceptor {
200
- this .acceptor = acceptor
201
261
return object : ServerTransportAcceptor {
262
+
202
263
override fun <T : Closeable > transport (
203
264
transport : () -> ServerTransport <T >): Start <T > =
204
- ServerStart (transport, interceptors())
205
- }
206
- }
207
-
208
- private fun interceptors (): InterceptorRegistry {
209
- return interceptors.copyWith {
210
- it.connectionFirst(
211
- ServerContractInterceptor (errorConsumer))
212
- it.connectionFirst(
213
- FragmentationInterceptor (mtu))
265
+ ServerStart (transport,
266
+ acceptor,
267
+ errorConsumer,
268
+ mtu,
269
+ leaseRefConsumer,
270
+ interceptors.copy(),
271
+ streamRequestLimit)
214
272
}
215
273
}
216
274
217
- private inner class ServerStart <T : Closeable >(
275
+ private class ServerStart <T : Closeable >(
218
276
private val transportServer : () -> ServerTransport <T >,
219
- private val interceptors : InterceptorRegistry ) : Start<T> {
277
+ private val acceptor : ServerAcceptor ,
278
+ private val errorConsumer : (Throwable ) -> Unit ,
279
+ private val mtu : Int ,
280
+ private val leaseRef : ((LeaseRef ) -> Unit )? ,
281
+ private val parentInterceptors : InterceptorRegistry ,
282
+ private val streamRequestLimit : Int ) : Start<T> {
220
283
221
284
override fun start (): Single <T > {
222
285
return transportServer().start(object
@@ -225,25 +288,37 @@ object RSocketFactory {
225
288
override fun invoke (duplexConnection : DuplexConnection )
226
289
: Completable {
227
290
291
+ val withLease =
292
+ enableLease(parentInterceptors)
293
+
294
+ val withServerContract =
295
+ enableServerContract(withLease)
296
+
297
+ val interceptors =
298
+ enableFragmentation(withServerContract)
299
+
228
300
val demuxer = ServerConnectionDemuxer (
229
301
duplexConnection,
230
- interceptors)
302
+ interceptors as InterceptConnection )
231
303
232
304
return demuxer
233
305
.setupConnection()
234
306
.receive()
235
307
.firstOrError()
236
308
.flatMapCompletable { setup ->
237
- accept(setup, demuxer)
309
+ accept(setup,
310
+ interceptors as InterceptRSocket ,
311
+ demuxer)
238
312
}
239
313
}
240
314
})
241
315
}
242
316
243
317
private fun accept (setupFrame : Frame ,
318
+ interceptors : InterceptRSocket ,
244
319
demuxer : ConnectionDemuxer ): Completable {
245
320
246
- val setup = Setup .create(setupFrame)
321
+ val setup = SetupContents .create(setupFrame)
247
322
248
323
val rSocketRequester = RSocketRequester (
249
324
demuxer.requesterConnection(),
@@ -272,6 +347,30 @@ object RSocketFactory {
272
347
}
273
348
.ignoreElement()
274
349
}
350
+
351
+ private fun enableLease (parentInterceptors : InterceptorRegistry )
352
+ : InterceptorRegistry =
353
+ if (leaseRef != null ) {
354
+ parentInterceptors.copyWith(
355
+ ServerLeaseSupport .enable(leaseRef)())
356
+ } else {
357
+ parentInterceptors.copy()
358
+ }
359
+
360
+ private fun enableServerContract (parentInterceptors : InterceptorRegistry )
361
+ : InterceptorRegistry {
362
+
363
+ parentInterceptors.connectionFirst(
364
+ ServerContractInterceptor (errorConsumer, leaseRef != null ))
365
+ return parentInterceptors
366
+ }
367
+
368
+ private fun enableFragmentation (parentInterceptors : InterceptorRegistry )
369
+ : InterceptorRegistry {
370
+ parentInterceptors.connectionFirst(
371
+ FragmentationInterceptor (mtu))
372
+ return parentInterceptors
373
+ }
275
374
}
276
375
}
277
376
0 commit comments