Skip to content

Commit 8d04d7f

Browse files
committed
address comments
1 parent 0a9f63d commit 8d04d7f

File tree

13 files changed

+84
-45
lines changed

13 files changed

+84
-45
lines changed

api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ public String formatString(Map<String, String> kwargs) {
7777
});
7878
}
7979

80+
/**
81+
* Formats messages with the given default role and arguments.
82+
*
83+
* <p><strong>Note:</strong> This method is primarily designed for Python interoperability.
84+
* Python code may pass arguments of various types (integers, floats, booleans, etc.) which are
85+
* automatically converted to strings for template processing.
86+
*
87+
* <p>For Java code, prefer using {@link #formatMessages(MessageRole, Map)} directly with a
88+
* {@code Map<String, String>} for better type safety.
89+
*
90+
* @param defaultRoleValue The default message role as a string value
91+
* @param kwargs A map of template arguments where values can be of any type and will be
92+
* converted to strings using {@link Object#toString()}
93+
* @return A list of formatted chat messages
94+
*/
8095
public List<ChatMessage> formatMessages(String defaultRoleValue, Map<String, Object> kwargs) {
8196
Map<String, String> arguments = new HashMap<>();
8297
for (Map.Entry<String, Object> entry : kwargs.entrySet()) {

api/src/main/java/org/apache/flink/agents/api/resource/python/PythonChatModelConnection.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ public class PythonChatModelConnection extends BaseChatModelConnection
3838
implements PythonResourceWrapper {
3939
private PyObject chatModel;
4040

41+
/**
42+
* Creates a new PythonChatModelConnection.
43+
*
44+
* @param adapter The Python resource adapter (required by PythonResourceProvider's
45+
* reflection-based instantiation but not used directly in this implementation)
46+
* @param chatModel The Python chat model object
47+
* @param descriptor The resource descriptor
48+
* @param getResource Function to retrieve resources by name and type
49+
*/
4150
public PythonChatModelConnection(
4251
PythonResourceAdapter adapter,
4352
PyObject chatModel,

api/src/main/java/org/apache/flink/agents/api/resource/python/PythonChatModelSetup.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.flink.agents.api.resource.python;
1919

2020
import org.apache.flink.agents.api.chat.messages.ChatMessage;
21+
import org.apache.flink.agents.api.chat.messages.MessageRole;
2122
import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
2223
import org.apache.flink.agents.api.resource.Resource;
2324
import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -36,6 +37,11 @@
3637
* compatibility while delegating actual chat operations to the underlying Python implementation.
3738
*/
3839
public class PythonChatModelSetup extends BaseChatModelSetup implements PythonResourceWrapper {
40+
private static final String FROM_JAVA_CHAT_MESSAGE = "python_java_utils.from_java_chat_message";
41+
42+
private static final String UPDATE_JAVA_CHAT_MESSAGE =
43+
"python_java_utils.update_java_chat_message";
44+
3945
private final PyObject chatModelSetup;
4046
private final PythonResourceAdapter adapter;
4147

@@ -55,13 +61,40 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
5561

5662
List<Object> pythonMessages = new ArrayList<>();
5763
for (ChatMessage message : messages) {
58-
pythonMessages.add(adapter.toPythonChatMessage(message));
64+
pythonMessages.add(toPythonChatMessage(message));
5965
}
6066

6167
kwargs.put("messages", pythonMessages);
6268

6369
Object pythonMessageResponse = adapter.callMethod(chatModelSetup, "chat", kwargs);
64-
return adapter.fromPythonChatMessage(pythonMessageResponse);
70+
return fromPythonChatMessage(pythonMessageResponse);
71+
}
72+
73+
/**
74+
* Converts a Java ChatMessage object to its Python equivalent.
75+
*
76+
* @param message the Java ChatMessage to convert
77+
* @return the Python representation of the chat message
78+
*/
79+
private Object toPythonChatMessage(ChatMessage message) {
80+
return adapter.invoke(FROM_JAVA_CHAT_MESSAGE, message);
81+
}
82+
83+
/**
84+
* Converts a Python chat message object back to a Java ChatMessage.
85+
*
86+
* @param pythonChatMessage the Python chat message object to convert
87+
* @return the Java ChatMessage representation
88+
*/
89+
private ChatMessage fromPythonChatMessage(Object pythonChatMessage) {
90+
ChatMessage message = new ChatMessage();
91+
92+
String roleValue =
93+
(String) adapter.invoke(UPDATE_JAVA_CHAT_MESSAGE, pythonChatMessage, message);
94+
95+
message.setRole(MessageRole.fromValue(roleValue));
96+
97+
return message;
6598
}
6699

67100
@Override

api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
package org.apache.flink.agents.api.resource.python;
2020

21-
import org.apache.flink.agents.api.chat.messages.ChatMessage;
2221
import pemja.core.object.PyObject;
2322

2423
import java.util.Map;
@@ -60,18 +59,11 @@ public interface PythonResourceAdapter {
6059
Object callMethod(Object obj, String methodName, Map<String, Object> kwargs);
6160

6261
/**
63-
* Converts a Java ChatMessage object to its Python equivalent.
62+
* Invokes a method with the specified name and arguments.
6463
*
65-
* @param message the Java ChatMessage to convert
66-
* @return the Python representation of the chat message
67-
*/
68-
Object toPythonChatMessage(ChatMessage message);
69-
70-
/**
71-
* Converts a Python chat message object back to a Java ChatMessage.
72-
*
73-
* @param pythonChatMessage the Python chat message object to convert
74-
* @return the Java ChatMessage representation
64+
* @param name the name of the method to invoke
65+
* @param args the arguments to pass to the method
66+
* @return the result of the method invocation
7567
*/
76-
ChatMessage fromPythonChatMessage(Object pythonChatMessage);
68+
Object invoke(String name, Object... args);
7769
}

plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public class AgentPlan implements Serializable {
7979

8080
private AgentConfiguration config;
8181

82-
private PythonResourceAdapter pythonResourceAdapter;
82+
private transient PythonResourceAdapter pythonResourceAdapter;
8383

8484
/** Cache for instantiated resources. */
8585
private transient Map<ResourceType, Map<String, Resource>> resourceCache;

plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ public PythonResourceProvider(String name, ResourceType type, ResourceDescriptor
6868
throw new IllegalArgumentException("clazz should not be null or empty.");
6969
}
7070
this.kwargs = new HashMap<>(descriptor.getInitialArguments());
71+
this.kwargs.remove("module");
72+
this.kwargs.remove("clazz");
7173
this.descriptor = descriptor;
7274
}
7375

@@ -92,7 +94,6 @@ public Resource provide(BiFunction<String, ResourceType, Resource> getResource)
9294
throws Exception {
9395
if (pythonResourceAdapter != null) {
9496
Class<?> clazz = Class.forName(descriptor.getClazz());
95-
Map<String, Object> kwargs = new HashMap<>(descriptor.getInitialArguments());
9697
PyObject pyResource =
9798
pythonResourceAdapter.initPythonResource(this.module, this.clazz, kwargs);
9899
Constructor<?> constructor =

plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ private PythonResourceProvider deserializePythonResourceProvider(JsonNode node)
7777
String name = node.get("name").asText();
7878
String type = node.get("type").asText();
7979
try {
80-
ResourceDescriptor descriptor =
81-
mapper.treeToValue(node.get("descriptor"), ResourceDescriptor.class);
82-
if (descriptor != null) {
80+
if (node.has("descriptor")) {
81+
ResourceDescriptor descriptor =
82+
mapper.treeToValue(node.get("descriptor"), ResourceDescriptor.class);
8383
return new PythonResourceProvider(name, ResourceType.fromValue(type), descriptor);
8484
}
8585
} catch (JsonProcessingException e) {

plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,13 @@ private void serializePythonResourceProvider(JsonGenerator gen, PythonResourcePr
6868
throws IOException {
6969
gen.writeStringField("name", provider.getName());
7070
gen.writeStringField("type", provider.getType().getValue());
71-
gen.writeObjectField("descriptor", provider.getDescriptor());
7271
gen.writeStringField("module", provider.getModule());
7372
gen.writeStringField("clazz", provider.getClazz());
7473

74+
if (provider.getDescriptor() != null) {
75+
gen.writeObjectField("descriptor", provider.getDescriptor());
76+
}
77+
7578
gen.writeFieldName("kwargs");
7679
gen.writeStartObject();
7780
provider.getKwargs()

plan/src/test/resources/resource_providers/python_resource_provider.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"type" : "chat_model",
44
"module" : "flink_agents.plan.tests.test_resource_provider",
55
"clazz" : "MockChatModelImpl",
6-
"descriptor" :null,
76
"kwargs" : {
87
"host" : "8.8.8.8",
98
"desc" : "mock chat model"

python/flink_agents/api/prompts/prompt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
from flink_agents.api.chat_message import ChatMessage, MessageRole
2424
from flink_agents.api.prompts.utils import format_string
25-
from flink_agents.api.resource import Resource, ResourceType, SerializableResource
25+
from flink_agents.api.resource import ResourceType, SerializableResource
2626

2727

28-
class Prompt(Resource, ABC):
28+
class Prompt(SerializableResource, ABC):
2929
"""Base prompt abstract."""
3030

3131
@staticmethod
@@ -55,7 +55,7 @@ def resource_type(cls) -> ResourceType:
5555
return ResourceType.PROMPT
5656

5757

58-
class LocalPrompt(Prompt, SerializableResource):
58+
class LocalPrompt(Prompt):
5959
"""Prompt for a language model.
6060
6161
Attributes:

0 commit comments

Comments
 (0)