Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Sources/Spyable/Spyable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@
/// - The generated spy class name is suffixed with `Spy` (e.g., `ServiceProtocolSpy`).
///
@attached(peer, names: suffixed(Spy))
public macro Spyable(behindPreprocessorFlag: String? = nil, accessLevel: SpyAccessLevel? = nil) =
public macro Spyable(
behindPreprocessorFlag: String? = nil,
accessLevel: SpyAccessLevel? = nil,
inheritedType: String? = nil
) =
#externalMacro(
module: "SpyableMacro",
type: "SpyableMacro"
Expand Down
6 changes: 5 additions & 1 deletion Sources/SpyableMacro/Diagnostics/SpyableDiagnostic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error {
case behindPreprocessorFlagArgumentRequiresStaticStringLiteral
case accessLevelArgumentRequiresMemberAccessExpression
case accessLevelArgumentUnsupportedAccessLevel
case inheritedTypeArgumentRequiresStaticStringLiteral

/// Provides a human-readable diagnostic message for each diagnostic case.
var message: String {
Expand All @@ -30,6 +31,8 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error {
"The `accessLevel` argument requires a member access expression"
case .accessLevelArgumentUnsupportedAccessLevel:
"The `accessLevel` argument does not support the specified access level"
case .inheritedTypeArgumentRequiresStaticStringLiteral:
"The `inheritedType` argument requires a static string literal"
}
}

Expand All @@ -41,7 +44,8 @@ enum SpyableDiagnostic: String, DiagnosticMessage, Error {
.variableDeclInProtocolWithNotIdentifierPattern,
.behindPreprocessorFlagArgumentRequiresStaticStringLiteral,
.accessLevelArgumentRequiresMemberAccessExpression,
.accessLevelArgumentUnsupportedAccessLevel:
.accessLevelArgumentUnsupportedAccessLevel,
.inheritedTypeArgumentRequiresStaticStringLiteral:
.error
}
}
Expand Down
56 changes: 56 additions & 0 deletions Sources/SpyableMacro/Extractors/Extractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ struct Extractor {
let accessLevelText = memberAccess.declName.baseName.text

switch accessLevelText {
case "open":
return DeclModifierSyntax(name: .keyword(.open))

case "public":
return DeclModifierSyntax(name: .keyword(.public))

Expand Down Expand Up @@ -145,6 +148,59 @@ struct Extractor {
func extractAccessLevel(from protocolDeclSyntax: ProtocolDeclSyntax) -> DeclModifierSyntax? {
protocolDeclSyntax.modifiers.first(where: \.name.isAccessLevelSupportedInProtocol)
}

/// Extracts an inherited type value from an attribute if present and valid.
///
/// This method searches for an argument labeled `inheritedType` within the
/// given attribute. If the argument is found, its value is validated to ensure it is
/// a static string literal.
///
/// - Parameters:
/// - attribute: The attribute syntax to analyze.
/// - context: The macro expansion context in which the operation is performed.
/// - Returns: The static string literal value of the `inheritedType` argument,
/// or `nil` if the argument is missing or invalid.
/// - Note: Diagnoses an error if the argument value is not a static string literal.
func extractInheritedType(
from attribute: AttributeSyntax,
in context: some MacroExpansionContext
) -> String? {
guard case let .argumentList(argumentList) = attribute.arguments else {
// No arguments are present in the attribute.
return nil
}

let inheritedTypeArgument = argumentList.first { argument in
argument.label?.text == "inheritedType"
}

guard let inheritedTypeArgument else {
// The `inheritedType` argument is missing.
return nil
}

// Check if it's a string literal expression
let segments = inheritedTypeArgument.expression
.as(StringLiteralExprSyntax.self)?
.segments

guard let segments,
segments.count == 1,
case let .stringSegment(literalSegment)? = segments.first
else {
// The `inheritedType` argument's value is not a valid string literal.
context.diagnose(
Diagnostic(
node: attribute,
message: SpyableDiagnostic.inheritedTypeArgumentRequiresStaticStringLiteral,
highlights: [Syntax(inheritedTypeArgument.expression)]
)
)
return nil
}

return literalSegment.content.text
}
}

extension TokenSyntax {
Expand Down
16 changes: 15 additions & 1 deletion Sources/SpyableMacro/Factories/SpyFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ struct SpyFactory {
private let closureFactory = ClosureFactory()
private let functionImplementationFactory = FunctionImplementationFactory()

func classDeclaration(for protocolDeclaration: ProtocolDeclSyntax) throws -> ClassDeclSyntax {
func classDeclaration(
for protocolDeclaration: ProtocolDeclSyntax,
inheritedType: String? = nil
) throws -> ClassDeclSyntax {
let identifier = TokenSyntax.identifier(protocolDeclaration.name.text + "Spy")

let assosciatedtypeDeclarations = protocolDeclaration.memberBlock.members.compactMap {
Expand All @@ -117,6 +120,14 @@ struct SpyFactory {
name: identifier,
genericParameterClause: genericParameterClause,
inheritanceClause: InheritanceClauseSyntax {
// Add inherited type first if present
if let inheritedType {
InheritedTypeSyntax(
type: TypeSyntax(stringLiteral: inheritedType)
)
}

// Add the main protocol
InheritedTypeSyntax(
type: TypeSyntax(stringLiteral: protocolDeclaration.name.text)
)
Expand All @@ -125,7 +136,10 @@ struct SpyFactory {
)
},
memberBlockBuilder: {
let initOverrideKeyword: DeclModifierListSyntax = inheritedType != nil ? [DeclModifierSyntax(name: .keyword(.override))] : []

InitializerDeclSyntax(
modifiers: initOverrideKeyword,
signature: FunctionSignatureSyntax(
parameterClause: FunctionParameterClauseSyntax(parameters: [])
),
Expand Down
15 changes: 13 additions & 2 deletions Sources/SpyableMacro/Macro/AccessLevelModifierRewriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@ final class AccessLevelModifierRewriter: SyntaxRewriter {
return node
}

return DeclModifierListSyntax {
newAccessLevel
// Always preserve existing modifiers (like override, convenience, etc.)
var modifiers = Array(node)

// Special case: if accessLevel is open and this is an initializer, use public instead
if newAccessLevel.name.text == TokenSyntax.keyword(.open).text,
let parent = node.parent,
parent.is(InitializerDeclSyntax.self) {
modifiers.append(DeclModifierSyntax(name: .keyword(.public)))
} else {
// Add the access level modifier for all other cases
modifiers.append(newAccessLevel)
}

return DeclModifierListSyntax(modifiers)
}
}
10 changes: 8 additions & 2 deletions Sources/SpyableMacro/Macro/SpyableMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ public enum SpyableMacro: PeerMacro {
// Extract the protocol declaration
let protocolDeclaration = try extractor.extractProtocolDeclaration(from: declaration)

// Generate the initial spy class declaration
var spyClassDeclaration = try spyFactory.classDeclaration(for: protocolDeclaration)
// Extract inherited type from the attribute
let inheritedType = extractor.extractInheritedType(from: node, in: context)

// Generate the initial spy class declaration with inherited type
var spyClassDeclaration = try spyFactory.classDeclaration(
for: protocolDeclaration,
inheritedType: inheritedType
)

// Apply access level modifiers if needed
if let accessLevel = determineAccessLevel(
Expand Down
143 changes: 143 additions & 0 deletions Tests/SpyableMacroTests/Extractors/UT_Extractor.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import SwiftSyntax
import SwiftSyntaxMacros
import SwiftDiagnostics
import XCTest

@testable import SpyableMacro

final class UT_Extractor: XCTestCase {
private var mockContext: MockMacroExpansionContext!

override func setUp() {
super.setUp()
mockContext = MockMacroExpansionContext()
}

func testExtractProtocolDeclarationSuccessfully() throws {
let declaration = DeclSyntax(
"""
Expand All @@ -29,4 +38,138 @@ final class UT_Extractor: XCTestCase {
let unwrappedReceivedError = try XCTUnwrap(receivedError as? SpyableDiagnostic)
XCTAssertEqual(unwrappedReceivedError, .onlyApplicableToProtocol)
}

// MARK: - extractInheritedType Tests

func test_extractInheritedType_withValidStringLiteral_returnsValue() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable(inheritedType: "BaseClass")
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertEqual(result, "BaseClass")
XCTAssertTrue(mockContext.diagnostics.isEmpty)
}

func test_extractInheritedType_withNoArguments_returnsNil() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertNil(result)
XCTAssertTrue(mockContext.diagnostics.isEmpty)
}

func test_extractInheritedType_withMissingInheritedTypeArgument_returnsNil() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable(accessLevel: .public)
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertNil(result)
XCTAssertTrue(mockContext.diagnostics.isEmpty)
}

func test_extractInheritedType_withNonStringLiteral_returnsNilAndDiagnoses() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable(inheritedType: someVariable)
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertNil(result)
XCTAssertEqual(mockContext.diagnostics.count, 1)
XCTAssertEqual(
mockContext.diagnostics.first?.message,
SpyableDiagnostic.inheritedTypeArgumentRequiresStaticStringLiteral.message
)
}

func test_extractInheritedType_withEmptyString_returnsEmptyString() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable(inheritedType: "")
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertEqual(result, "")
XCTAssertTrue(mockContext.diagnostics.isEmpty)
}

func test_extractInheritedType_withComplexClassName_returnsValue() {
// Given
let attribute = AttributeSyntax(
"""
@Spyable(inheritedType: "MyModule.BaseClass<T>")
"""
)

// When
let result = Extractor().extractInheritedType(from: attribute, in: mockContext)

// Then
XCTAssertEqual(result, "MyModule.BaseClass<T>")
XCTAssertTrue(mockContext.diagnostics.isEmpty)
}
}

// MARK: - Mock Context

private class MockMacroExpansionContext: MacroExpansionContext {
var diagnostics: [Diagnostic] = []

func diagnose(_ diagnostic: Diagnostic) {
diagnostics.append(diagnostic)
}

func location<Node: SyntaxProtocol>(
of node: Node,
at position: PositionInSyntaxNode,
filePathMode: SourceLocationFilePathMode
) -> AbstractSourceLocation? {
return nil
}

func location(
for token: TokenSyntax,
at position: PositionInSyntaxNode,
filePathMode: SourceLocationFilePathMode
) -> AbstractSourceLocation? {
return nil
}

var lexicalContext: [Syntax] = []

func makeUniqueName(_ providedName: String) -> TokenSyntax {
return TokenSyntax.identifier(providedName)
}
}
Loading