@@ -365,7 +365,7 @@ public void AwaitOnCompleted<TAwaiter, TStateMachine>(
365365 {
366366 try
367367 {
368- awaiter . OnCompleted ( GetMoveNextDelegate ( ref stateMachine ) ) ;
368+ awaiter . OnCompleted ( GetStateMachineBox ( ref stateMachine ) . MoveNextAction ) ;
369369 }
370370 catch ( Exception e )
371371 {
@@ -384,10 +384,107 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
384384 ref TAwaiter awaiter , ref TStateMachine stateMachine )
385385 where TAwaiter : ICriticalNotifyCompletion
386386 where TStateMachine : IAsyncStateMachine
387+ {
388+ IAsyncStateMachineBox box = GetStateMachineBox ( ref stateMachine ) ;
389+
390+ // TODO https://github.com/dotnet/coreclr/issues/12877:
391+ // Once the JIT is able to recognize "awaiter is ITaskAwaiter" and "awaiter is IConfiguredTaskAwaiter",
392+ // use those in order to a) consolidate a lot of this code, and b) handle all Task/Task<T> and not just
393+ // the few types special-cased here. For now, handle common {Configured}TaskAwaiter. Having the types
394+ // explicitly listed here allows the JIT to generate the best code for them; otherwise we'll fall through
395+ // to the later workaround.
396+ if ( typeof ( TAwaiter ) == typeof ( TaskAwaiter ) ||
397+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < object > ) ||
398+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < string > ) ||
399+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < byte [ ] > ) ||
400+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < bool > ) ||
401+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < byte > ) ||
402+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < int > ) ||
403+ typeof ( TAwaiter ) == typeof ( TaskAwaiter < long > ) )
404+ {
405+ ref TaskAwaiter ta = ref Unsafe . As < TAwaiter , TaskAwaiter > ( ref awaiter ) ; // relies on TaskAwaiter/TaskAwaiter<T> having the same layout
406+ TaskAwaiter . UnsafeOnCompletedInternal ( ta . m_task , box , continueOnCapturedContext : true ) ;
407+ }
408+ else if (
409+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable . ConfiguredTaskAwaiter ) ||
410+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < object > . ConfiguredTaskAwaiter ) ||
411+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < string > . ConfiguredTaskAwaiter ) ||
412+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < byte [ ] > . ConfiguredTaskAwaiter ) ||
413+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < bool > . ConfiguredTaskAwaiter ) ||
414+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < byte > . ConfiguredTaskAwaiter ) ||
415+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < int > . ConfiguredTaskAwaiter ) ||
416+ typeof ( TAwaiter ) == typeof ( ConfiguredTaskAwaitable < long > . ConfiguredTaskAwaiter ) )
417+ {
418+ ref ConfiguredTaskAwaitable . ConfiguredTaskAwaiter ta = ref Unsafe . As < TAwaiter , ConfiguredTaskAwaitable . ConfiguredTaskAwaiter > ( ref awaiter ) ;
419+ TaskAwaiter . UnsafeOnCompletedInternal ( ta . m_task , box , ta . m_continueOnCapturedContext ) ;
420+ }
421+
422+ // Handle common {Configured}ValueTaskAwaiter<T> types. Unfortunately these need to be special-cased
423+ // individually, as we don't have good way to extract the task from a ValueTaskAwaiter<T> when we don't
424+ // know what the T is; we could make ValueTaskAwaiter<T> implement an IValueTaskAwaiter interface, but
425+ // calling a GetTask method on that would end up boxing the awaiter. This hard-coded list here is
426+ // somewhat arbitrary and is based on types currently in use with ValueTask<T> in coreclr/corefx.
427+ else if ( typeof ( TAwaiter ) == typeof ( ValueTaskAwaiter < int > ) )
428+ {
429+ var vta = ( ValueTaskAwaiter < int > ) ( object ) awaiter ;
430+ TaskAwaiter . UnsafeOnCompletedInternal ( vta . AsTask ( ) , box , continueOnCapturedContext : true ) ;
431+ }
432+ else if ( typeof ( TAwaiter ) == typeof ( ConfiguredValueTaskAwaitable < int > . ConfiguredValueTaskAwaiter ) )
433+ {
434+ var vta = ( ConfiguredValueTaskAwaitable < int > . ConfiguredValueTaskAwaiter ) ( object ) awaiter ;
435+ TaskAwaiter . UnsafeOnCompletedInternal ( vta . AsTask ( ) , box , vta . _continueOnCapturedContext ) ;
436+ }
437+ else if ( typeof ( TAwaiter ) == typeof ( ConfiguredValueTaskAwaitable < System . IO . Stream > . ConfiguredValueTaskAwaiter ) )
438+ {
439+ var vta = ( ConfiguredValueTaskAwaitable < System . IO . Stream > . ConfiguredValueTaskAwaiter ) ( object ) awaiter ;
440+ TaskAwaiter . UnsafeOnCompletedInternal ( vta . AsTask ( ) , box , vta . _continueOnCapturedContext ) ;
441+ }
442+ else if ( typeof ( TAwaiter ) == typeof ( ConfiguredValueTaskAwaitable < ArraySegment < byte > > . ConfiguredValueTaskAwaiter ) )
443+ {
444+ var vta = ( ConfiguredValueTaskAwaitable < ArraySegment < byte > > . ConfiguredValueTaskAwaiter ) ( object ) awaiter ;
445+ TaskAwaiter . UnsafeOnCompletedInternal ( vta . AsTask ( ) , box , vta . _continueOnCapturedContext ) ;
446+ }
447+ else if ( typeof ( TAwaiter ) == typeof ( ConfiguredValueTaskAwaitable < object > . ConfiguredValueTaskAwaiter ) )
448+ {
449+ var vta = ( ConfiguredValueTaskAwaitable < object > . ConfiguredValueTaskAwaiter ) ( object ) awaiter ;
450+ TaskAwaiter . UnsafeOnCompletedInternal ( vta . AsTask ( ) , box , vta . _continueOnCapturedContext ) ;
451+ }
452+
453+ // To catch all Task/Task<T> awaits, do the currently more expensive interface checks.
454+ // Eventually these and the above Task/Task<T> checks should be replaced by "is" checks,
455+ // once that's recognized and optimized by the JIT. We do these after all of the hardcoded
456+ // checks above so that they don't incur the costs of these checks.
457+ else if ( InterfaceIsCheckWorkaround < TAwaiter > . IsITaskAwaiter )
458+ {
459+ ref TaskAwaiter ta = ref Unsafe . As < TAwaiter , TaskAwaiter > ( ref awaiter ) ;
460+ TaskAwaiter . UnsafeOnCompletedInternal ( ta . m_task , box , continueOnCapturedContext : true ) ;
461+ }
462+ else if ( InterfaceIsCheckWorkaround < TAwaiter > . IsIConfiguredTaskAwaiter )
463+ {
464+ ref ConfiguredTaskAwaitable . ConfiguredTaskAwaiter ta = ref Unsafe . As < TAwaiter , ConfiguredTaskAwaitable . ConfiguredTaskAwaiter > ( ref awaiter ) ;
465+ TaskAwaiter . UnsafeOnCompletedInternal ( ta . m_task , box , ta . m_continueOnCapturedContext ) ;
466+ }
467+
468+ // The awaiter isn't specially known. Fall back to doing a normal await.
469+ else
470+ {
471+ // TODO https://github.com/dotnet/coreclr/issues/14177:
472+ // Move the code back into this method once the JIT is able to
473+ // elide it successfully when one of the previous branches is hit.
474+ AwaitArbitraryAwaiterUnsafeOnCompleted ( ref awaiter , box ) ;
475+ }
476+ }
477+
478+ /// <summary>Schedules the specified state machine to be pushed forward when the specified awaiter completes.</summary>
479+ /// <typeparam name="TAwaiter">Specifies the type of the awaiter.</typeparam>
480+ /// <param name="awaiter">The awaiter.</param>
481+ /// <param name="box">The state machine box.</param>
482+ private static void AwaitArbitraryAwaiterUnsafeOnCompleted < TAwaiter > ( ref TAwaiter awaiter , IAsyncStateMachineBox box )
483+ where TAwaiter : ICriticalNotifyCompletion
387484 {
388485 try
389486 {
390- awaiter . UnsafeOnCompleted ( GetMoveNextDelegate ( ref stateMachine ) ) ;
487+ awaiter . UnsafeOnCompleted ( box . MoveNextAction ) ;
391488 }
392489 catch ( Exception e )
393490 {
@@ -399,7 +496,7 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
399496 /// <typeparam name="TStateMachine">Specifies the type of the async state machine.</typeparam>
400497 /// <param name="stateMachine">The state machine.</param>
401498 /// <returns>The "boxed" state machine.</returns>
402- private Action GetMoveNextDelegate < TStateMachine > (
499+ private IAsyncStateMachineBox GetStateMachineBox < TStateMachine > (
403500 ref TStateMachine stateMachine )
404501 where TStateMachine : IAsyncStateMachine
405502 {
@@ -416,7 +513,7 @@ private Action GetMoveNextDelegate<TStateMachine>(
416513 {
417514 stronglyTypedBox . Context = currentContext ;
418515 }
419- return stronglyTypedBox . MoveNextAction ;
516+ return stronglyTypedBox ;
420517 }
421518
422519 // The least common case: we have a weakly-typed boxed. This results if the debugger
@@ -440,7 +537,7 @@ private Action GetMoveNextDelegate<TStateMachine>(
440537 // Update the context. This only happens with a debugger, so no need to spend
441538 // extra IL checking for equality before doing the assignment.
442539 weaklyTypedBox . Context = currentContext ;
443- return weaklyTypedBox . MoveNextAction ;
540+ return weaklyTypedBox ;
444541 }
445542
446543 // Alert a listening debugger that we can't make forward progress unless it slips threads.
@@ -462,34 +559,33 @@ private Action GetMoveNextDelegate<TStateMachine>(
462559 m_task = box ; // important: this must be done before storing stateMachine into box.StateMachine!
463560 box . StateMachine = stateMachine ;
464561 box . Context = currentContext ;
465- return box . MoveNextAction ;
562+ return box ;
466563 }
467564
468565 /// <summary>A strongly-typed box for Task-based async state machines.</summary>
469566 /// <typeparam name="TStateMachine">Specifies the type of the state machine.</typeparam>
470567 /// <typeparam name="TResult">Specifies the type of the Task's result.</typeparam>
471568 private sealed class AsyncStateMachineBox < TStateMachine > :
472- Task < TResult > , IDebuggingAsyncStateMachineAccessor
569+ Task < TResult > , IAsyncStateMachineBox
473570 where TStateMachine : IAsyncStateMachine
474571 {
475572 /// <summary>Delegate used to invoke on an ExecutionContext when passed an instance of this box type.</summary>
476573 private static readonly ContextCallback s_callback = s => ( ( AsyncStateMachineBox < TStateMachine > ) s ) . StateMachine . MoveNext ( ) ;
477574
478575 /// <summary>A delegate to the <see cref="MoveNext"/> method.</summary>
479- public readonly Action MoveNextAction ;
576+ private Action _moveNextAction ;
480577 /// <summary>The state machine itself.</summary>
481578 public TStateMachine StateMachine ; // mutable struct; do not make this readonly
482579 /// <summary>Captured ExecutionContext with which to invoke <see cref="MoveNextAction"/>; may be null.</summary>
483580 public ExecutionContext Context ;
484581
485- public AsyncStateMachineBox ( )
486- {
487- var mn = new Action ( MoveNext ) ;
488- MoveNextAction = AsyncCausalityTracer . LoggingOn ? AsyncMethodBuilderCore . OutputAsyncCausalityEvents ( this , mn ) : mn ;
489- }
582+ /// <summary>A delegate to the <see cref="MoveNext"/> method.</summary>
583+ public Action MoveNextAction =>
584+ _moveNextAction ??
585+ ( _moveNextAction = AsyncCausalityTracer . LoggingOn ? AsyncMethodBuilderCore . OutputAsyncCausalityEvents ( this , new Action ( MoveNext ) ) : new Action ( MoveNext ) ) ;
490586
491- /// <summary>Call MoveNext on <see cref="StateMachine"/>. </summary>
492- private void MoveNext ( )
587+ /// <summary>Calls MoveNext on <see cref="StateMachine"/></summary>
588+ public void MoveNext ( )
493589 {
494590 if ( Context == null )
495591 {
@@ -501,8 +597,19 @@ private void MoveNext()
501597 }
502598 }
503599
600+ /// <summary>
601+ /// Calls MoveNext on <see cref="StateMachine"/>. Implements ITaskCompletionAction.Invoke so
602+ /// that the state machine object may be queued directly as a continuation into a Task's
603+ /// continuation slot/list.
604+ /// </summary>
605+ /// <param name="completedTask">The completing task that caused this method to be invoked, if there was one.</param>
606+ void ITaskCompletionAction . Invoke ( Task completedTask ) => MoveNext ( ) ;
607+
608+ /// <summary>Signals to Task's continuation logic that <see cref="Invoke"/> runs arbitrary user code via MoveNext.</summary>
609+ bool ITaskCompletionAction . InvokeMayRunArbitraryCode => true ;
610+
504611 /// <summary>Gets the state machine as a boxed object. This should only be used for debugging purposes.</summary>
505- IAsyncStateMachine IDebuggingAsyncStateMachineAccessor . GetStateMachineObject ( ) => StateMachine ; // likely boxes, only use for debugging
612+ IAsyncStateMachine IAsyncStateMachineBox . GetStateMachineObject ( ) => StateMachine ; // likely boxes, only use for debugging
506613 }
507614
508615 /// <summary>Gets the <see cref="System.Threading.Tasks.Task{TResult}"/> for this builder.</summary>
@@ -815,12 +922,24 @@ internal static Task<TResult> CreateCacheableTask<TResult>(TResult result) =>
815922 new Task < TResult > ( false , result , ( TaskCreationOptions ) InternalTaskOptions . DoNotDispose , default ( CancellationToken ) ) ;
816923 }
817924
925+ /// <summary>Temporary workaround for https://github.com/dotnet/coreclr/issues/12877.</summary>
926+ internal static class InterfaceIsCheckWorkaround < TAwaiter >
927+ {
928+ internal static readonly bool IsITaskAwaiter = typeof ( TAwaiter ) . GetInterface ( "ITaskAwaiter" ) != null ;
929+ internal static readonly bool IsIConfiguredTaskAwaiter = typeof ( TAwaiter ) . GetInterface ( "IConfiguredTaskAwaiter" ) != null ;
930+ }
931+
818932 /// <summary>
819- /// An interface implemented by <see cref="AsyncStateMachineBox{TStateMachine, TResult}"/> to allow access
820- /// non-generically to state associated with a builder and state machine.
933+ /// An interface implemented by all <see cref="AsyncStateMachineBox{TStateMachine, TResult}"/> instances, regardless of generics.
821934 /// </summary>
822- interface IDebuggingAsyncStateMachineAccessor
935+ interface IAsyncStateMachineBox : ITaskCompletionAction
823936 {
937+ /// <summary>
938+ /// Gets an action for moving forward the contained state machine.
939+ /// This will lazily-allocate the delegate as needed.
940+ /// </summary>
941+ Action MoveNextAction { get ; }
942+
824943 /// <summary>Gets the state machine as a boxed object. This should only be used for debugging purposes.</summary>
825944 IAsyncStateMachine GetStateMachineObject ( ) ;
826945 }
@@ -843,7 +962,7 @@ internal static Action TryGetStateMachineForDebugger(Action action) // debugger
843962 {
844963 object target = action . Target ;
845964 return
846- target is IDebuggingAsyncStateMachineAccessor sm ? sm . GetStateMachineObject ( ) . MoveNext :
965+ target is IAsyncStateMachineBox sm ? sm . GetStateMachineObject ( ) . MoveNext :
847966 target is ContinuationWrapper cw ? TryGetStateMachineForDebugger ( cw . _continuation ) :
848967 action ;
849968 }
0 commit comments