Skip to content

Improving API to allow strong typing #683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.serverlessworkflow.api.types.ForTask;
import io.serverlessworkflow.api.types.Workflow;
import io.serverlessworkflow.api.types.func.ForTaskFunction;
import io.serverlessworkflow.api.types.func.TypedFunction;
import io.serverlessworkflow.impl.WorkflowApplication;
import io.serverlessworkflow.impl.WorkflowFilter;
import io.serverlessworkflow.impl.WorkflowPosition;
Expand All @@ -44,7 +45,7 @@ protected JavaForExecutorBuilder(
protected Optional<WorkflowFilter> buildWhileFilter() {
if (task instanceof ForTaskFunction taskFunctions) {
final LoopPredicateIndex whilePred = taskFunctions.getWhilePredicate();
Optional<Class<?>> modelClass = taskFunctions.getModelClass();
Optional<Class<?>> whileClass = taskFunctions.getWhileClass();
String varName = task.getFor().getEach();
String indexName = task.getFor().getAt();
if (whilePred != null) {
Expand All @@ -55,7 +56,7 @@ protected Optional<WorkflowFilter> buildWhileFilter() {
.modelFactory()
.from(
whilePred.test(
JavaFuncUtils.convert(n, modelClass),
JavaFuncUtils.convert(n, whileClass),
item,
(Integer) safeObject(t.variables().get(indexName))));
});
Expand All @@ -66,7 +67,15 @@ protected Optional<WorkflowFilter> buildWhileFilter() {

protected WorkflowFilter buildCollectionFilter() {
return task instanceof ForTaskFunction taskFunctions
? WorkflowUtils.buildWorkflowFilter(application, null, taskFunctions.getCollection())
? WorkflowUtils.buildWorkflowFilter(
application, null, collectionFilterObject(taskFunctions))
: super.buildCollectionFilter();
}

private Object collectionFilterObject(ForTaskFunction taskFunctions) {
return taskFunctions.getForClass().isPresent()
? new TypedFunction(
taskFunctions.getCollection(), taskFunctions.getForClass().orElseThrow())
: taskFunctions.getCollection();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import io.serverlessworkflow.api.types.SwitchTask;
import io.serverlessworkflow.api.types.Workflow;
import io.serverlessworkflow.api.types.func.SwitchCaseFunction;
import io.serverlessworkflow.api.types.func.TypedPredicate;
import io.serverlessworkflow.impl.WorkflowApplication;
import io.serverlessworkflow.impl.WorkflowFilter;
import io.serverlessworkflow.impl.WorkflowPosition;
import io.serverlessworkflow.impl.WorkflowUtils;
import io.serverlessworkflow.impl.executors.SwitchExecutor.SwitchExecutorBuilder;
import io.serverlessworkflow.impl.resources.ResourceLoader;
import java.util.Optional;
import java.util.function.Predicate;

public class JavaSwitchExecutorBuilder extends SwitchExecutorBuilder {

Expand All @@ -42,7 +44,14 @@ protected JavaSwitchExecutorBuilder(
@Override
protected Optional<WorkflowFilter> buildFilter(SwitchCase switchCase) {
return switchCase instanceof SwitchCaseFunction function
? Optional.of(WorkflowUtils.buildWorkflowFilter(application, null, function.predicate()))
? Optional.of(
WorkflowUtils.buildWorkflowFilter(
application, null, predObject(function.predicate(), function.predicateClass())))
: super.buildFilter(switchCase);
}

@SuppressWarnings({"unchecked", "rawtypes"})
private Object predObject(Predicate<?> pred, Optional<Class<?>> predClass) {
return predClass.isPresent() ? new TypedPredicate(pred, predClass.orElseThrow()) : pred;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import io.serverlessworkflow.api.types.TaskBase;
import io.serverlessworkflow.api.types.TaskMetadata;
import io.serverlessworkflow.api.types.func.TypedFunction;
import io.serverlessworkflow.api.types.func.TypedPredicate;
import io.serverlessworkflow.impl.TaskContext;
import io.serverlessworkflow.impl.WorkflowContext;
import io.serverlessworkflow.impl.WorkflowFilter;
Expand Down Expand Up @@ -52,8 +54,12 @@ public Expression buildExpression(String expression) {
public WorkflowFilter buildFilter(String expr, Object value) {
if (value instanceof Function func) {
return (w, t, n) -> modelFactory.fromAny(func.apply(n.asJavaObject()));
} else if (value instanceof TypedFunction func) {
return (w, t, n) -> modelFactory.fromAny(func.function().apply(n.as(func.argClass())));
} else if (value instanceof Predicate pred) {
return fromPredicate(pred);
} else if (value instanceof TypedPredicate pred) {
return fromPredicate(pred);
} else if (value instanceof BiPredicate pred) {
return (w, t, n) -> modelFactory.from(pred.test(w, t));
} else if (value instanceof BiFunction func) {
Expand All @@ -70,14 +76,23 @@ private WorkflowFilter fromPredicate(Predicate pred) {
return (w, t, n) -> modelFactory.from(pred.test(n.asJavaObject()));
}

@SuppressWarnings({"rawtypes", "unchecked"})
private WorkflowFilter fromPredicate(TypedPredicate pred) {
return (w, t, n) -> modelFactory.from(pred.pred().test(n.as(pred.argClass())));
}

@Override
public Optional<WorkflowFilter> buildIfFilter(TaskBase task) {
TaskMetadata metadata = task.getMetadata();
return metadata != null
&& metadata.getAdditionalProperties().get(TaskMetadataKeys.IF_PREDICATE)
instanceof Predicate pred
? Optional.of(fromPredicate(pred))
: ExpressionFactory.super.buildIfFilter(task);
if (metadata != null) {
Object obj = metadata.getAdditionalProperties().get(TaskMetadataKeys.IF_PREDICATE);
if (obj instanceof Predicate pred) {
return Optional.of(fromPredicate(pred));
} else if (obj instanceof TypedPredicate pred) {
return Optional.of(fromPredicate(pred));
}
}
return ExpressionFactory.super.buildIfFilter(task);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.serverlessworkflow.api.types.func;

import io.serverlessworkflow.api.types.ExportAs;
import java.util.Objects;
import java.util.function.Function;

public class ExportAsFunction extends ExportAs {
Expand All @@ -24,4 +25,10 @@ public <T, V> ExportAs withFunction(Function<T, V> value) {
setObject(value);
return this;
}

public <T, V> ExportAs withFunction(Function<T, V> value, Class<T> argClass) {
Objects.requireNonNull(argClass);
setObject(new TypedFunction<>(value, argClass));
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ public class ForTaskFunction extends ForTask {

private static final long serialVersionUID = 1L;
private LoopPredicateIndex<?, ?> whilePredicate;
private Optional<Class<?>> modelClass;
private Optional<Class<?>> whileClass;
private Optional<Class<?>> itemClass;
private Optional<Class<?>> forClass;
private Function<?, Collection<?>> collection;

public <T, V> ForTaskFunction withWhile(LoopPredicate<T, V> whilePredicate) {
Expand All @@ -53,35 +54,45 @@ public <T, V> ForTaskFunction withWhile(LoopPredicateIndex<T, V> whilePredicate)

public <T, V> ForTaskFunction withWhile(
LoopPredicateIndex<T, V> whilePredicate, Class<T> modelClass) {
return withWhile(whilePredicate, Optional.of(modelClass), Optional.empty());
return withWhile(whilePredicate, Optional.ofNullable(modelClass), Optional.empty());
}

public <T, V> ForTaskFunction withWhile(
LoopPredicateIndex<T, V> whilePredicate, Class<T> modelClass, Class<V> itemClass) {
return withWhile(whilePredicate, Optional.of(modelClass), Optional.of(itemClass));
return withWhile(whilePredicate, Optional.ofNullable(modelClass), Optional.of(itemClass));
}

private <T, V> ForTaskFunction withWhile(
LoopPredicateIndex<T, V> whilePredicate,
Optional<Class<?>> modelClass,
Optional<Class<?>> itemClass) {
this.whilePredicate = whilePredicate;
this.modelClass = modelClass;
this.whileClass = modelClass;
this.itemClass = itemClass;
return this;
}

public <T> ForTaskFunction withCollection(Function<T, Collection<?>> collection) {
return withCollection(collection, null);
}

public <T> ForTaskFunction withCollection(
Function<T, Collection<?>> collection, Class<T> colArgClass) {
this.collection = collection;
this.forClass = Optional.ofNullable(colArgClass);
return this;
}

public LoopPredicateIndex<?, ?> getWhilePredicate() {
return whilePredicate;
}

public Optional<Class<?>> getModelClass() {
return modelClass;
public Optional<Class<?>> getWhileClass() {
return whileClass;
}

public Optional<Class<?>> getForClass() {
return forClass;
}

public Optional<Class<?>> getItemClass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.serverlessworkflow.api.types.func;

import io.serverlessworkflow.api.types.InputFrom;
import java.util.Objects;
import java.util.function.Function;

public class InputFromFunction extends InputFrom {
Expand All @@ -24,4 +25,10 @@ public <T, V> InputFrom withFunction(Function<T, V> value) {
setObject(value);
return this;
}

public <T, V> InputFrom withFunction(Function<T, V> value, Class<T> argClass) {
Objects.requireNonNull(argClass);
setObject(new TypedFunction<>(value, argClass));
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.serverlessworkflow.api.types.func;

import io.serverlessworkflow.api.types.OutputAs;
import java.util.Objects;
import java.util.function.Function;

public class OutputAsFunction extends OutputAs {
Expand All @@ -24,4 +25,10 @@ public <T, V> OutputAs withFunction(Function<T, V> value) {
setObject(value);
return this;
}

public <T, V> OutputAs withFunction(Function<T, V> value, Class<T> argClass) {
Objects.requireNonNull(argClass);
setObject(new TypedFunction<>(value, argClass));
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,32 @@
package io.serverlessworkflow.api.types.func;

import io.serverlessworkflow.api.types.SwitchCase;
import java.util.Optional;
import java.util.function.Predicate;

public class SwitchCaseFunction extends SwitchCase {

private static final long serialVersionUID = 1L;
private Predicate<?> predicate;
private Optional<Class<?>> predicateClass;

public <T> SwitchCaseFunction withPredicate(Predicate<T> predicate) {
this.predicate = predicate;
this.predicateClass = Optional.empty();
return this;
}

public <T> void setPredicate(Predicate<T> predicate) {
public <T> SwitchCaseFunction withPredicate(Predicate<T> predicate, Class<T> predicateClass) {
this.predicate = predicate;
this.predicateClass = Optional.ofNullable(predicateClass);
return this;
}

public Predicate<?> predicate() {
return predicate;
}

public Optional<Class<?>> predicateClass() {
return predicateClass;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.serverlessworkflow.api.types.func;

import java.util.function.Function;

public record TypedFunction<T, V>(Function<T, V> function, Class<T> argClass) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.serverlessworkflow.api.types.func;

import java.util.function.Predicate;

public record TypedPredicate<T>(Predicate<T> pred, Class<T> argClass) {}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static Function<Cognisphere, Object> toFunction(AgentExecutor exec) {
return exec::invoke;
}

public static LoopPredicateIndex<Object, Object> toWhile(Predicate<Cognisphere> exit) {
return (model, item, idx) -> !exit.test((Cognisphere) model);
public static LoopPredicateIndex<Cognisphere, Object> toWhile(Predicate<Cognisphere> exit) {
return (model, item, idx) -> !exit.test(model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public LoopAgentsBuilder maxIterations(int maxIterations) {
}

public LoopAgentsBuilder exitCondition(Predicate<Cognisphere> exitCondition) {
this.forTask.withWhile(AgentAdapters.toWhile(exitCondition));
this.forTask.withWhile(AgentAdapters.toWhile(exitCondition), Cognisphere.class);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ public static final class SwitchCaseFunctionBuilder {
}

public <T> SwitchCaseFunctionBuilder when(Predicate<T> when) {
this.switchCase.setPredicate(when);
this.switchCase.withPredicate(when);
return this;
}

public <T> SwitchCaseFunctionBuilder when(Predicate<T> when, Class<T> whenClass) {
this.switchCase.withPredicate(when, whenClass);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@
package io.serverlessworkflow.fluent.func.spi;

import io.serverlessworkflow.api.types.TaskBase;
import io.serverlessworkflow.api.types.TaskMetadata;
import io.serverlessworkflow.impl.expressions.TaskMetadataKeys;
import io.serverlessworkflow.api.types.func.TypedPredicate;
import java.util.Objects;
import java.util.function.Predicate;

public interface ConditionalTaskBuilder<SELF> {

TaskBase getTask();

default SELF when(Predicate<?> predicate) {
if (getTask().getMetadata() == null) {
getTask().setMetadata(new TaskMetadata());
}
getTask().getMetadata().setAdditionalProperty(TaskMetadataKeys.IF_PREDICATE, predicate);
ConditionalTaskBuilderHelper.setMetadata(getTask(), predicate);
return (SELF) this;
}

default <T> SELF when(Predicate<T> predicate, Class<T> argClass) {
Objects.requireNonNull(argClass);
ConditionalTaskBuilderHelper.setMetadata(getTask(), new TypedPredicate<>(predicate, argClass));
return (SELF) this;
}
}
Loading
Loading