Skip to content

Commit 6cfec1e

Browse files
committed
wip
1 parent 0365478 commit 6cfec1e

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

src/Compiler/Utilities/Async2.fs

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type IAsync2Invocation<'t> =
1717
and Async2<'t> =
1818
abstract StartImmediate: CancellationToken -> IAsync2Invocation<'t>
1919
abstract StartBound: CancellationToken -> TaskAwaiter<'t>
20-
abstract TailCall: CancellationToken * TaskCompletionSource<'t> voption -> unit
20+
abstract TailCall: CancellationToken * TaskCompletionSource<'t> -> unit
2121
abstract GetAwaiter: unit -> TaskAwaiter<'t>
2222

2323
module Async2Implementation =
@@ -50,32 +50,35 @@ module Async2Implementation =
5050
| Bounce of DynamicState
5151
| Immediate of DynamicState
5252

53-
module BindContext =
54-
let bindCount = new ThreadLocal<int>()
55-
56-
[<Literal>]
57-
let bindLimit = 100
58-
59-
let IncrementBindCount () =
60-
bindCount.Value <- bindCount.Value + 1
61-
bindCount.Value >= bindLimit
62-
63-
let Reset () = bindCount.Value <- 0
64-
6553
type Trampoline private () =
6654

6755
let ownerThreadId = Thread.CurrentThread.ManagedThreadId
6856

6957
static let holder = new ThreadLocal<_>(fun () -> Trampoline())
7058

59+
let mutable depth = 0
60+
61+
[<Literal>]
62+
let MaxDepth = 50
63+
64+
let insufficientStack () =
65+
depth <- depth + 1
66+
depth % MaxDepth = 0
67+
// if current.Value % MaxDepth = 0 then
68+
//#if NETSTANDARD2_0
69+
// try RuntimeHelpers.EnsureSufficientExecutionStack(); true with _ -> false
70+
//#else
71+
// RuntimeHelpers.TryEnsureSufficientExecutionStack()
72+
//#endif
73+
// else
74+
// true
75+
7176
let mutable pending: Action voption = ValueNone
7277
let mutable running = false
7378

74-
let start (action: Action) =
79+
let start () =
7580
try
76-
BindContext.Reset()
7781
running <- true
78-
action.Invoke()
7982

8083
while pending.IsSome do
8184
let next = pending.Value
@@ -88,19 +91,25 @@ module Async2Implementation =
8891
assert (Thread.CurrentThread.ManagedThreadId = ownerThreadId) // "Trampoline used from wrong thread"
8992
assert pending.IsNone // "Trampoline set while already pending"
9093

91-
BindContext.Reset()
94+
pending <- ValueSome action
9295

93-
if running then
94-
pending <- ValueSome action
95-
else
96-
start action
96+
if not running then
97+
start ()
9798

9899
interface ICriticalNotifyCompletion with
99100
member _.OnCompleted continuation = set continuation
100101
member _.UnsafeOnCompleted continuation = set continuation
101102

102103
member this.Ref: ICriticalNotifyCompletion ref = ref this
103104

105+
member _.Running = running
106+
107+
member _.IsStackSufficient() =
108+
depth <- depth + 1
109+
depth % MaxDepth <> 0
110+
111+
member _.ShouldBounce = pending.IsNone && insufficientStack ()
112+
104113
static member Current = holder.Value
105114

106115
module ExceptionCache =
@@ -139,32 +148,28 @@ module Async2Implementation =
139148
val mutable MethodBuilder: AsyncTaskMethodBuilder<'t>
140149

141150
[<DefaultValue(false)>]
142-
val mutable TailCallSource: TaskCompletionSource<'t> voption
151+
val mutable TailCallSource: TaskCompletionSource<'t> option
143152

144153
[<DefaultValue(false)>]
145154
val mutable CancellationToken: CancellationToken
146155

147-
[<DefaultValue(false)>]
148-
val mutable IsBound: bool
149-
150156
type Async2StateMachine<'TOverall> = ResumableStateMachine<Async2Data<'TOverall>>
151157
type IAsync2StateMachine<'TOverall> = IResumableStateMachine<Async2Data<'TOverall>>
152158
type Async2ResumptionFunc<'TOverall> = ResumptionFunc<Async2Data<'TOverall>>
153159
type Async2ResumptionDynamicInfo<'TOverall> = ResumptionDynamicInfo<Async2Data<'TOverall>>
154160

155161
type Async2Code<'TOverall, 'T> = ResumableCode<Async2Data<'TOverall>, 'T>
156162

157-
[<Struct; NoComparison>]
158-
type Async2<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>> =
163+
[<NoComparison>]
164+
type Async2<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>() =
159165
[<DefaultValue(false)>]
160166
val mutable StateMachine: 'm
161167

162-
member ts.Start(ct, tailCallSource, isBound) =
168+
member ts.Start(ct, ?tailCallSource) =
163169
let mutable copy = ts
164170
let mutable data = Async2Data()
165171
data.CancellationToken <- ct
166172
data.TailCallSource <- tailCallSource
167-
data.IsBound <- isBound
168173
data.MethodBuilder <- AsyncTaskMethodBuilder<'t>.Create()
169174
copy.StateMachine.Data <- data
170175
copy.StateMachine.Data.MethodBuilder.Start(&copy.StateMachine)
@@ -177,15 +182,15 @@ module Async2Implementation =
177182
ts.StateMachine.Data.MethodBuilder.Task.GetAwaiter()
178183

179184
interface Async2<'t> with
180-
member ts.StartImmediate ct = ts.Start(ct, ValueNone, false)
185+
member ts.StartImmediate ct = ts.Start(ct)
181186

