diff --git a/pom.xml b/pom.xml index b1d17888e9..063e5eca96 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 4.0.0-SNAPSHOT + 4.0.x-GH-3339-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index d7c0c9dd96..b9f80b2202 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; -import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.Generated; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument; @@ -39,7 +38,6 @@ import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.FieldSpec; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.TypeName; @@ -63,7 +61,9 @@ class AotRepositoryBuilder { private @Nullable Consumer constructorCustomizer; private @Nullable MethodContributorFactory methodContributorFactory; + private @Nullable String targetClassName; private Consumer classCustomizer; + private final RepositoryConstructorBuilder constructorBuilder; private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String moduleName, ProjectionFactory projectionFactory) { @@ -72,13 +72,9 @@ private AotRepositoryBuilder(RepositoryInformation repositoryInformation, String this.moduleName = moduleName; this.projectionFactory = projectionFactory; - this.generationMetadata = new AotRepositoryFragmentMetadata(className()); - this.generationMetadata.addField(FieldSpec - .builder(TypeName.get(Log.class), "logger", Modifier.PRIVATE, Modifier.STATIC, Modifier.FINAL) - .initializer("$T.getLog($T.class)", TypeName.get(LogFactory.class), this.generationMetadata.getTargetTypeName()) - .build()); - + this.generationMetadata = new AotRepositoryFragmentMetadata(); this.classCustomizer = (builder) -> {}; + this.constructorBuilder = new RepositoryConstructorBuilder(generationMetadata); } /** @@ -131,15 +127,24 @@ public AotRepositoryBuilder withQueryMethodContributor(MethodContributorFactory return this; } - public AotBundle build() { + /** + * Configure the {@link Class#getSimpleName() simple class name} of the generated repository implementation. + * + * @param className the class name to use for the generated repository implementation. Defaults to the simple + * {@link RepositoryInformation#getRepositoryInterface()} class name suffixed with {@code Impl} + * @return {@code this}. + */ + public AotRepositoryBuilder withClassName(@Nullable String className) { + this.targetClassName = className; + return this; + } + + public AotBundle build(TypeSpec.Builder builder) { List methodMetadata = new ArrayList<>(); RepositoryComposition repositoryComposition = repositoryInformation.getRepositoryComposition(); - // start creating the type - TypeSpec.Builder builder = TypeSpec.classBuilder(this.generationMetadata.getTargetTypeName()) // - .addModifiers(Modifier.PUBLIC) // - .addAnnotation(Generated.class) // + builder.addModifiers(Modifier.PUBLIC) // .addJavadoc("AOT generated $L repository implementation for {@link $T}.\n", moduleName, repositoryInformation.getRepositoryInterface()); @@ -177,10 +182,15 @@ public AotBundle build() { return new AotBundle(javaFile, metadata); } - private MethodSpec buildConstructor() { + public AotBundle build() { + return build(TypeSpec.classBuilder(getClassName()).addAnnotation(Generated.class)); + } - RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder( - generationMetadata); + public ClassName getClassName() { + return ClassName.get(packageName(), targetClassName != null ? targetClassName : typeName()); + } + + private MethodSpec buildConstructor() { if (constructorCustomizer != null) { constructorCustomizer.accept(constructorBuilder); @@ -252,15 +262,11 @@ public AotRepositoryFragmentMetadata getGenerationMetadata() { return generationMetadata; } - private ClassName className() { - return new ClassNameGenerator(ClassName.get(packageName(), typeName())).generateClassName("Aot", null); - } - - private String packageName() { + public String packageName() { return repositoryInformation.getRepositoryInterface().getPackageName(); } - private String typeName() { + public String typeName() { return "%sImpl".formatted(repositoryInformation.getRepositoryInterface().getSimpleName()); } @@ -280,7 +286,6 @@ public ProjectionFactory getProjectionFactory() { return projectionFactory; } - /** * Customizer interface to customize the AOT repository fragment constructor through * {@link AotRepositoryConstructorBuilder}. diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java index c8df5e6c01..f6423463c7 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java @@ -37,12 +37,10 @@ */ public class AotRepositoryFragmentMetadata { - private final ClassName className; private final Map fields = new HashMap<>(3); private final Map constructorArguments = new LinkedHashMap<>(3); - public AotRepositoryFragmentMetadata(ClassName className) { - this.className = className; + public AotRepositoryFragmentMetadata() { } /** @@ -65,10 +63,6 @@ public String fieldNameOf(Class type) { return null; } - public ClassName getTargetTypeName() { - return className; - } - /** * Add a field to the repository fragment. * diff --git a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java index 7a27ef2018..7b0399b0f7 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java @@ -21,15 +21,17 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; +import org.springframework.aot.generate.GeneratedClass; +import org.springframework.aot.generate.GeneratedTypeReference; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.TypeReference; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.aot.generate.AotRepositoryBuilder.AotBundle; import org.springframework.data.repository.config.AotRepositoryContext; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.query.QueryMethod; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.TypeName; /** @@ -42,8 +44,10 @@ public class RepositoryContributor { private static final Log logger = LogFactory.getLog(RepositoryContributor.class); + private static final String FEATURE_NAME = "AotRepository"; private final AotRepositoryBuilder builder; + private @Nullable TypeReference contributedTypeName; /** * Create a new {@code RepositoryContributor} for the given {@link AotRepositoryContext}. @@ -51,7 +55,8 @@ public class RepositoryContributor { * @param repositoryContext */ public RepositoryContributor(AotRepositoryContext repositoryContext) { - this.builder = AotRepositoryBuilder.forRepository(repositoryContext.getRepositoryInformation(), + + builder = AotRepositoryBuilder.forRepository(repositoryContext.getRepositoryInformation(), repositoryContext.getModuleName(), createProjectionFactory()); } @@ -77,8 +82,8 @@ protected RepositoryInformation getRepositoryInformation() { return builder.getRepositoryInformation(); } - public String getContributedTypeName() { - return builder.getGenerationMetadata().getTargetTypeName().toString(); + public @Nullable TypeReference getContributedTypeName() { + return this.contributedTypeName; } public java.util.Map requiredArgs() { @@ -87,44 +92,53 @@ public java.util.Map requiredArgs() { public void contribute(GenerationContext generationContext) { - AotRepositoryBuilder.AotBundle aotBundle = builder.withClassCustomizer(this::customizeClass) // + builder.withClassCustomizer(this::customizeClass) // .withConstructorCustomizer(this::customizeConstructor) // - .withQueryMethodContributor(this::contributeQueryMethod) // - .build(); - - Class repositoryInterface = getRepositoryInformation().getRepositoryInterface(); - String repositoryJsonFileName = getRepositoryJsonFileName(repositoryInterface); - - JavaFile javaFile = aotBundle.javaFile(); - String typeName = "%s.%s".formatted(javaFile.packageName(), javaFile.typeSpec().name()); - String repositoryJson; - - try { - repositoryJson = aotBundle.metadata().toJson().toString(2); - } catch (JSONException e) { - throw new RuntimeException(e); - } - - if (logger.isTraceEnabled()) { - logger.trace(""" - ------ AOT Repository.json: %s ------ - %s - ------------------- - """.formatted(repositoryJsonFileName, repositoryJson)); - - logger.trace(""" - ------ AOT Generated Repository: %s ------ - %s - ------------------- - """.formatted(typeName, javaFile)); - } - - // generate the files - generationContext.getGeneratedFiles().addSourceFile(javaFile); - generationContext.getGeneratedFiles().addResourceFile(repositoryJsonFileName, repositoryJson); + .withQueryMethodContributor(this::contributeQueryMethod); // + + // TODO: temporary fix until we have a better representation of constructor arguments + // decouple the description of arguments from the actual code used in the constructor initialization, super calls, + // etc. + RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder(builder.getGenerationMetadata()); + customizeConstructor(constructorBuilder); + + GeneratedClass generatedClass = generationContext.getGeneratedClasses().getOrAddForFeatureComponent(FEATURE_NAME, + builder.getClassName(), targetTypeSpec -> { + + // capture the actual type name early on so that we can use it in the constructor. + builder.withClassName(targetTypeSpec.build().name()); + + AotBundle aotBundle = builder.build(targetTypeSpec); + Class repositoryInterface = getRepositoryInformation().getRepositoryInterface(); + String repositoryJsonFileName = getRepositoryJsonFileName(repositoryInterface); + String repositoryJson; + try { + repositoryJson = aotBundle.metadata().toJson().toString(2); + } catch (JSONException e) { + throw new RuntimeException(e); + } + + if (logger.isTraceEnabled()) { + logger.trace(""" + ------ AOT Repository.json: %s ------ + %s + ------------------- + """.formatted(repositoryJsonFileName, repositoryJson)); + + logger.trace(""" + ------ AOT Generated Repository: %s ------ + %s + ------------------- + """.formatted(null, aotBundle.javaFile())); + } + + generationContext.getGeneratedFiles().addResourceFile(repositoryJsonFileName, repositoryJson); + }); + + this.contributedTypeName = GeneratedTypeReference.of(generatedClass.getName()); // generate native runtime hints - needed cause we're using the repository proxy - generationContext.getRuntimeHints().reflection().registerType(TypeReference.of(typeName), + generationContext.getRuntimeHints().reflection().registerType(this.contributedTypeName, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS); } diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java index d25e0f1cb3..23e90fe744 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java @@ -24,6 +24,7 @@ import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.TypeName; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** @@ -51,6 +52,8 @@ public AotRepositoryBeanDefinitionPropertiesDecorator(Supplier inheri */ public CodeBlock decorate() { + Assert.notNull(repositoryContributor.getContributedTypeName(), "Contributed type name must not be null"); + CodeBlock.Builder builder = CodeBlock.builder(); // bring in properties as usual builder.add(inheritedProperties.get()); @@ -78,7 +81,7 @@ public CodeBlock decorate() { } builder.addStatement("return RepositoryComposition.RepositoryFragments.just(new $L($L))", - repositoryContributor.getContributedTypeName(), + repositoryContributor.getContributedTypeName().getCanonicalName(), StringUtils.collectionToDelimitedString(repositoryContributor.requiredArgs().keySet(), ", ")); builder.unindent(); builder.add("}\n"); diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java index 1bb20248f3..f184620889 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.aot.hint.TypeReference; import org.springframework.data.geo.Metric; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.querydsl.QuerydslPredicateExecutor; @@ -66,8 +67,8 @@ void writesClassSkeleton() { assertThat(repoBuilder.build().javaFile().toString()) .contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo .contains("@Generated") // marked as generated source - .contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name - .contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire + .contains("public class %sImpl".formatted(UserRepository.class.getSimpleName())) // target name + .contains("public UserRepositoryImpl"); // default constructor if not arguments to wire } @Test // GH-3279 @@ -80,11 +81,11 @@ void appliesCtorArguments() { ctor.addParameter("param2", String.class); ctor.addParameter("ctorScoped", TypeName.get(Object.class), false); }); - assertThat(repoBuilder.build().javaFile().toString()) // + assertThat(repoBuilder.withClassName(null).build().javaFile().toString()) // .contains("private final Metric param1;") // .contains("private final String param2;") // .doesNotContain("private final Object ctorScoped;") // - .contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") // + .contains("public UserRepositoryImpl(Metric param1, String param2, Object ctorScoped)") // .contains("this.param1 = param1") // .contains("this.param2 = param2") // .doesNotContain("this.ctorScoped = ctorScoped"); @@ -100,8 +101,9 @@ void appliesCtorCodeBlock() { code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error"); }); }); + repoBuilder.withClassName(null); assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces( - "UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }"); + "UserRepositoryImpl() { throw new IllegalStateException(\"initialization error\"); }"); } @Test // GH-3279 @@ -180,6 +182,24 @@ void shouldContributeFragmentImplementationMetadata() { assertThat(method.fragment().implementation()).isEqualTo(DummyQuerydslPredicateExecutor.class.getName()); } + @Test // GH-3339 + void usesTargetTypeName() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, "Commons", + new SpelAwareProxyProjectionFactory()); + repoBuilder.withConstructorCustomizer(ctor -> { + ctor.addParameter("param1", Metric.class); + ctor.addParameter("param2", String.class); + ctor.addParameter("ctorScoped", TypeName.get(Object.class), false); + }); + + TypeReference targetType = TypeReference.of("%s__AotPostfix".formatted(UserRepository.class.getCanonicalName())); + + assertThat(repoBuilder.withClassName(targetType.getSimpleName()).build().javaFile().toString()) // + .contains("class %s".formatted(targetType.getSimpleName())) // + .contains("public %s(Metric param1, String param2, Object ctorScoped)".formatted(targetType.getSimpleName())); + } + interface UserRepository extends org.springframework.data.repository.Repository { String someMethod(); diff --git a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java index 8640c1eada..402898784d 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java @@ -87,7 +87,7 @@ public Map serialize() { repositoryContributor.contribute(generationContext); generationContext.writeGeneratedContent(); - String expectedTypeName = "example.UserRepositoryImpl__Aot"; + String expectedTypeName = "example.UserRepositoryImpl__AotRepository"; TestCompiler.forSystem().with(generationContext).compile(compiled -> { assertThat(compiled.getAllCompiledClasses()).map(Class::getName).contains(expectedTypeName); @@ -132,7 +132,7 @@ public Map serialize() { repositoryContributor.contribute(generationContext); generationContext.writeGeneratedContent(); - String expectedTypeName = "example.UserRepositoryImpl__Aot"; + String expectedTypeName = "example.UserRepositoryImpl__AotRepository"; TestCompiler.forSystem().with(generationContext).compile(compiled -> { String content = compiled.getResourceFile().getContent(); @@ -154,7 +154,9 @@ void callsMethodContributionForQueryMethod() { when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); - contributor.contribute(new TestGenerationContext(UserRepository.class)); + TestGenerationContext generationContext = new TestGenerationContext(UserRepository.class); + contributor.contribute(generationContext); + generationContext.writeGeneratedContent(); contributor.verifyContributionFor("findByFirstname"); } @@ -174,8 +176,11 @@ void doesNotContributeBaseClassMethods() { .thenReturn(true); when(repositoryInformation.isQueryMethod(argThat(it -> !it.getName().equals("findByFirstname")))).thenReturn(true); + TestGenerationContext testGenerationContext = new TestGenerationContext(UserRepository.class); MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); - contributor.contribute(new TestGenerationContext(UserRepository.class)); + contributor.contribute(testGenerationContext); + testGenerationContext.writeGeneratedContent(); + contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findByFirstname"); } @@ -200,7 +205,9 @@ void doesNotContributeFragmentMethod() { when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); - contributor.contribute(new TestGenerationContext(UserRepository.class)); + TestGenerationContext generationContext = new TestGenerationContext(UserRepository.class); + contributor.contribute(generationContext); + generationContext.writeGeneratedContent(); contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findUserByExtensionMethod"); } @@ -221,7 +228,9 @@ void contributesBaseClassMethodIfQueryMethod() { when(repositoryInformation.isQueryMethod(any())).thenReturn(true); MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); - contributor.contribute(new TestGenerationContext(UserRepository.class)); + TestGenerationContext generationContext = new TestGenerationContext(UserRepository.class); + contributor.contribute(generationContext); + generationContext.writeGeneratedContent(); contributor.verifyContributedMethods().containsKey("findByFirstname").hasSizeGreaterThan(1); }