1515 */
1616package org .springframework .batch .core .repository .dao .mongodb ;
1717
18- import java .util .ArrayList ;
1918import java .util .Collection ;
20- import java .util .Comparator ;
2119import java .util .List ;
22- import java .util .Optional ;
2320
2421import org .springframework .batch .core .job .JobExecution ;
2522import org .springframework .batch .core .job .JobInstance ;
2623import org .springframework .batch .core .step .StepExecution ;
2724import org .springframework .batch .core .repository .dao .StepExecutionDao ;
2825import org .springframework .batch .core .repository .persistence .converter .JobExecutionConverter ;
2926import org .springframework .batch .core .repository .persistence .converter .StepExecutionConverter ;
27+ import org .springframework .data .domain .Sort ;
3028import org .springframework .data .mongodb .core .MongoOperations ;
3129import org .springframework .data .mongodb .core .query .Query ;
3230import org .springframework .jdbc .support .incrementer .DataFieldMaxValueIncrementer ;
3634
3735/**
3836 * @author Mahmoud Ben Hassine
37+ * @author Jinwoo Bae
3938 * @since 5.2.0
4039 */
4140public class MongoStepExecutionDao implements StepExecutionDao {
@@ -66,7 +65,7 @@ public void setStepExecutionIncrementer(DataFieldMaxValueIncrementer stepExecuti
6665 @ Override
6766 public void saveStepExecution (StepExecution stepExecution ) {
6867 org .springframework .batch .core .repository .persistence .StepExecution stepExecutionToSave = this .stepExecutionConverter
69- .fromStepExecution (stepExecution );
68+ .fromStepExecution (stepExecution );
7069 long stepExecutionId = this .stepExecutionIncrementer .nextLongValue ();
7170 stepExecutionToSave .setStepExecutionId (stepExecutionId );
7271 this .mongoOperations .insert (stepExecutionToSave , STEP_EXECUTIONS_COLLECTION_NAME );
@@ -84,82 +83,91 @@ public void saveStepExecutions(Collection<StepExecution> stepExecutions) {
8483 public void updateStepExecution (StepExecution stepExecution ) {
8584 Query query = query (where ("stepExecutionId" ).is (stepExecution .getId ()));
8685 org .springframework .batch .core .repository .persistence .StepExecution stepExecutionToUpdate = this .stepExecutionConverter
87- .fromStepExecution (stepExecution );
86+ .fromStepExecution (stepExecution );
8887 this .mongoOperations .findAndReplace (query , stepExecutionToUpdate , STEP_EXECUTIONS_COLLECTION_NAME );
8988 }
9089
9190 @ Override
9291 public StepExecution getStepExecution (JobExecution jobExecution , Long stepExecutionId ) {
9392 Query query = query (where ("stepExecutionId" ).is (stepExecutionId ));
9493 org .springframework .batch .core .repository .persistence .StepExecution stepExecution = this .mongoOperations
95- .findOne (query , org .springframework .batch .core .repository .persistence .StepExecution .class ,
96- STEP_EXECUTIONS_COLLECTION_NAME );
94+ .findOne (query , org .springframework .batch .core .repository .persistence .StepExecution .class ,
95+ STEP_EXECUTIONS_COLLECTION_NAME );
9796 return stepExecution != null ? this .stepExecutionConverter .toStepExecution (stepExecution , jobExecution ) : null ;
9897 }
9998
10099 @ Override
101100 public StepExecution getLastStepExecution (JobInstance jobInstance , String stepName ) {
102101 // TODO optimize the query
103- // get all step executions
104- List <org .springframework .batch .core .repository .persistence .StepExecution > stepExecutions = new ArrayList <>();
105- Query query = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
102+ Query jobExecutionQuery = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
106103 List <org .springframework .batch .core .repository .persistence .JobExecution > jobExecutions = this .mongoOperations
107- .find (query , org .springframework .batch .core .repository .persistence .JobExecution .class ,
108- JOB_EXECUTIONS_COLLECTION_NAME );
109- for (org .springframework .batch .core .repository .persistence .JobExecution jobExecution : jobExecutions ) {
110- stepExecutions .addAll (jobExecution .getStepExecutions ());
104+ .find (jobExecutionQuery , org .springframework .batch .core .repository .persistence .JobExecution .class ,
105+ JOB_EXECUTIONS_COLLECTION_NAME );
106+
107+ if (jobExecutions .isEmpty ()) {
108+ return null ;
109+ }
110+
111+ List <Long > jobExecutionIds = jobExecutions .stream ()
112+ .map (org .springframework .batch .core .repository .persistence .JobExecution ::getJobExecutionId )
113+ .toList ();
114+
115+ Query stepExecutionQuery = query (where ("name" ).is (stepName ).and ("jobExecutionId" ).in (jobExecutionIds ))
116+ .with (Sort .by (Sort .Direction .DESC , "createTime" , "stepExecutionId" ))
117+ .limit (1 );
118+
119+ org .springframework .batch .core .repository .persistence .StepExecution stepExecution = this .mongoOperations
120+ .findOne (stepExecutionQuery , org .springframework .batch .core .repository .persistence .StepExecution .class ,
121+ STEP_EXECUTIONS_COLLECTION_NAME );
122+
123+ if (stepExecution == null ) {
124+ return null ;
111125 }
112- // sort step executions by creation date then id (see contract) and return the
113- // first one
114- Optional <org .springframework .batch .core .repository .persistence .StepExecution > lastStepExecution = stepExecutions
115- .stream ()
116- .filter (stepExecution -> stepExecution .getName ().equals (stepName ))
117- .min (Comparator
118- .comparing (org .springframework .batch .core .repository .persistence .StepExecution ::getCreateTime )
119- .thenComparing (org .springframework .batch .core .repository .persistence .StepExecution ::getId ));
120- if (lastStepExecution .isPresent ()) {
121- org .springframework .batch .core .repository .persistence .StepExecution stepExecution = lastStepExecution .get ();
122- JobExecution jobExecution = this .jobExecutionConverter .toJobExecution (jobExecutions .stream ()
126+
127+ org .springframework .batch .core .repository .persistence .JobExecution jobExecution = jobExecutions .stream ()
123128 .filter (execution -> execution .getJobExecutionId ().equals (stepExecution .getJobExecutionId ()))
124129 .findFirst ()
125- .get (), jobInstance );
126- return this . stepExecutionConverter . toStepExecution ( stepExecution , jobExecution );
127- }
128- else {
129- return null ;
130+ .orElse ( null );
131+
132+ if ( jobExecution != null ) {
133+ JobExecution jobExecutionDomain = this . jobExecutionConverter . toJobExecution ( jobExecution , jobInstance );
134+ return this . stepExecutionConverter . toStepExecution ( stepExecution , jobExecutionDomain ) ;
130135 }
136+
137+ return null ;
131138 }
132139
133140 @ Override
134141 public void addStepExecutions (JobExecution jobExecution ) {
135142 Query query = query (where ("jobExecutionId" ).is (jobExecution .getId ()));
136143 List <StepExecution > stepExecutions = this .mongoOperations
137- .find (query , org .springframework .batch .core .repository .persistence .StepExecution .class ,
138- STEP_EXECUTIONS_COLLECTION_NAME )
139- .stream ()
140- .map (stepExecution -> this .stepExecutionConverter .toStepExecution (stepExecution , jobExecution ))
141- .toList ();
144+ .find (query , org .springframework .batch .core .repository .persistence .StepExecution .class ,
145+ STEP_EXECUTIONS_COLLECTION_NAME )
146+ .stream ()
147+ .map (stepExecution -> this .stepExecutionConverter .toStepExecution (stepExecution , jobExecution ))
148+ .toList ();
142149 jobExecution .addStepExecutions (stepExecutions );
143150 }
144151
145152 @ Override
146153 public long countStepExecutions (JobInstance jobInstance , String stepName ) {
147- long count = 0 ;
148- // TODO optimize the count query
149- Query query = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
150- List <org .springframework .batch .core .repository .persistence .JobExecution > jobExecutions = this .mongoOperations
151- .find (query , org .springframework .batch .core .repository .persistence .JobExecution .class ,
152- JOB_EXECUTIONS_COLLECTION_NAME );
153- for (org .springframework .batch .core .repository .persistence .JobExecution jobExecution : jobExecutions ) {
154- List <org .springframework .batch .core .repository .persistence .StepExecution > stepExecutions = jobExecution
155- .getStepExecutions ();
156- for (org .springframework .batch .core .repository .persistence .StepExecution stepExecution : stepExecutions ) {
157- if (stepExecution .getName ().equals (stepName )) {
158- count ++;
159- }
160- }
154+ Query jobExecutionQuery = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
155+ List <Long > jobExecutionIds = this .mongoOperations
156+ .find (jobExecutionQuery , org .springframework .batch .core .repository .persistence .JobExecution .class ,
157+ JOB_EXECUTIONS_COLLECTION_NAME )
158+ .stream ()
159+ .map (org .springframework .batch .core .repository .persistence .JobExecution ::getJobExecutionId )
160+ .toList ();
161+
162+ if (jobExecutionIds .isEmpty ()) {
163+ return 0 ;
161164 }
162- return count ;
165+
166+ // Count step executions directly from BATCH_STEP_EXECUTION collection
167+ Query stepQuery = query (where ("name" ).is (stepName ).and ("jobExecutionId" ).in (jobExecutionIds ));
168+ return this .mongoOperations .count (stepQuery ,
169+ org .springframework .batch .core .repository .persistence .StepExecution .class ,
170+ STEP_EXECUTIONS_COLLECTION_NAME );
163171 }
164172
165173}
0 commit comments