182-
member ts.StartBound ct =
183-
ts.Start(ct, ValueNone, true).GetAwaiter()
187+
member ts.StartBound ct = ts.Start(ct).GetAwaiter()
184188

185-
member ts.TailCall(ct, tc) = ts.Start(ct, tc, true) |> ignore
189+
member ts.TailCall(ct, tc) =
190+
ts.Start(ct, tailCallSource = tc) |> ignore
186191

187192
member ts.GetAwaiter() =
188-
ts.Start(CancellationToken.None, ValueNone, true).GetAwaiter()
193+
ts.Start(CancellationToken.None).GetAwaiter()
189194

190195
type Async2Dynamic<'t, 'm when 'm :> IAsyncStateMachine and 'm :> IAsync2StateMachine<'t>>(getCopy: bool -> 'm) =
191196
member ts.GetCopy isBound =
@@ -196,8 +201,7 @@ module Async2Implementation =
196201

197202
member ts.StartBound ct = ts.GetCopy(true).StartBound(ct)
198203

199-
member ts.TailCall(ct, tc) =
200-
ts.GetCopy(true).TailCall(ct, tc) |> ignore
204+
member ts.TailCall(ct, tc) = ts.GetCopy(true).TailCall(ct, tc)
201205

202206
member ts.GetAwaiter() = ts.GetCopy(true).GetAwaiter()
203207

@@ -220,7 +224,7 @@ module Async2Implementation =
220224

221225
let inline yieldOnBindLimit () =
222226
Async2Code(fun sm ->
223-
if BindContext.IncrementBindCount() then
227+
if Trampoline.Current.ShouldBounce then
224228
let __stack_yield_fin = ResumableCode.Yield().Invoke(&sm)
225229

226230
if not __stack_yield_fin then
@@ -311,7 +315,7 @@ module Async2Implementation =
311315
let initialResumptionFunc = Async2ResumptionFunc<'T>(fun sm -> code.Invoke &sm)
312316

313317
let maybeBounce state =
314-
if BindContext.IncrementBindCount() then
318+
if Trampoline.Current.ShouldBounce then
315319
Bounce state
316320
else
317321
Immediate state
@@ -353,11 +357,11 @@ module Async2Implementation =
353357
sm.Data.MethodBuilder.AwaitOnCompleted(Trampoline.Current.Ref, &sm)
354358
| SetResult ->
355359
match sm.Data.TailCallSource with
356-
| ValueSome tcs -> tcs.SetResult sm.Data.Result
360+
| Some tcs -> tcs.SetResult sm.Data.Result
357361
| _ -> sm.Data.MethodBuilder.SetResult sm.Data.Result
358362
| SetException edi ->
359363
match sm.Data.TailCallSource with
360-
| ValueSome tcs -> tcs.TrySetException(edi.SourceException) |> ignore
364+
| Some tcs -> tcs.TrySetException(edi.SourceException) |> ignore
361365
| _ -> sm.Data.MethodBuilder.SetException(edi.SourceException)
362366

363367
member _.SetStateMachine(sm, state) =
@@ -375,7 +379,7 @@ module Async2Implementation =
375379

376380
let mutable error = ValueNone
377381

378-
let __stack_go1 = not sm.Data.IsBound || yieldOnBindLimit().Invoke(&sm)
382+
let __stack_go1 = yieldOnBindLimit().Invoke(&sm)
379383

380384
if __stack_go1 then
381385
try
@@ -386,7 +390,7 @@ module Async2Implementation =
386390

387391
if __stack_go2 then
388392
match sm.Data.TailCallSource with
389-
| ValueSome tcs -> tcs.SetResult sm.Data.Result
393+
| Some tcs -> tcs.SetResult sm.Data.Result
390394
| _ -> sm.Data.MethodBuilder.SetResult(sm.Data.Result)
391395
with exn ->
392396
error <- ValueSome(ExceptionCache.CaptureOrRetrieve exn)
@@ -396,7 +400,7 @@ module Async2Implementation =
396400

397401
if __stack_go2 then
398402
match sm.Data.TailCallSource with
399-
| ValueSome tcs -> tcs.SetException(error.Value.SourceException)
403+
| Some tcs -> tcs.SetException(error.Value.SourceException)
400404
| _ -> sm.Data.MethodBuilder.SetException(error.Value.SourceException)))
401405

402406
(SetStateMachineMethodImpl<_>(fun sm state -> sm.Data.MethodBuilder.SetStateMachine state))
@@ -451,15 +455,15 @@ module HighPriority =
451455
member inline this.ReturnFromFinal(code: Async2<'T>) =
452456
Async2Code(fun sm ->
453457
match sm.Data.TailCallSource with
454-
| ValueNone ->
458+
| None ->
455459
// This is the start of a tail call chain. we need to return here when the entire chain is done.
456460
let __stack_tcs = TaskCompletionSource<_>()
457-
code.TailCall(sm.Data.CancellationToken, ValueSome __stack_tcs)
461+
code.TailCall(sm.Data.CancellationToken, __stack_tcs)
458462
this.Bind(__stack_tcs.Task, this.Return).Invoke(&sm)
459-
| ValueSome tcs ->
463+
| Some tcs ->
460464
// We are already in a tail call chain.
461465
let __stack_ct = sm.Data.CancellationToken
462-
code.TailCall(__stack_ct, ValueSome tcs)
466+
code.TailCall(__stack_ct, tcs)
463467
false // Return false to abandon this state machine and continue on the next one.
464468
)
465469

0 commit comments

Comments
 (0)