Skip to content

Use generated classname for writing aot repository content #3345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-commons</artifactId>
<version>4.0.0-SNAPSHOT</version>
<version>4.0.x-GH-3339-SNAPSHOT</version>

<name>Spring Data Core</name>
<description>Core Spring concepts underpinning every Spring Data module.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.Generated;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument;
Expand All @@ -39,7 +38,6 @@
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.FieldSpec;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
Expand All @@ -63,7 +61,9 @@ class AotRepositoryBuilder {

private @Nullable Consumer<AotRepositoryConstructorBuilder> constructorCustomizer;
private @Nullable MethodContributorFactory methodContributorFactory;
private @Nullable String targetClassName;
private Consumer<AotRepositoryClassBuilder> classCustomizer;
private final RepositoryConstructorBuilder constructorBuilder;

private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String moduleName,
ProjectionFactory projectionFactory) {
Expand All @@ -72,13 +72,9 @@ private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String
this.moduleName = moduleName;
this.projectionFactory = projectionFactory;

this.generationMetadata = new AotRepositoryFragmentMetadata(className());
this.generationMetadata.addField(FieldSpec
.builder(TypeName.get(Log.class), "logger", Modifier.PRIVATE, Modifier.STATIC, Modifier.FINAL)
.initializer("$T.getLog($T.class)", TypeName.get(LogFactory.class), this.generationMetadata.getTargetTypeName())
.build());

this.generationMetadata = new AotRepositoryFragmentMetadata();
this.classCustomizer = (builder) -> {};
this.constructorBuilder = new RepositoryConstructorBuilder(generationMetadata);
}

/**
Expand Down Expand Up @@ -131,15 +127,24 @@ public AotRepositoryBuilder withQueryMethodContributor(MethodContributorFactory
return this;
}

public AotBundle build() {
/**
* Configure the {@link Class#getSimpleName() simple class name} of the generated repository implementation.
*
* @param className the class name to use for the generated repository implementation. Defaults to the simple
* {@link RepositoryInformation#getRepositoryInterface()} class name suffixed with {@code Impl}
* @return {@code this}.
*/
public AotRepositoryBuilder withClassName(@Nullable String className) {
this.targetClassName = className;
return this;
}

