diff --git a/Sources/Spyable/Spyable.swift b/Sources/Spyable/Spyable.swift index c5feb99..7d7619a 100644 --- a/Sources/Spyable/Spyable.swift +++ b/Sources/Spyable/Spyable.swift @@ -134,7 +134,11 @@ /// - The generated spy class name is suffixed with `Spy` (e.g., `ServiceProtocolSpy`). /// @attached(peer, names: suffixed(Spy)) -public macro Spyable(behindPreprocessorFlag: String? = nil, accessLevel: SpyAccessLevel? = nil) = +public macro Spyable( + behindPreprocessorFlag: String? = nil, + accessLevel: SpyAccessLevel? = nil, + inheritedType: String? = nil +) = #externalMacro( module: "SpyableMacro", type: "SpyableMacro" diff --git a/Sources/SpyableMacro/Diagnostics/SpyableDiagnostic.swift b/Sources/SpyableMacro/Diagnostics/SpyableDiagnostic.swift index daade36..b986426 100644 --- a/Sources/SpyableMacro/Diagnostics/SpyableDiagnostic.swift +++ b/Sources/SpyableMacro/Diagnostics/SpyableDiagnostic.swift @@ -14,6 +14,7 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error { case behindPreprocessorFlagArgumentRequiresStaticStringLiteral case accessLevelArgumentRequiresMemberAccessExpression case accessLevelArgumentUnsupportedAccessLevel + case inheritedTypeArgumentRequiresStaticStringLiteral /// Provides a human-readable diagnostic message for each diagnostic case. var message: String { @@ -30,6 +31,8 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error { "The `accessLevel` argument requires a member access expression" case .accessLevelArgumentUnsupportedAccessLevel: "The `accessLevel` argument does not support the specified access level" + case .inheritedTypeArgumentRequiresStaticStringLiteral: + "The `inheritedType` argument requires a static string literal" } } @@ -41,7 +44,8 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error { .variableDeclInProtocolWithNotIdentifierPattern, .behindPreprocessorFlagArgumentRequiresStaticStringLiteral, .accessLevelArgumentRequiresMemberAccessExpression, - .accessLevelArgumentUnsupportedAccessLevel: + .accessLevelArgumentUnsupportedAccessLevel, + .inheritedTypeArgumentRequiresStaticStringLiteral: .error } } diff --git a/Sources/SpyableMacro/Extractors/Extractor.swift b/Sources/SpyableMacro/Extractors/Extractor.swift index 48aeefb..5f94b66 100644 --- a/Sources/SpyableMacro/Extractors/Extractor.swift +++ b/Sources/SpyableMacro/Extractors/Extractor.swift @@ -106,6 +106,9 @@ struct Extractor { let accessLevelText = memberAccess.declName.baseName.text switch accessLevelText { + case "open": + return DeclModifierSyntax(name: .keyword(.open)) + case "public": return DeclModifierSyntax(name: .keyword(.public)) @@ -145,6 +148,59 @@ struct Extractor { func extractAccessLevel(from protocolDeclSyntax: ProtocolDeclSyntax) -> DeclModifierSyntax? { protocolDeclSyntax.modifiers.first(where: \.name.isAccessLevelSupportedInProtocol) } + + /// Extracts an inherited type value from an attribute if present and valid. + /// + /// This method searches for an argument labeled `inheritedType` within the + /// given attribute. If the argument is found, its value is validated to ensure it is + /// a static string literal. + /// + /// - Parameters: + /// - attribute: The attribute syntax to analyze. + /// - context: The macro expansion context in which the operation is performed. + /// - Returns: The static string literal value of the `inheritedType` argument, + /// or `nil` if the argument is missing or invalid. + /// - Note: Diagnoses an error if the argument value is not a static string literal. + func extractInheritedType( + from attribute: AttributeSyntax, + in context: some MacroExpansionContext + ) -> String? { + guard case let .argumentList(argumentList) = attribute.arguments else { + // No arguments are present in the attribute. + return nil + } + + let inheritedTypeArgument = argumentList.first { argument in + argument.label?.text == "inheritedType" + } + + guard let inheritedTypeArgument else { + // The `inheritedType` argument is missing. + return nil + } + + // Check if it's a string literal expression + let segments = inheritedTypeArgument.expression + .as(StringLiteralExprSyntax.self)? + .segments + + guard let segments, + segments.count == 1, + case let .stringSegment(literalSegment)? = segments.first + else { + // The `inheritedType` argument's value is not a valid string literal. + context.diagnose( + Diagnostic( + node: attribute, + message: SpyableDiagnostic.inheritedTypeArgumentRequiresStaticStringLiteral, + highlights: [Syntax(inheritedTypeArgument.expression)] + ) + ) + return nil + } + + return literalSegment.content.text + } } extension TokenSyntax { diff --git a/Sources/SpyableMacro/Factories/SpyFactory.swift b/Sources/SpyableMacro/Factories/SpyFactory.swift index 2da670b..c05c2f2 100644 --- a/Sources/SpyableMacro/Factories/SpyFactory.swift +++ b/Sources/SpyableMacro/Factories/SpyFactory.swift @@ -92,7 +92,10 @@ struct SpyFactory { private let closureFactory = ClosureFactory() private let functionImplementationFactory = FunctionImplementationFactory() - func classDeclaration(for protocolDeclaration: ProtocolDeclSyntax) throws -> ClassDeclSyntax { + func classDeclaration( + for protocolDeclaration: ProtocolDeclSyntax, + inheritedType: String? = nil + ) throws -> ClassDeclSyntax { let identifier = TokenSyntax.identifier(protocolDeclaration.name.text + "Spy") let assosciatedtypeDeclarations = protocolDeclaration.memberBlock.members.compactMap { @@ -117,6 +120,14 @@ struct SpyFactory { name: identifier, genericParameterClause: genericParameterClause, inheritanceClause: InheritanceClauseSyntax { + // Add inherited type first if present + if let inheritedType { + InheritedTypeSyntax( + type: TypeSyntax(stringLiteral: inheritedType) + ) + } + + // Add the main protocol InheritedTypeSyntax( type: TypeSyntax(stringLiteral: protocolDeclaration.name.text) ) @@ -125,7 +136,10 @@ struct SpyFactory { ) }, memberBlockBuilder: { + let initOverrideKeyword: DeclModifierListSyntax = inheritedType != nil ? [DeclModifierSyntax(name: .keyword(.override))] : [] + InitializerDeclSyntax( + modifiers: initOverrideKeyword, signature: FunctionSignatureSyntax( parameterClause: FunctionParameterClauseSyntax(parameters: []) ), diff --git a/Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift b/Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift index b2552f8..32ebd8d 100644 --- a/Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift +++ b/Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift @@ -17,8 +17,19 @@ final class AccessLevelModifierRewriter: SyntaxRewriter { return node } - return DeclModifierListSyntax { - newAccessLevel + // Always preserve existing modifiers (like override, convenience, etc.) + var modifiers = Array(node) + + // Special case: if accessLevel is open and this is an initializer, use public instead + if newAccessLevel.name.text == TokenSyntax.keyword(.open).text, + let parent = node.parent, + parent.is(InitializerDeclSyntax.self) { + modifiers.append(DeclModifierSyntax(name: .keyword(.public))) + } else { + // Add the access level modifier for all other cases + modifiers.append(newAccessLevel) } + + return DeclModifierListSyntax(modifiers) } } diff --git a/Sources/SpyableMacro/Macro/SpyableMacro.swift b/Sources/SpyableMacro/Macro/SpyableMacro.swift index bf4a083..17667cb 100644 --- a/Sources/SpyableMacro/Macro/SpyableMacro.swift +++ b/Sources/SpyableMacro/Macro/SpyableMacro.swift @@ -13,8 +13,14 @@ public enum SpyableMacro: PeerMacro { // Extract the protocol declaration let protocolDeclaration = try extractor.extractProtocolDeclaration(from: declaration) - // Generate the initial spy class declaration - var spyClassDeclaration = try spyFactory.classDeclaration(for: protocolDeclaration) + // Extract inherited type from the attribute + let inheritedType = extractor.extractInheritedType(from: node, in: context) + + // Generate the initial spy class declaration with inherited type + var spyClassDeclaration = try spyFactory.classDeclaration( + for: protocolDeclaration, + inheritedType: inheritedType + ) // Apply access level modifiers if needed if let accessLevel = determineAccessLevel( diff --git a/Tests/SpyableMacroTests/Extractors/UT_Extractor.swift b/Tests/SpyableMacroTests/Extractors/UT_Extractor.swift index 179f444..62bd6f0 100644 --- a/Tests/SpyableMacroTests/Extractors/UT_Extractor.swift +++ b/Tests/SpyableMacroTests/Extractors/UT_Extractor.swift @@ -1,9 +1,18 @@ import SwiftSyntax +import SwiftSyntaxMacros +import SwiftDiagnostics import XCTest @testable import SpyableMacro final class UT_Extractor: XCTestCase { + private var mockContext: MockMacroExpansionContext! + + override func setUp() { + super.setUp() + mockContext = MockMacroExpansionContext() + } + func testExtractProtocolDeclarationSuccessfully() throws { let declaration = DeclSyntax( """ @@ -29,4 +38,138 @@ final class UT_Extractor: XCTestCase { let unwrappedReceivedError = try XCTUnwrap(receivedError as? SpyableDiagnostic) XCTAssertEqual(unwrappedReceivedError, .onlyApplicableToProtocol) } + + // MARK: - extractInheritedType Tests + + func test_extractInheritedType_withValidStringLiteral_returnsValue() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable(inheritedType: "BaseClass") + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertEqual(result, "BaseClass") + XCTAssertTrue(mockContext.diagnostics.isEmpty) + } + + func test_extractInheritedType_withNoArguments_returnsNil() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertNil(result) + XCTAssertTrue(mockContext.diagnostics.isEmpty) + } + + func test_extractInheritedType_withMissingInheritedTypeArgument_returnsNil() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable(accessLevel: .public) + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertNil(result) + XCTAssertTrue(mockContext.diagnostics.isEmpty) + } + + func test_extractInheritedType_withNonStringLiteral_returnsNilAndDiagnoses() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable(inheritedType: someVariable) + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertNil(result) + XCTAssertEqual(mockContext.diagnostics.count, 1) + XCTAssertEqual( + mockContext.diagnostics.first?.message, + SpyableDiagnostic.inheritedTypeArgumentRequiresStaticStringLiteral.message + ) + } + + func test_extractInheritedType_withEmptyString_returnsEmptyString() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable(inheritedType: "") + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertEqual(result, "") + XCTAssertTrue(mockContext.diagnostics.isEmpty) + } + + func test_extractInheritedType_withComplexClassName_returnsValue() { + // Given + let attribute = AttributeSyntax( + """ + @Spyable(inheritedType: "MyModule.BaseClass") + """ + ) + + // When + let result = Extractor().extractInheritedType(from: attribute, in: mockContext) + + // Then + XCTAssertEqual(result, "MyModule.BaseClass") + XCTAssertTrue(mockContext.diagnostics.isEmpty) + } +} + +// MARK: - Mock Context + +private class MockMacroExpansionContext: MacroExpansionContext { + var diagnostics: [Diagnostic] = [] + + func diagnose(_ diagnostic: Diagnostic) { + diagnostics.append(diagnostic) + } + + func location( + of node: Node, + at position: PositionInSyntaxNode, + filePathMode: SourceLocationFilePathMode + ) -> AbstractSourceLocation? { + return nil + } + + func location( + for token: TokenSyntax, + at position: PositionInSyntaxNode, + filePathMode: SourceLocationFilePathMode + ) -> AbstractSourceLocation? { + return nil + } + + var lexicalContext: [Syntax] = [] + + func makeUniqueName(_ providedName: String) -> TokenSyntax { + return TokenSyntax.identifier(providedName) + } } diff --git a/Tests/SpyableMacroTests/Macro/UT_AccessLevelModifierRewriter.swift b/Tests/SpyableMacroTests/Macro/UT_AccessLevelModifierRewriter.swift new file mode 100644 index 0000000..30c5f62 --- /dev/null +++ b/Tests/SpyableMacroTests/Macro/UT_AccessLevelModifierRewriter.swift @@ -0,0 +1,261 @@ +import SwiftSyntax +import SwiftSyntaxBuilder +import XCTest + +@testable import SpyableMacro + +final class UT_AccessLevelModifierRewriter: XCTestCase { + + // MARK: - Normal Access Level Rewriting Tests + + func test_visit_withPublicAccessLevel_shouldAddPublicModifier() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.public)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalFunction = FunctionDeclSyntax( + modifiers: [], + name: .identifier("testFunction"), + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenFunction = rewriter.visit(originalFunction).as(FunctionDeclSyntax.self)! + + // Then + let modifiers = rewrittenFunction.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "public") + } + + func test_visit_withInternalAccessLevel_shouldAddInternalModifier() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.internal)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalVariable = VariableDeclSyntax( + modifiers: [], + bindingSpecifier: .keyword(.var) + ) { + PatternBindingSyntax( + pattern: IdentifierPatternSyntax(identifier: .identifier("testVar")), + typeAnnotation: TypeAnnotationSyntax(type: TypeSyntax("String")) + ) + } + + // When + let rewrittenVariable = rewriter.visit(originalVariable).as(VariableDeclSyntax.self)! + + // Then + let modifiers = rewrittenVariable.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "internal") + } + + func test_visit_withPrivateAccessLevel_shouldConvertToFileprivate() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.private)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalFunction = FunctionDeclSyntax( + modifiers: [], + name: .identifier("testFunction"), + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenFunction = rewriter.visit(originalFunction).as(FunctionDeclSyntax.self)! + + // Then + let modifiers = rewrittenFunction.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "fileprivate") + } + + // MARK: - Initializer Special Case Tests + + func test_visit_withOpenAccessLevel_onInitializer_shouldUsePublicInstead() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.open)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalInit = InitializerDeclSyntax( + modifiers: [], + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenInit = rewriter.visit(originalInit).as(InitializerDeclSyntax.self)! + + // Then + let modifiers = rewrittenInit.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "public") + } + + func test_visit_withOpenAccessLevel_onInitializerWithOverride_shouldPreserveOverrideAndAddPublic() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.open)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalInit = InitializerDeclSyntax( + modifiers: [DeclModifierSyntax(name: .keyword(.override))], + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenInit = rewriter.visit(originalInit).as(InitializerDeclSyntax.self)! + + // Then + let modifiers = rewrittenInit.modifiers + XCTAssertEqual(modifiers.count, 2) + + let modifierTexts = modifiers.map { $0.name.text } + XCTAssertTrue(modifierTexts.contains("override")) + XCTAssertTrue(modifierTexts.contains("public")) + } + + func test_visit_withOpenAccessLevel_onInitializerWithConvenience_shouldPreserveConvenienceAndAddPublic() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.open)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalInit = InitializerDeclSyntax( + modifiers: [DeclModifierSyntax(name: .keyword(.convenience))], + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenInit = rewriter.visit(originalInit).as(InitializerDeclSyntax.self)! + + // Then + let modifiers = rewrittenInit.modifiers + XCTAssertEqual(modifiers.count, 2) + + let modifierTexts = modifiers.map { $0.name.text } + XCTAssertTrue(modifierTexts.contains("convenience")) + XCTAssertTrue(modifierTexts.contains("public")) + } + + func test_visit_withOpenAccessLevel_onNonInitializer_shouldUseOpen() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.open)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalFunction = FunctionDeclSyntax( + modifiers: [], + name: .identifier("testFunction"), + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenFunction = rewriter.visit(originalFunction).as(FunctionDeclSyntax.self)! + + // Then + let modifiers = rewrittenFunction.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "open") + } + + // MARK: - Preserve Existing Modifiers Tests + + func test_visit_withExistingModifiers_shouldPreserveAndAddAccessLevel() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.public)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalFunction = FunctionDeclSyntax( + modifiers: [ + DeclModifierSyntax(name: .keyword(.static)), + DeclModifierSyntax(name: .keyword(.final)) + ], + name: .identifier("testFunction"), + signature: FunctionSignatureSyntax( + parameterClause: FunctionParameterClauseSyntax(parameters: []) + ) + ) {} + + // When + let rewrittenFunction = rewriter.visit(originalFunction).as(FunctionDeclSyntax.self)! + + // Then + let modifiers = rewrittenFunction.modifiers + XCTAssertEqual(modifiers.count, 3) + + let modifierTexts = modifiers.map { $0.name.text } + XCTAssertTrue(modifierTexts.contains("static")) + XCTAssertTrue(modifierTexts.contains("final")) + XCTAssertTrue(modifierTexts.contains("public")) + } + + // MARK: - Function Parameter Tests + + func test_visit_onFunctionParameter_shouldNotModify() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.public)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let parameter = FunctionParameterSyntax( + modifiers: [], + firstName: .identifier("param"), + type: TypeSyntax("String") + ) + + // When + let rewrittenParameter = rewriter.visit(parameter).as(FunctionParameterSyntax.self)! + + // Then + XCTAssertEqual(rewrittenParameter.modifiers.count, 0) + } + + // MARK: - Edge Cases Tests + + func test_visit_withEmptyModifierList_shouldAddAccessLevel() { + // Given + let accessLevel = DeclModifierSyntax(name: .keyword(.internal)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + let originalClass = ClassDeclSyntax( + modifiers: [], + name: .identifier("TestClass") + ) {} + + // When + let rewrittenClass = rewriter.visit(originalClass).as(ClassDeclSyntax.self)! + + // Then + let modifiers = rewrittenClass.modifiers + XCTAssertEqual(modifiers.count, 1) + XCTAssertEqual(modifiers.first?.name.text, "internal") + } + + func test_init_withPrivateAccessLevel_shouldConvertToFileprivateInInit() { + // Given & When + let accessLevel = DeclModifierSyntax(name: .keyword(.private)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + // Then + XCTAssertEqual(rewriter.newAccessLevel.name.text, "fileprivate") + } + + func test_init_withNonPrivateAccessLevel_shouldKeepOriginal() { + // Given & When + let accessLevel = DeclModifierSyntax(name: .keyword(.public)) + let rewriter = AccessLevelModifierRewriter(newAccessLevel: accessLevel) + + // Then + XCTAssertEqual(rewriter.newAccessLevel.name.text, "public") + } +} diff --git a/Tests/SpyableMacroTests/Macro/UT_SpyableMacro+Inheritance.swift b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro+Inheritance.swift new file mode 100644 index 0000000..c5c9f9d --- /dev/null +++ b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro+Inheritance.swift @@ -0,0 +1,291 @@ +import SwiftSyntaxMacros +import SwiftSyntaxMacrosTestSupport +import XCTest + +@testable import SpyableMacro + +final class UT_SpyableMacroInheritance: XCTestCase { + private let sut = ["Spyable": SpyableMacro.self] + + func testMacroWithOpenAccessLevel() { + let protocolDeclaration = """ + protocol ServiceProtocol { + var removed: (() -> Void)? { get set } + + func fetchUsername(context: String, completion: @escaping (String) -> Void) + } + """ + + assertMacroExpansion( + """ + @Spyable(accessLevel: .open) + \(protocolDeclaration) + """, + expandedSource: """ + + \(protocolDeclaration) + + open class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + public init() { + } + open + var removed: (() -> Void)? + open var fetchUsernameContextCompletionCallsCount = 0 + open var fetchUsernameContextCompletionCalled: Bool { + return fetchUsernameContextCompletionCallsCount > 0 + } + open var fetchUsernameContextCompletionReceivedArguments: (context: String, completion: (String) -> Void)? + open var fetchUsernameContextCompletionReceivedInvocations: [(context: String, completion: (String) -> Void)] = [] + open var fetchUsernameContextCompletionClosure: ((String, @escaping (String) -> Void) -> Void)? + open + + func fetchUsername(context: String, completion: @escaping (String) -> Void) { + fetchUsernameContextCompletionCallsCount += 1 + fetchUsernameContextCompletionReceivedArguments = (context, completion) + fetchUsernameContextCompletionReceivedInvocations.append((context, completion)) + fetchUsernameContextCompletionClosure?(context, completion) + } + } + """, + macros: sut + ) + } + + func testMacroWithOpenAccessLevelAndInheritedType() { + let protocolDeclaration = """ + protocol ServiceProtocol { + var removed: (() -> Void)? { get set } + + func fetchUsername(context: String, completion: @escaping (String) -> Void) + } + """ + + assertMacroExpansion( + """ + @Spyable(accessLevel: .open, inheritedType: "BaseServiceSpy") + \(protocolDeclaration) + """, + expandedSource: """ + + \(protocolDeclaration) + + open class ServiceProtocolSpy: BaseServiceSpy, ServiceProtocol, @unchecked Sendable { + override public init() { + } + open + var removed: (() -> Void)? + open var fetchUsernameContextCompletionCallsCount = 0 + open var fetchUsernameContextCompletionCalled: Bool { + return fetchUsernameContextCompletionCallsCount > 0 + } + open var fetchUsernameContextCompletionReceivedArguments: (context: String, completion: (String) -> Void)? + open var fetchUsernameContextCompletionReceivedInvocations: [(context: String, completion: (String) -> Void)] = [] + open var fetchUsernameContextCompletionClosure: ((String, @escaping (String) -> Void) -> Void)? + open + + func fetchUsername(context: String, completion: @escaping (String) -> Void) { + fetchUsernameContextCompletionCallsCount += 1 + fetchUsernameContextCompletionReceivedArguments = (context, completion) + fetchUsernameContextCompletionReceivedInvocations.append((context, completion)) + fetchUsernameContextCompletionClosure?(context, completion) + } + } + """, + macros: sut + ) + } + + func testMacroWithOpenAccessLevelAndBehindPreprocessorFlag() { + let protocolDeclaration = """ + protocol ServiceProtocol { + func doSomething() + } + """ + + assertMacroExpansion( + """ + @Spyable(accessLevel: .open, behindPreprocessorFlag: "DEBUG") + \(protocolDeclaration) + """, + expandedSource: """ + + \(protocolDeclaration) + + #if DEBUG + open class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + public init() { + } + open var doSomethingCallsCount = 0 + open var doSomethingCalled: Bool { + return doSomethingCallsCount > 0 + } + open var doSomethingClosure: (() -> Void)? + open + func doSomething() { + doSomethingCallsCount += 1 + doSomethingClosure?() + } + } + #endif + """, + macros: sut + ) + } + + func testMacroWithOpenAccessLevelComplexProtocol() { + let protocolDeclaration = """ + protocol ComplexProtocol { + var name: String { get set } + func initialize(name: String, secondName: String?) + func fetchConfig() async throws -> [String: String] + } + """ + + assertMacroExpansion( + """ + @Spyable(accessLevel: .open) + \(protocolDeclaration) + """, + expandedSource: """ + + \(protocolDeclaration) + + open class ComplexProtocolSpy: ComplexProtocol, @unchecked Sendable { + public init() { + } + open var name: String { + get { + underlyingName + } + set { + underlyingName = newValue + } + } + open var underlyingName: (String)! + open var initializeNameSecondNameCallsCount = 0 + open var initializeNameSecondNameCalled: Bool { + return initializeNameSecondNameCallsCount > 0 + } + open var initializeNameSecondNameReceivedArguments: (name: String, secondName: String?)? + open var initializeNameSecondNameReceivedInvocations: [(name: String, secondName: String?)] = [] + open var initializeNameSecondNameClosure: ((String, String?) -> Void)? + open + func initialize(name: String, secondName: String?) { + initializeNameSecondNameCallsCount += 1 + initializeNameSecondNameReceivedArguments = (name, secondName) + initializeNameSecondNameReceivedInvocations.append((name, secondName)) + initializeNameSecondNameClosure?(name, secondName) + } + open var fetchConfigCallsCount = 0 + open var fetchConfigCalled: Bool { + return fetchConfigCallsCount > 0 + } + open var fetchConfigThrowableError: (any Error)? + open var fetchConfigReturnValue: [String: String]! + open var fetchConfigClosure: (() async throws -> [String: String])? + open + func fetchConfig() async throws -> [String: String] { + fetchConfigCallsCount += 1 + if let fetchConfigThrowableError { + throw fetchConfigThrowableError + } + if fetchConfigClosure != nil { + return try await fetchConfigClosure!() + } else { + return fetchConfigReturnValue + } + } + } + """, + macros: sut + ) + } + + func testOpenAccessLevelFromProtocolDeclaration() { + // Test that open access level is NOT automatically detected from protocol declaration + let protocolDefinition = """ + open protocol ServiceProtocol { + var removed: (() -> Void)? { get set } + + func fetchUsername(context: String, completion: @escaping (String) -> Void) + } + """ + + assertMacroExpansion( + """ + @Spyable + \(protocolDefinition) + """, + expandedSource: """ + + \(protocolDefinition) + + class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + init() { + } + var removed: (() -> Void)? + var fetchUsernameContextCompletionCallsCount = 0 + var fetchUsernameContextCompletionCalled: Bool { + return fetchUsernameContextCompletionCallsCount > 0 + } + var fetchUsernameContextCompletionReceivedArguments: (context: String, completion: (String) -> Void)? + var fetchUsernameContextCompletionReceivedInvocations: [(context: String, completion: (String) -> Void)] = [] + var fetchUsernameContextCompletionClosure: ((String, @escaping (String) -> Void) -> Void)? + + func fetchUsername(context: String, completion: @escaping (String) -> Void) { + fetchUsernameContextCompletionCallsCount += 1 + fetchUsernameContextCompletionReceivedArguments = (context, completion) + fetchUsernameContextCompletionReceivedInvocations.append((context, completion)) + fetchUsernameContextCompletionClosure?(context, completion) + } + } + """, + macros: sut + ) + } + + func testMacroWithOpenAccessLevelOverridingProtocolLevel() { + // Test that accessLevel: .open argument overrides the protocol's access level + let protocolDefinition = """ + internal protocol ServiceProtocol { + var removed: (() -> Void)? { get set } + + func fetchUsername(context: String, completion: @escaping (String) -> Void) + } + """ + + assertMacroExpansion( + """ + @Spyable(accessLevel: .open) + \(protocolDefinition) + """, + expandedSource: """ + + \(protocolDefinition) + + open class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + public init() { + } + open + var removed: (() -> Void)? + open var fetchUsernameContextCompletionCallsCount = 0 + open var fetchUsernameContextCompletionCalled: Bool { + return fetchUsernameContextCompletionCallsCount > 0 + } + open var fetchUsernameContextCompletionReceivedArguments: (context: String, completion: (String) -> Void)? + open var fetchUsernameContextCompletionReceivedInvocations: [(context: String, completion: (String) -> Void)] = [] + open var fetchUsernameContextCompletionClosure: ((String, @escaping (String) -> Void) -> Void)? + open + + func fetchUsername(context: String, completion: @escaping (String) -> Void) { + fetchUsernameContextCompletionCallsCount += 1 + fetchUsernameContextCompletionReceivedArguments = (context, completion) + fetchUsernameContextCompletionReceivedInvocations.append((context, completion)) + fetchUsernameContextCompletionClosure?(context, completion) + } + } + """, + macros: sut + ) + } +}