1
+ using System . Collections . Generic ;
2
+ using System . Threading . Tasks ;
3
+ using FluentAssertions . Analyzers . Utilities ;
4
+ using Microsoft . CodeAnalysis ;
5
+ using Microsoft . CodeAnalysis . CodeActions ;
6
+ using Microsoft . CodeAnalysis . CodeFixes ;
7
+ using Microsoft . CodeAnalysis . CSharp . Syntax ;
8
+ using Microsoft . CodeAnalysis . Operations ;
9
+ using CreateChangedDocument = System . Func < System . Threading . CancellationToken , System . Threading . Tasks . Task < Microsoft . CodeAnalysis . Document > > ;
10
+
11
+ namespace FluentAssertions . Analyzers ;
12
+
13
+ public abstract class TestingFrameworkCodeFixProvider : CodeFixProvider
14
+ {
15
+ protected const string Title = "Replace with FluentAssertions" ;
16
+
17
+ public override FixAllProvider GetFixAllProvider ( ) => WellKnownFixAllProviders . BatchFixer ;
18
+
19
+ public override async Task RegisterCodeFixesAsync ( CodeFixContext context )
20
+ {
21
+ var root = await context . Document . GetSyntaxRootAsync ( context . CancellationToken ) ;
22
+ var semanticModel = await context . Document . GetSemanticModelAsync ( context . CancellationToken ) ;
23
+
24
+ var testContext = new TestingFrameworkCodeFixContext ( semanticModel . Compilation ) ;
25
+ foreach ( var diagnostic in context . Diagnostics )
26
+ {
27
+ var node = root . FindNode ( diagnostic . Location . SourceSpan ) ;
28
+ if ( node is not InvocationExpressionSyntax invocationExpression )
29
+ {
30
+ continue ;
31
+ }
32
+
33
+ var operation = semanticModel . GetOperation ( invocationExpression , context . CancellationToken ) ;
34
+ if ( operation is not IInvocationOperation invocation )
35
+ {
36
+ continue ;
37
+ }
38
+
39
+ var fix = TryComputeFix ( invocation , context , testContext , diagnostic ) ;
40
+ if ( fix is not null )
41
+ {
42
+ context . RegisterCodeFix ( CodeAction . Create ( Title , fix , equivalenceKey : Title ) , diagnostic ) ;
43
+ }
44
+ }
45
+ }
46
+
47
+ protected abstract CreateChangedDocument TryComputeFix ( IInvocationOperation invocation , CodeFixContext context , TestingFrameworkCodeFixContext t , Diagnostic diagnostic ) ;
48
+
49
+ protected static bool ArgumentsAreTypeOf ( IInvocationOperation invocation , params ITypeSymbol [ ] types ) => ArgumentsAreTypeOf ( invocation , 0 , types ) ;
50
+ protected static bool ArgumentsAreTypeOf ( IInvocationOperation invocation , int startFromIndex , params ITypeSymbol [ ] types )
51
+ {
52
+ if ( invocation . TargetMethod . Parameters . Length != types . Length + startFromIndex )
53
+ {
54
+ return false ;
55
+ }
56
+
57
+ for ( int i = startFromIndex ; i < types . Length ; i ++ )
58
+ {
59
+ if ( ! invocation . TargetMethod . Parameters [ i ] . Type . EqualsSymbol ( types [ i ] ) )
60
+ {
61
+ return false ;
62
+ }
63
+ }
64
+
65
+ return true ;
66
+ }
67
+
68
+ protected static bool ArgumentsAreGenericTypeOf ( IInvocationOperation invocation , params ITypeSymbol [ ] types )
69
+ {
70
+ const int generics = 1 ;
71
+ if ( invocation . TargetMethod . Parameters . Length != types . Length )
72
+ {
73
+ return false ;
74
+ }
75
+
76
+ if ( invocation . TargetMethod . TypeArguments . Length != generics )
77
+ {
78
+ return false ;
79
+ }
80
+
81
+ var genericType = invocation . TargetMethod . TypeArguments [ 0 ] ;
82
+
83
+ for ( int i = 0 ; i < types . Length ; i ++ )
84
+ {
85
+ if ( invocation . TargetMethod . Parameters [ i ] . Type is not INamedTypeSymbol parameterType )
86
+ {
87
+ return false ;
88
+ }
89
+
90
+ if ( parameterType . TypeArguments . IsEmpty && parameterType . EqualsSymbol ( genericType ) )
91
+ {
92
+ continue ;
93
+ }
94
+
95
+ if ( parameterType . TypeArguments . Length != generics
96
+ || ! ( parameterType . TypeArguments [ 0 ] . EqualsSymbol ( genericType ) && parameterType . OriginalDefinition . EqualsSymbol ( types [ i ] ) ) )
97
+ {
98
+ return false ;
99
+ }
100
+ }
101
+
102
+ return true ;
103
+ }
104
+
105
+ protected static bool ArgumentsCount ( IInvocationOperation invocation , int arguments )
106
+ {
107
+ return invocation . TargetMethod . Parameters . Length == arguments ;
108
+ }
109
+
110
+ protected sealed class TestingFrameworkCodeFixContext ( Compilation compilation )
111
+ {
112
+ public INamedTypeSymbol Object { get ; } = compilation . ObjectType ;
113
+ public INamedTypeSymbol String { get ; } = compilation . GetTypeByMetadataName ( "System.String" ) ;
114
+ public INamedTypeSymbol Int32 { get ; } = compilation . GetTypeByMetadataName ( "System.Int32" ) ;
115
+ public INamedTypeSymbol Float { get ; } = compilation . GetTypeByMetadataName ( "System.Single" ) ;
116
+ public INamedTypeSymbol Double { get ; } = compilation . GetTypeByMetadataName ( "System.Double" ) ;
117
+ public INamedTypeSymbol Decimal { get ; } = compilation . GetTypeByMetadataName ( "System.Decimal" ) ;
118
+ public INamedTypeSymbol Boolean { get ; } = compilation . GetTypeByMetadataName ( "System.Boolean" ) ;
119
+ public INamedTypeSymbol Action { get ; } = compilation . GetTypeByMetadataName ( "System.Action" ) ;
120
+ public INamedTypeSymbol Type { get ; } = compilation . GetTypeByMetadataName ( "System.Type" ) ;
121
+ public INamedTypeSymbol DateTime { get ; } = compilation . GetTypeByMetadataName ( "System.DateTime" ) ;
122
+ public INamedTypeSymbol TimeSpan { get ; } = compilation . GetTypeByMetadataName ( "System.TimeSpan" ) ;
123
+ public INamedTypeSymbol FuncOfObject { get ; } = compilation . GetTypeByMetadataName ( "System.Func`1" ) . Construct ( compilation . ObjectType ) ;
124
+ public INamedTypeSymbol FuncOfTask { get ; } = compilation . GetTypeByMetadataName ( "System.Func`1" ) . Construct ( compilation . GetTypeByMetadataName ( "System.Threading.Tasks.Task" ) ) ;
125
+ public IArrayTypeSymbol ObjectArray { get ; } = compilation . CreateArrayTypeSymbol ( compilation . ObjectType ) ;
126
+ public INamedTypeSymbol CultureInfo { get ; } = compilation . GetTypeByMetadataName ( "System.Globalization.CultureInfo" ) ;
127
+ public INamedTypeSymbol StringComparison { get ; } = compilation . GetTypeByMetadataName ( "System.StringComparison" ) ;
128
+ public INamedTypeSymbol Regex { get ; } = compilation . GetTypeByMetadataName ( "System.Text.RegularExpressions.Regex" ) ;
129
+ public INamedTypeSymbol ICollection { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.ICollection" ) ;
130
+ public INamedTypeSymbol IComparer { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.IComparer" ) ;
131
+ public INamedTypeSymbol IEqualityComparerOfT1 { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEqualityComparer`1" ) ;
132
+ public INamedTypeSymbol IEnumerableOfT1 { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEnumerable`1" ) ;
133
+
134
+ public INamedTypeSymbol Identity { get ; } = null ;
135
+ }
136
+ }
0 commit comments