public AotBundle build(TypeSpec.Builder builder) {

List<AotRepositoryMethod> methodMetadata = new ArrayList<>();
RepositoryComposition repositoryComposition = repositoryInformation.getRepositoryComposition();

// start creating the type
TypeSpec.Builder builder = TypeSpec.classBuilder(this.generationMetadata.getTargetTypeName()) //
.addModifiers(Modifier.PUBLIC) //
.addAnnotation(Generated.class) //
builder.addModifiers(Modifier.PUBLIC) //
.addJavadoc("AOT generated $L repository implementation for {@link $T}.\n", moduleName,
repositoryInformation.getRepositoryInterface());

Expand Down Expand Up @@ -177,10 +182,15 @@ public AotBundle build() {
return new AotBundle(javaFile, metadata);
}

private MethodSpec buildConstructor() {
public AotBundle build() {
return build(TypeSpec.classBuilder(getClassName()).addAnnotation(Generated.class));
}

RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder(
generationMetadata);
public ClassName getClassName() {
return ClassName.get(packageName(), targetClassName != null ? targetClassName : typeName());
}

private MethodSpec buildConstructor() {

if (constructorCustomizer != null) {
constructorCustomizer.accept(constructorBuilder);
Expand Down Expand Up @@ -252,15 +262,11 @@ public AotRepositoryFragmentMetadata getGenerationMetadata() {
return generationMetadata;
}

private ClassName className() {
return new ClassNameGenerator(ClassName.get(packageName(), typeName())).generateClassName("Aot", null);
}

private String packageName() {
public String packageName() {
return repositoryInformation.getRepositoryInterface().getPackageName();
}

private String typeName() {
public String typeName() {
return "%sImpl".formatted(repositoryInformation.getRepositoryInterface().getSimpleName());
}

Expand All @@ -280,7 +286,6 @@ public ProjectionFactory getProjectionFactory() {
return projectionFactory;
}


/**
* Customizer interface to customize the AOT repository fragment constructor through
* {@link AotRepositoryConstructorBuilder}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@
*/
public class AotRepositoryFragmentMetadata {

private final ClassName className;
private final Map<String, FieldSpec> fields = new HashMap<>(3);
private final Map<String, ConstructorArgument> constructorArguments = new LinkedHashMap<>(3);

public AotRepositoryFragmentMetadata(ClassName className) {
this.className = className;
public AotRepositoryFragmentMetadata() {
}

/**
Expand All @@ -65,10 +63,6 @@ public String fieldNameOf(Class<?> type) {
return null;
}

public ClassName getTargetTypeName() {
return className;
}

/**
* Add a field to the repository fragment.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedTypeReference;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.TypeReference;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.aot.generate.AotRepositoryBuilder.AotBundle;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.TypeName;

/**
Expand All @@ -42,16 +44,19 @@
public class RepositoryContributor {

private static final Log logger = LogFactory.getLog(RepositoryContributor.class);
private static final String FEATURE_NAME = "AotRepository";

private final AotRepositoryBuilder builder;
private @Nullable TypeReference contributedTypeName;

/**
* Create a new {@code RepositoryContributor} for the given {@link AotRepositoryContext}.
*
* @param repositoryContext
*/
public RepositoryContributor(AotRepositoryContext repositoryContext) {
this.builder = AotRepositoryBuilder.forRepository(repositoryContext.getRepositoryInformation(),

builder = AotRepositoryBuilder.forRepository(repositoryContext.getRepositoryInformation(),
repositoryContext.getModuleName(), createProjectionFactory());
}

Expand All @@ -77,8 +82,8 @@ protected RepositoryInformation getRepositoryInformation() {
return builder.getRepositoryInformation();
}

public String getContributedTypeName() {
return builder.getGenerationMetadata().getTargetTypeName().toString();
public @Nullable TypeReference getContributedTypeName() {
return this.contributedTypeName;
}

public java.util.Map<String, TypeName> requiredArgs() {
Expand All @@ -87,44 +92,53 @@ public java.util.Map<String, TypeName> requiredArgs() {

public void contribute(GenerationContext generationContext) {

AotRepositoryBuilder.AotBundle aotBundle = builder.withClassCustomizer(this::customizeClass) //
builder.withClassCustomizer(this::customizeClass) //
.withConstructorCustomizer(this::customizeConstructor) //
.withQueryMethodContributor(this::contributeQueryMethod) //
.build();

Class<?> repositoryInterface = getRepositoryInformation().getRepositoryInterface();
String repositoryJsonFileName = getRepositoryJsonFileName(repositoryInterface);

JavaFile javaFile = aotBundle.javaFile();
String typeName = "%s.%s".formatted(javaFile.packageName(), javaFile.typeSpec().name());
String repositoryJson;

try {
repositoryJson = aotBundle.metadata().toJson().toString(2);
} catch (JSONException e) {
throw new RuntimeException(e);
}

if (logger.isTraceEnabled()) {
logger.trace("""
------ AOT Repository.json: %s ------
%s
-------------------
""".formatted(repositoryJsonFileName, repositoryJson));

logger.trace("""
------ AOT Generated Repository: %s ------
%s
-------------------
""".formatted(typeName, javaFile));
}

// generate the files
generationContext.getGeneratedFiles().addSourceFile(javaFile);
generationContext.getGeneratedFiles().addResourceFile(repositoryJsonFileName, repositoryJson);
.withQueryMethodContributor(this::contributeQueryMethod); //

// TODO: temporary fix until we have a better representation of constructor arguments
// decouple the description of arguments from the actual code used in the constructor initialization, super calls,
// etc.
RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder(builder.getGenerationMetadata());
customizeConstructor(constructorBuilder);

GeneratedClass generatedClass = generationContext.getGeneratedClasses().getOrAddForFeatureComponent(FEATURE_NAME,
builder.getClassName(), targetTypeSpec -> {

// capture the actual type name early on so that we can use it in the constructor.
builder.withClassName(targetTypeSpec.build().name());

AotBundle aotBundle = builder.build(targetTypeSpec);
Class<?> repositoryInterface = getRepositoryInformation().getRepositoryInterface();
String repositoryJsonFileName = getRepositoryJsonFileName(repositoryInterface);
String repositoryJson;
try {
repositoryJson = aotBundle.metadata().toJson().toString(2);
} catch (JSONException e) {
throw new RuntimeException(e);
}

if (logger.isTraceEnabled()) {
logger.trace("""
------ AOT Repository.json: %s ------
%s
-------------------
""".formatted(repositoryJsonFileName, repositoryJson));

logger.trace("""
------ AOT Generated Repository: %s ------
%s
-------------------
""".formatted(null, aotBundle.javaFile()));
}

generationContext.getGeneratedFiles().addResourceFile(repositoryJsonFileName, repositoryJson);
});

this.contributedTypeName = GeneratedTypeReference.of(generatedClass.getName());

// generate native runtime hints - needed cause we're using the repository proxy
generationContext.getRuntimeHints().reflection().registerType(TypeReference.of(typeName),
generationContext.getRuntimeHints().reflection().registerType(this.contributedTypeName,
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
Expand Down Expand Up @@ -51,6 +52,8 @@ public AotRepositoryBeanDefinitionPropertiesDecorator(Supplier<CodeBlock> inheri
*/
public CodeBlock decorate() {

Assert.notNull(repositoryContributor.getContributedTypeName(), "Contributed type name must not be null");

CodeBlock.Builder builder = CodeBlock.builder();
// bring in properties as usual
builder.add(inheritedProperties.get());
Expand Down Expand Up @@ -78,7 +81,7 @@ public CodeBlock decorate() {
}

builder.addStatement("return RepositoryComposition.RepositoryFragments.just(new $L($L))",
repositoryContributor.getContributedTypeName(),
repositoryContributor.getContributedTypeName().getCanonicalName(),
StringUtils.collectionToDelimitedString(repositoryContributor.requiredArgs().keySet(), ", "));
builder.unindent();
builder.add("}\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.aot.hint.TypeReference;
import org.springframework.data.geo.Metric;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.querydsl.QuerydslPredicateExecutor;
Expand Down Expand Up @@ -66,8 +67,8 @@ void writesClassSkeleton() {
assertThat(repoBuilder.build().javaFile().toString())
.contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo
.contains("@Generated") // marked as generated source
.contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name
.contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire
.contains("public class %sImpl".formatted(UserRepository.class.getSimpleName())) // target name
.contains("public UserRepositoryImpl"); // default constructor if not arguments to wire
}

@Test // GH-3279
Expand All @@ -80,11 +81,11 @@ void appliesCtorArguments() {
ctor.addParameter("param2", String.class);
ctor.addParameter("ctorScoped", TypeName.get(Object.class), false);
});
assertThat(repoBuilder.build().javaFile().toString()) //
assertThat(repoBuilder.withClassName(null).build().javaFile().toString()) //
.contains("private final Metric param1;") //
.contains("private final String param2;") //
.doesNotContain("private final Object ctorScoped;") //
.contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") //
.contains("public UserRepositoryImpl(Metric param1, String param2, Object ctorScoped)") //
.contains("this.param1 = param1") //
.contains("this.param2 = param2") //
.doesNotContain("this.ctorScoped = ctorScoped");
Expand All @@ -100,8 +101,9 @@ void appliesCtorCodeBlock() {
code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error");
});
});
repoBuilder.withClassName(null);
assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces(
"UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }");
"UserRepositoryImpl() { throw new IllegalStateException(\"initialization error\"); }");
}

@Test // GH-3279
Expand Down Expand Up @@ -180,6 +182,24 @@ void shouldContributeFragmentImplementationMetadata() {
assertThat(method.fragment().implementation()).isEqualTo(DummyQuerydslPredicateExecutor.class.getName());
}

@Test // GH-3339
void usesTargetTypeName() {

AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons",
new SpelAwareProxyProjectionFactory());
repoBuilder.withConstructorCustomizer(ctor -> {
ctor.addParameter("param1", Metric.class);
ctor.addParameter("param2", String.class);
ctor.addParameter("ctorScoped", TypeName.get(Object.class), false);
});

TypeReference targetType = TypeReference.of("%s__AotPostfix".formatted(UserRepository.class.getCanonicalName()));

assertThat(repoBuilder.withClassName(targetType.getSimpleName()).build().javaFile().toString()) //
.contains("class %s".formatted(targetType.getSimpleName())) //
.contains("public %s(Metric param1, String param2, Object ctorScoped)".formatted(targetType.getSimpleName()));
}

interface UserRepository extends org.springframework.data.repository.Repository<User, String> {

String someMethod();
Expand Down
Loading