3
3
using System . Reflection ;
4
4
using System . Runtime . CompilerServices ;
5
5
using System . Runtime . ExceptionServices ;
6
+ using AngleSharp . Dom ;
6
7
using Microsoft . Extensions . Logging ;
7
8
8
9
namespace Bunit . Rendering ;
@@ -26,6 +27,7 @@ public sealed class BunitRenderer : Renderer
26
27
27
28
private readonly HashSet < int > returnedRenderedComponentIds = new ( ) ;
28
29
private readonly List < BunitRootComponent > rootComponents = new ( ) ;
30
+ private readonly Dictionary < string , int > elementReferenceToComponentId = new ( ) ;
29
31
private readonly ILogger < BunitRenderer > logger ;
30
32
private bool disposed ;
31
33
private TaskCompletionSource < Exception > unhandledExceptionTsc = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
@@ -453,6 +455,7 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch)
453
455
var id = renderBatch . DisposedComponentIDs . Array [ i ] ;
454
456
disposedComponentIds . Add ( id ) ;
455
457
returnedRenderedComponentIds . Remove ( id ) ;
458
+ RemoveElementReferencesForComponent ( id ) ;
456
459
}
457
460
458
461
for ( var i = 0 ; i < renderBatch . UpdatedComponents . Count ; i ++ )
@@ -467,6 +470,8 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch)
467
470
var componentState = GetComponentState ( diff . ComponentId ) ;
468
471
var renderedComponent = ( IRenderedComponent ) componentState ;
469
472
473
+ TrackElementReferencesForComponent ( diff . ComponentId ) ;
474
+
470
475
if ( returnedRenderedComponentIds . Contains ( diff . ComponentId ) )
471
476
{
472
477
renderedComponent . UpdateState ( hasRendered : true , isMarkupGenerationRequired : diff . Edits . Count > 0 ) ;
@@ -519,6 +524,101 @@ static bool IsParentComponentAlreadyUpdated(int componentId, in RenderBatch rend
519
524
}
520
525
}
521
526
527
+ private void TrackElementReferencesForComponent ( int componentId )
528
+ {
529
+ var frames = GetCurrentRenderTreeFrames ( componentId ) ;
530
+ TrackElementReferencesInFrames ( frames , componentId ) ;
531
+ }
532
+
533
+ private void TrackElementReferencesInFrames ( ArrayRange < RenderTreeFrame > frames , int componentId )
534
+ {
535
+ for ( var i = 0 ; i < frames . Count ; i ++ )
536
+ {
537
+ ref var frame = ref frames . Array [ i ] ;
538
+
539
+ if ( frame . FrameType == RenderTreeFrameType . ElementReferenceCapture )
540
+ {
541
+ var elementReferenceId = frame . ElementReferenceCaptureId ;
542
+ if ( elementReferenceId != null )
543
+ {
544
+ elementReferenceToComponentId [ elementReferenceId ] = componentId ;
545
+ }
546
+ }
547
+ else if ( frame . FrameType == RenderTreeFrameType . Component )
548
+ {
549
+ TrackElementReferencesForComponent ( frame . ComponentId ) ;
550
+ }
551
+ }
552
+ }
553
+
554
+ private void RemoveElementReferencesForComponent ( int componentId )
555
+ {
556
+ var keysToRemove = elementReferenceToComponentId
557
+ . Where ( kvp => kvp . Value == componentId )
558
+ . Select ( kvp => kvp . Key )
559
+ . ToList ( ) ;
560
+
561
+ foreach ( var key in keysToRemove )
562
+ {
563
+ elementReferenceToComponentId . Remove ( key ) ;
564
+ }
565
+ }
566
+
567
+ internal IRenderedComponent < TComponent > ? FindComponentForElement < TComponent > ( IElement element )
568
+ where TComponent : IComponent
569
+ {
570
+ var elementReferenceId = element . GetAttribute ( "blazor:elementReference" ) ;
571
+ if ( elementReferenceId is not null && elementReferenceToComponentId . TryGetValue ( elementReferenceId , out var componentId ) )
572
+ {
573
+ return GetRenderedComponent < TComponent > ( componentId ) ;
574
+ }
575
+
576
+ return FindComponentByElementContainment < TComponent > ( element ) ;
577
+ }
578
+
579
+ private IRenderedComponent < TComponent > ? FindComponentByElementContainment < TComponent > ( IElement element )
580
+ where TComponent : IComponent
581
+ {
582
+ List < int > renderedComponentIdsWhenStarted = [ ..returnedRenderedComponentIds ] ;
583
+ var components = new List < IRenderedComponent < TComponent > > ( returnedRenderedComponentIds . Count ) ;
584
+
585
+ foreach ( var parentComponent in renderedComponentIdsWhenStarted . Select ( GetRenderedComponent < IComponent > ) )
586
+ {
587
+ components . AddRange ( FindComponents < TComponent > ( parentComponent ) ) ;
588
+ }
589
+
590
+ return components . FirstOrDefault ( component => ComponentContainsElement ( component , element ) ) ;
591
+ }
592
+
593
+ private static bool ComponentContainsElement < TComponent > ( IRenderedComponent < TComponent > component , IElement element )
594
+ where TComponent : IComponent
595
+ {
596
+ foreach ( var node in component . Nodes )
597
+ {
598
+ if ( node is IElement nodeElement && nodeElement . Equals ( element ) )
599
+ {
600
+ return true ;
601
+ }
602
+ if ( IsDescendantOf ( element , node ) )
603
+ {
604
+ return true ;
605
+ }
606
+ }
607
+ return false ;
608
+ }
609
+
610
+ private static bool IsDescendantOf ( IElement element , INode potentialAncestor )
611
+ {
612
+ var current = element . Parent ;
613
+ while ( current is not null )
614
+ {
615
+ if ( current == potentialAncestor )
616
+ return true ;
617
+ current = current . Parent ;
618
+ }
619
+ return false ;
620
+ }
621
+
522
622
/// <inheritdoc/>
523
623
internal new ArrayRange < RenderTreeFrame > GetCurrentRenderTreeFrames ( int componentId )
524
624
=> base . GetCurrentRenderTreeFrames ( componentId ) ;
0 commit comments