From b89e056402c4a9582e65837b2e125f87685c8f13 Mon Sep 17 00:00:00 2001 From: youjin Date: Tue, 2 Dec 2025 17:04:49 +0800 Subject: [PATCH] [Feature][runtime] Support the use of Python ChatModel in Java --- .github/workflows/ci.yml | 17 +- api/pom.xml | 6 + .../agents/api/chat/messages/ChatMessage.java | 6 - .../python/PythonChatModelConnection.java | 74 +++++++ .../model/python/PythonChatModelSetup.java | 109 ++++++++++ .../python/PythonResourceAdapter.java | 69 +++++++ .../python/PythonResourceWrapper.java | 32 +++ .../python/PythonChatModelConnectionTest.java | 125 ++++++++++++ .../python/PythonChatModelSetupTest.java | 162 +++++++++++++++ .../pom.xml | 5 + .../test/ChatModelIntegrationAgent.java | 20 ++ .../test/ChatModelIntegrationTest.java | 5 +- .../apache/flink/agents/plan/AgentPlan.java | 34 +++- .../PythonResourceProvider.java | 51 ++++- .../ResourceProviderJsonDeserializer.java | 9 + .../ResourceProviderJsonSerializer.java | 4 + .../flink/agents/plan/AgentPlanTest.java | 107 +++++++++- python/flink_agents/api/tools/utils.py | 14 ++ python/flink_agents/runtime/java/__init__.py | 17 ++ .../runtime/java/java_resource_wrapper.py | 75 +++++++ .../flink_agents/runtime/python_java_utils.py | 155 +++++++++++++- .../operator/ActionExecutionOperator.java | 77 ++++++- .../python/utils/PythonActionExecutor.java | 18 +- .../utils/PythonResourceAdapterImpl.java | 104 ++++++++++ .../utils/PythonResourceAdapterImplTest.java | 192 ++++++++++++++++++ 25 files changed, 1439 insertions(+), 48 deletions(-) create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java create mode 100644 python/flink_agents/runtime/java/__init__.py create mode 100644 python/flink_agents/runtime/java/java_resource_wrapper.py create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e41b9895..8f38e5c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,6 +128,7 @@ jobs: fail-fast: false matrix: os: [ 'ubuntu-latest' ] + python-version: [ '3.11' ] steps: - uses: actions/checkout@v4 - name: Install java @@ -135,12 +136,22 @@ jobs: with: java-version: '11' distribution: 'adopt' - - name: Install flink-agents Java - run: bash tools/build.sh -j + - name: Install python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + - name: Install flink-agents + run: bash tools/build.sh - name: Install ollama run: bash tools/start_ollama_server.sh - name: Run Java IT - run: tools/ut.sh -j -e + run: | + export PYTHONPATH="${{ github.workspace }}/python/.venv/lib/python${{ matrix.python-version }}/site-packages:$PYTHONPATH" + tools/ut.sh -j -e cross_language_tests: name: cross-language [${{ matrix.os }}] [${{ matrix.python-version}}] diff --git a/api/pom.xml b/api/pom.xml index fb4123e8..a540361f 100644 --- a/api/pom.xml +++ b/api/pom.xml @@ -57,6 +57,12 @@ under the License. ${flink.version} provided + + com.alibaba + pemja + ${pemja.version} + provided + \ No newline at end of file diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java index 7ebc7870..6e844e74 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java @@ -32,9 +32,6 @@ */ public class ChatMessage { - /** The key for the message type in the metadata. */ - public static final String MESSAGE_TYPE = "messageType"; - private MessageRole role; private String content; private List> toolCalls; @@ -68,7 +65,6 @@ public ChatMessage( this.content = content != null ? content : ""; this.toolCalls = toolCalls != null ? toolCalls : new ArrayList<>(); this.extraArgs = extraArgs != null ? new HashMap<>(extraArgs) : new HashMap<>(); - this.extraArgs.put(MESSAGE_TYPE, this.role); } public MessageRole getRole() { @@ -77,7 +73,6 @@ public MessageRole getRole() { public void setRole(MessageRole role) { this.role = role; - this.extraArgs.put(MESSAGE_TYPE, this.role); } public String getContent() { @@ -102,7 +97,6 @@ public Map getExtraArgs() { public void setExtraArgs(Map extraArgs) { this.extraArgs = extraArgs != null ? extraArgs : new HashMap<>(); - this.extraArgs.put(MESSAGE_TYPE, this.role); } @JsonIgnore diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java new file mode 100644 index 00000000..fc191111 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.chat.model.python; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.tools.Tool; +import pemja.core.object.PyObject; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Python-based implementation of ChatModelConnection that wraps a Python chat model object. This + * class serves as a bridge between Java and Python chat model environments, but unlike {@link + * PythonChatModelSetup}, it does not provide direct chat functionality in Java. + */ +public class PythonChatModelConnection extends BaseChatModelConnection + implements PythonResourceWrapper { + private PyObject chatModel; + + /** + * Creates a new PythonChatModelConnection. + * + * @param adapter The Python resource adapter (required by PythonResourceProvider's + * reflection-based instantiation but not used directly in this implementation) + * @param chatModel The Python chat model object + * @param descriptor The resource descriptor + * @param getResource Function to retrieve resources by name and type + */ + public PythonChatModelConnection( + PythonResourceAdapter adapter, + PyObject chatModel, + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + this.chatModel = chatModel; + } + + @Override + public Object getPythonResource() { + return chatModel; + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + throw new UnsupportedOperationException( + "Chat method of PythonChatModelConnection cannot be called directly from Java runtime. " + + "This connection serves as a Python resource wrapper only. " + + "Chat operations should be performed on the Python side using the underlying Python chat model object."); + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java new file mode 100644 index 00000000..cd2626cf --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.chat.model.python; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import pemja.core.object.PyObject; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Python-based implementation of ChatModelSetup that bridges Java and Python chat model + * functionality. This class wraps a Python chat model setup object and provides Java interface + * compatibility while delegating actual chat operations to the underlying Python implementation. + */ +public class PythonChatModelSetup extends BaseChatModelSetup implements PythonResourceWrapper { + static final String FROM_JAVA_CHAT_MESSAGE = "python_java_utils.from_java_chat_message"; + + static final String TO_JAVA_CHAT_MESSAGE = "python_java_utils.to_java_chat_message"; + + private final PyObject chatModelSetup; + private final PythonResourceAdapter adapter; + + public PythonChatModelSetup( + PythonResourceAdapter adapter, + PyObject chatModelSetup, + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + this.chatModelSetup = chatModelSetup; + this.adapter = adapter; + } + + @Override + public ChatMessage chat(List messages, Map parameters) { + if (chatModelSetup == null) { + throw new IllegalStateException( + "ChatModelSetup is not initialized. Cannot perform chat operation."); + } + + Map kwargs = new HashMap<>(parameters); + + List pythonMessages = new ArrayList<>(); + for (ChatMessage message : messages) { + pythonMessages.add(toPythonChatMessage(message)); + } + + kwargs.put("messages", pythonMessages); + + Object pythonMessageResponse = adapter.callMethod(chatModelSetup, "chat", kwargs); + return fromPythonChatMessage(pythonMessageResponse); + } + + /** + * Converts a Java ChatMessage object to its Python equivalent. + * + * @param message the Java ChatMessage to convert + * @return the Python representation of the chat message + */ + private Object toPythonChatMessage(ChatMessage message) { + return adapter.invoke(FROM_JAVA_CHAT_MESSAGE, message); + } + + /** + * Converts a Python chat message object back to a Java ChatMessage. + * + * @param pythonChatMessage the Python chat message object to convert + * @return the Java ChatMessage representation + */ + private ChatMessage fromPythonChatMessage(Object pythonChatMessage) { + ChatMessage message = (ChatMessage) adapter.invoke(TO_JAVA_CHAT_MESSAGE, pythonChatMessage); + + return message; + } + + @Override + public Object getPythonResource() { + return chatModelSetup; + } + + @Override + public Map getParameters() { + return Map.of(); + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java new file mode 100644 index 00000000..44e64081 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.resource.python; + +import pemja.core.object.PyObject; + +import java.util.Map; + +/** + * Adapter interface for managing Python resources and facilitating Java-Python interoperability. + * This interface provides methods to interact with Python objects, invoke Python methods, and + * handle data conversion between Java and Python environments. + */ +public interface PythonResourceAdapter { + + /** + * Retrieves a Python resource by name and type. + * + * @param resourceName the name of the resource to retrieve + * @param resourceType the type of the resource + * @return the retrieved resource object + */ + Object getResource(String resourceName, String resourceType); + + /** + * Initializes a Python resource instance from the specified module and class. + * + * @param module the Python module containing the target class + * @param clazz the Python class name to instantiate + * @param kwargs keyword arguments to pass to the Python class constructor + * @return a PyObject representing the initialized Python resource + */ + PyObject initPythonResource(String module, String clazz, Map kwargs); + + /** + * Invokes a method on a Python object with the specified parameters. + * + * @param obj the Python object on which to call the method + * @param methodName the name of the method to invoke + * @param kwargs keyword arguments to pass to the method + * @return the result of the method invocation + */ + Object callMethod(Object obj, String methodName, Map kwargs); + + /** + * Invokes a method with the specified name and arguments. + * + * @param name the name of the method to invoke + * @param args the arguments to pass to the method + * @return the result of the method invocation + */ + Object invoke(String name, Object... args); +} diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java new file mode 100644 index 00000000..c69cf59b --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceWrapper.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.resource.python; + +/** + * Wrapper interface for Python resource objects. This interface provides a unified way to access + * the underlying Python resource from Java objects that encapsulate Python functionality. + */ +public interface PythonResourceWrapper { + + /** + * Retrieves the underlying Python resource object. + * + * @return the wrapped Python resource object + */ + Object getPythonResource(); +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java new file mode 100644 index 00000000..fa1f2370 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.chat.model.python; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.tools.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import pemja.core.object.PyObject; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +public class PythonChatModelConnectionTest { + @Mock private PythonResourceAdapter mockAdapter; + + @Mock private PyObject mockChatModel; + + @Mock private ResourceDescriptor mockDescriptor; + + @Mock private BiFunction mockGetResource; + + private PythonChatModelConnection pythonChatModelConnection; + private AutoCloseable mocks; + + @BeforeEach + void setUp() throws Exception { + mocks = MockitoAnnotations.openMocks(this); + pythonChatModelConnection = + new PythonChatModelConnection( + mockAdapter, mockChatModel, mockDescriptor, mockGetResource); + } + + @AfterEach + void tearDown() throws Exception { + if (mocks != null) { + mocks.close(); + } + } + + @Test + void testConstructor() { + assertThat(pythonChatModelConnection).isNotNull(); + assertThat(pythonChatModelConnection.getPythonResource()).isEqualTo(mockChatModel); + } + + @Test + void testGetPythonResourceWithNullChatModel() { + PythonChatModelConnection connectionWithNullModel = + new PythonChatModelConnection(mockAdapter, null, mockDescriptor, mockGetResource); + + Object result = connectionWithNullModel.getPythonResource(); + + assertThat(result).isNull(); + } + + @Test + void testChatThrowsUnsupportedOperationException() { + ChatMessage mockChatMessage = mock(ChatMessage.class); + Tool mockTool = mock(Tool.class); + List messages = Collections.singletonList(mockChatMessage); + List tools = Collections.singletonList(mockTool); + Map arguments = new HashMap<>(); + arguments.put("temperature", 0.7); + arguments.put("max_tokens", 100); + + assertThatThrownBy(() -> pythonChatModelConnection.chat(messages, tools, arguments)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining( + "Chat method of PythonChatModelConnection cannot be called directly from Java runtime") + .hasMessageContaining("This connection serves as a Python resource wrapper only") + .hasMessageContaining( + "Chat operations should be performed on the Python side using the underlying Python chat model object"); + } + + @Test + void testInheritanceFromBaseChatModelConnection() { + assertThat(pythonChatModelConnection).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + void testImplementsPythonResourceWrapper() { + assertThat(pythonChatModelConnection).isInstanceOf(PythonResourceWrapper.class); + } + + @Test + void testConstructorWithAllNullParameters() { + PythonChatModelConnection connectionWithNulls = + new PythonChatModelConnection(null, null, null, null); + + assertThat(connectionWithNulls).isNotNull(); + assertThat(connectionWithNulls.getPythonResource()).isNull(); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java new file mode 100644 index 00000000..01e0341f --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.chat.model.python; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import pemja.core.object.PyObject; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +public class PythonChatModelSetupTest { + @Mock private PythonResourceAdapter mockAdapter; + + @Mock private PyObject mockChatModelSetup; + + @Mock private ResourceDescriptor mockDescriptor; + + @Mock private BiFunction mockGetResource; + + private PythonChatModelSetup pythonChatModelSetup; + private AutoCloseable mocks; + + @BeforeEach + void setUp() throws Exception { + mocks = MockitoAnnotations.openMocks(this); + pythonChatModelSetup = + new PythonChatModelSetup( + mockAdapter, mockChatModelSetup, mockDescriptor, mockGetResource); + } + + @AfterEach + void tearDown() throws Exception { + if (mocks != null) { + mocks.close(); + } + } + + @Test + void testConstructor() { + assertThat(pythonChatModelSetup).isNotNull(); + assertThat(pythonChatModelSetup.getPythonResource()).isEqualTo(mockChatModelSetup); + } + + @Test + void testGetPythonResourceWithNullChatModelSetup() { + PythonChatModelSetup setupWithNullModel = + new PythonChatModelSetup(mockAdapter, null, mockDescriptor, mockGetResource); + + Object result = setupWithNullModel.getPythonResource(); + + assertThat(result).isNull(); + } + + @Test + void testGetParameters() { + Map result = pythonChatModelSetup.getParameters(); + + assertThat(result).isNotNull(); + assertThat(result).isEmpty(); + } + + @Test + void testChat() { + ChatMessage inputMessage = mock(ChatMessage.class); + ChatMessage outputMessage = mock(ChatMessage.class); + List messages = Collections.singletonList(inputMessage); + Map parameters = new HashMap<>(); + parameters.put("temperature", 0.7); + parameters.put("max_tokens", 100); + + Object pythonInputMessage = new Object(); + Object pythonOutputMessage = new Object(); + + when(mockAdapter.invoke(PythonChatModelSetup.FROM_JAVA_CHAT_MESSAGE, inputMessage)) + .thenReturn(pythonInputMessage); + when(mockAdapter.callMethod(eq(mockChatModelSetup), eq("chat"), any(Map.class))) + .thenReturn(pythonOutputMessage); + when(mockAdapter.invoke(PythonChatModelSetup.TO_JAVA_CHAT_MESSAGE, pythonOutputMessage)) + .thenReturn(outputMessage); + + ChatMessage result = pythonChatModelSetup.chat(messages, parameters); + + assertThat(result).isEqualTo(outputMessage); + verify(mockAdapter).invoke(PythonChatModelSetup.FROM_JAVA_CHAT_MESSAGE, inputMessage); + verify(mockAdapter) + .callMethod( + eq(mockChatModelSetup), + eq("chat"), + argThat( + kwargs -> { + assertThat(kwargs).containsKey("messages"); + assertThat(kwargs).containsKey("temperature"); + assertThat(kwargs).containsKey("max_tokens"); + assertThat(kwargs.get("temperature")).isEqualTo(0.7); + assertThat(kwargs.get("max_tokens")).isEqualTo(100); + List pythonMessages = (List) kwargs.get("messages"); + assertThat(pythonMessages).hasSize(1); + assertThat(pythonMessages.get(0)).isEqualTo(pythonInputMessage); + return true; + })); + verify(mockAdapter).invoke(PythonChatModelSetup.TO_JAVA_CHAT_MESSAGE, pythonOutputMessage); + } + + @Test + void testChatWithNullChatModelSetupThrowsException() { + PythonChatModelSetup setupWithNullModel = + new PythonChatModelSetup(mockAdapter, null, mockDescriptor, mockGetResource); + + ChatMessage inputMessage = mock(ChatMessage.class); + List messages = Collections.singletonList(inputMessage); + Map parameters = new HashMap<>(); + + assertThatThrownBy(() -> setupWithNullModel.chat(messages, parameters)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("ChatModelSetup is not initialized") + .hasMessageContaining("Cannot perform chat operation"); + } + + @Test + void testInheritanceFromBaseChatModelSetup() { + assertThat(pythonChatModelSetup) + .isInstanceOf(org.apache.flink.agents.api.chat.model.BaseChatModelSetup.class); + } + + @Test + void testImplementsPythonResourceWrapper() { + assertThat(pythonChatModelSetup) + .isInstanceOf( + org.apache.flink.agents.api.resource.python.PythonResourceWrapper.class); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml index 65c7d12e..67a8fb41 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml @@ -79,6 +79,11 @@ under the License. flink-agents-integrations-embedding-models-ollama ${project.version} + + org.apache.flink + flink-python + ${flink.version} + diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java index 9056f997..d114f883 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java @@ -28,6 +28,8 @@ import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.python.PythonChatModelConnection; +import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ChatRequestEvent; import org.apache.flink.agents.api.event.ChatResponseEvent; @@ -87,6 +89,12 @@ public static ResourceDescriptor chatModelConnection() { return ResourceDescriptor.Builder.newBuilder(OpenAIChatModelConnection.class.getName()) .addInitialArgument("api_key", apiKey) .build(); + } else if (provider.equals("PYTHON")) { + return ResourceDescriptor.Builder.newBuilder(PythonChatModelConnection.class.getName()) + .addInitialArgument( + "module", "flink_agents.integrations.chat_models.ollama_chat_model") + .addInitialArgument("clazz", "OllamaChatModelConnection") + .build(); } else { throw new RuntimeException(String.format("Unknown model provider %s", provider)); } @@ -120,6 +128,18 @@ public static ResourceDescriptor chatModel() { "tools", List.of("calculateBMI", "convertTemperature", "createRandomNumber")) .build(); + } else if (provider.equals("PYTHON")) { + return ResourceDescriptor.Builder.newBuilder(PythonChatModelSetup.class.getName()) + .addInitialArgument("connection", "chatModelConnection") + .addInitialArgument( + "module", "flink_agents.integrations.chat_models.ollama_chat_model") + .addInitialArgument("clazz", "OllamaChatModelSetup") + .addInitialArgument("model", OLLAMA_MODEL) + .addInitialArgument( + "tools", + List.of("calculateBMI", "convertTemperature", "createRandomNumber")) + .addInitialArgument("extract_reasoning", "true") + .build(); } else { throw new RuntimeException(String.format("Unknown model provider %s", provider)); } diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java index d1387b2d..711e0b32 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java @@ -44,6 +44,7 @@ public class ChatModelIntegrationTest extends OllamaPreparationUtils { private static final String API_KEY = "_API_KEY"; private static final String OLLAMA = "OLLAMA"; + private static final String PYTHON = "PYTHON"; private final boolean ollamaReady; @@ -52,10 +53,10 @@ public ChatModelIntegrationTest() throws IOException { } @ParameterizedTest() - @ValueSource(strings = {"OLLAMA", "AZURE", "OPENAI"}) + @ValueSource(strings = {"OLLAMA", "AZURE", "OPENAI", "PYTHON"}) public void testChatModeIntegration(String provider) throws Exception { Assumptions.assumeTrue( - (OLLAMA.equals(provider) && ollamaReady) + ((OLLAMA.equals(provider) || PYTHON.equals(provider)) && ollamaReady) || System.getenv().get(provider + API_KEY) != null, String.format( "Server or authentication information is not provided for %s", provider)); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 7cb81386..cc8b8a7d 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -30,12 +30,15 @@ import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.tools.ToolMetadata; import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.plan.actions.ChatModelAction; import org.apache.flink.agents.plan.actions.ToolCallAction; import org.apache.flink.agents.plan.resourceprovider.JavaResourceProvider; import org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; import org.apache.flink.agents.plan.serializer.AgentPlanJsonDeserializer; import org.apache.flink.agents.plan.serializer.AgentPlanJsonSerializer; @@ -76,6 +79,8 @@ public class AgentPlan implements Serializable { private AgentConfiguration config; + private transient PythonResourceAdapter pythonResourceAdapter; + /** Cache for instantiated resources. */ private transient Map> resourceCache; @@ -128,6 +133,10 @@ public AgentPlan(Agent agent, AgentConfiguration config) throws Exception { this.config = config; } + public void setPythonResourceAdapter(PythonResourceAdapter adapter) { + this.pythonResourceAdapter = adapter; + } + public Map getActions() { return actions; } @@ -174,6 +183,10 @@ public Resource getResource(String name, ResourceType type) throws Exception { ResourceProvider provider = resourceProviders.get(type).get(name); + if (pythonResourceAdapter != null && provider instanceof PythonResourceProvider) { + ((PythonResourceProvider) provider).setPythonResourceAdapter(pythonResourceAdapter); + } + // Create resource using provider Resource resource = provider.provide( @@ -279,8 +292,13 @@ private void extractActionsFromAgent(Agent agent) throws Exception { private void extractResource(ResourceType type, Method method) throws Exception { String name = method.getName(); + ResourceProvider provider; ResourceDescriptor descriptor = (ResourceDescriptor) method.invoke(null); - JavaResourceProvider provider = new JavaResourceProvider(name, type, descriptor); + if (PythonResourceWrapper.class.isAssignableFrom(Class.forName(descriptor.getClazz()))) { + provider = new PythonResourceProvider(name, type, descriptor); + } else { + provider = new JavaResourceProvider(name, type, descriptor); + } addResourceProvider(provider); } @@ -385,9 +403,17 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { ResourceType type = entry.getKey(); if (type == ResourceType.CHAT_MODEL || type == ResourceType.CHAT_MODEL_CONNECTION) { for (Map.Entry kv : entry.getValue().entrySet()) { - JavaResourceProvider provider = - new JavaResourceProvider( - kv.getKey(), type, (ResourceDescriptor) kv.getValue()); + ResourceProvider provider; + if (PythonResourceWrapper.class.isAssignableFrom( + Class.forName(((ResourceDescriptor) kv.getValue()).getClazz()))) { + provider = + new PythonResourceProvider( + kv.getKey(), type, (ResourceDescriptor) kv.getValue()); + } else { + provider = + new JavaResourceProvider( + kv.getKey(), type, (ResourceDescriptor) kv.getValue()); + } addResourceProvider(provider); } } else if (type == ResourceType.PROMPT) { diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java index 329a2db3..cdea068d 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java @@ -19,8 +19,13 @@ package org.apache.flink.agents.plan.resourceprovider; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import pemja.core.object.PyObject; +import java.lang.reflect.Constructor; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; @@ -35,6 +40,9 @@ public class PythonResourceProvider extends ResourceProvider { private final String module; private final String clazz; private final Map kwargs; + private final ResourceDescriptor descriptor; + + protected PythonResourceAdapter pythonResourceAdapter; public PythonResourceProvider( String name, @@ -46,6 +54,25 @@ public PythonResourceProvider( this.module = module; this.clazz = clazz; this.kwargs = kwargs; + this.descriptor = null; + } + + public PythonResourceProvider(String name, ResourceType type, ResourceDescriptor descriptor) { + super(name, type); + this.kwargs = new HashMap<>(descriptor.getInitialArguments()); + module = (String) kwargs.remove("module"); + if (module == null || module.isEmpty()) { + throw new IllegalArgumentException("module should not be null or empty."); + } + clazz = (String) kwargs.remove("clazz"); + if (clazz == null || clazz.isEmpty()) { + throw new IllegalArgumentException("clazz should not be null or empty."); + } + this.descriptor = descriptor; + } + + public void setPythonResourceAdapter(PythonResourceAdapter pythonResourceAdapter) { + this.pythonResourceAdapter = pythonResourceAdapter; } public String getModule() { @@ -63,11 +90,21 @@ public Map getKwargs() { @Override public Resource provide(BiFunction getResource) throws Exception { - // TODO: Implement Python resource creation logic - // This would typically involve calling into Python runtime to create the - // resource - throw new UnsupportedOperationException( - "Python resource creation not yet implemented in Java runtime"); + if (pythonResourceAdapter != null) { + Class clazz = Class.forName(descriptor.getClazz()); + PyObject pyResource = + pythonResourceAdapter.initPythonResource(this.module, this.clazz, kwargs); + Constructor constructor = + clazz.getConstructor( + PythonResourceAdapter.class, + PyObject.class, + ResourceDescriptor.class, + BiFunction.class); + return (Resource) + constructor.newInstance( + pythonResourceAdapter, pyResource, descriptor, getResource); + } + throw new UnsupportedOperationException("PythonResourceAdapter is not set."); } @Override @@ -92,4 +129,8 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.getName(), this.getType(), module, clazz, kwargs); } + + public ResourceDescriptor getDescriptor() { + return descriptor; + } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java index b548b4d3..7404bca1 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java @@ -76,6 +76,15 @@ public ResourceProvider deserialize( private PythonResourceProvider deserializePythonResourceProvider(JsonNode node) { String name = node.get("name").asText(); String type = node.get("type").asText(); + try { + if (node.has("descriptor")) { + ResourceDescriptor descriptor = + mapper.treeToValue(node.get("descriptor"), ResourceDescriptor.class); + return new PythonResourceProvider(name, ResourceType.fromValue(type), descriptor); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } String module = node.get("module").asText(); String clazz = node.get("clazz").asText(); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java index f11ff053..4b61c83a 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java @@ -71,6 +71,10 @@ private void serializePythonResourceProvider(JsonGenerator gen, PythonResourcePr gen.writeStringField("module", provider.getModule()); gen.writeStringField("clazz", provider.getClazz()); + if (provider.getDescriptor() != null) { + gen.writeObjectField("descriptor", provider.getDescriptor()); + } + gen.writeFieldName("kwargs"); gen.writeStartObject(); provider.getKwargs() diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java index e61dbb91..b42e176e 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java @@ -26,18 +26,25 @@ import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import pemja.core.object.PyObject; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Test for {@link AgentPlan} constructor that takes an Agent. */ public class AgentPlanTest { @@ -91,6 +98,27 @@ public ResourceType getResourceType() { } } + public static class TestPythonResource extends Resource implements PythonResourceWrapper { + + public TestPythonResource( + PythonResourceAdapter adapter, + PyObject chatModel, + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + } + + @Override + public ResourceType getResourceType() { + return ResourceType.CHAT_MODEL; + } + + @Override + public Object getPythonResource() { + return null; + } + } + /** Test agent class with annotated methods. */ public static class TestAgent extends Agent { @@ -120,6 +148,14 @@ public static class TestAgentWithResources extends Agent { private TestSerializableChatModel chatModel = new TestSerializableChatModel("defaultChatModel"); + @ChatModelSetup + public static ResourceDescriptor pythonChatModel() { + return ResourceDescriptor.Builder.newBuilder(TestPythonResource.class.getName()) + .addInitialArgument("module", "test.module") + .addInitialArgument("clazz", "TestClazz") + .build(); + } + @Tool private TestTool anotherTool = new TestTool("anotherTool"); @org.apache.flink.agents.api.annotation.Action(listenEvents = {InputEvent.class}) @@ -128,6 +164,39 @@ public void handleInputEvent(InputEvent event, RunnerContext context) { } } + /** Test agent class with illegal python resource. */ + public static class TestAgentWithIllegalPythonResource extends Agent { + @ChatModelSetup + public static ResourceDescriptor reviewAnalysisModel() { + return ResourceDescriptor.Builder.newBuilder(TestPythonResource.class.getName()) + .build(); + } + } + + public static class TestPythonResourceAdapter implements PythonResourceAdapter { + + @Override + public Object getResource(String resourceName, String resourceType) { + return null; + } + + @Override + public PyObject initPythonResource( + String module, String clazz, Map kwargs) { + return null; + } + + @Override + public Object callMethod(Object obj, String methodName, Map kwargs) { + return null; + } + + @Override + public Object invoke(String name, Object... args) { + return null; + } + } + @Test public void testConstructorWithAgent() throws Exception { // Create an agent instance @@ -303,8 +372,9 @@ public void testExtractResourceProvidersFromAgent() throws Exception { Map chatModelProviders = resourceProviders.get(ResourceType.CHAT_MODEL); assertThat(chatModelProviders).isNotNull(); - assertThat(chatModelProviders).hasSize(1); // defaultChatModel (field name used as default) + assertThat(chatModelProviders).hasSize(2); // defaultChatModel (field name used as default) assertThat(chatModelProviders).containsKey("chatModel"); + assertThat(chatModelProviders).containsKey("pythonChatModel"); // Verify that chat model provider is JavaSerializableResourceProvider // (serializable) @@ -319,6 +389,19 @@ public void testExtractResourceProvidersFromAgent() throws Exception { assertThat(serializableProvider.getModule()) .isEqualTo(TestAgentWithResources.class.getPackage().getName()); assertThat(serializableProvider.getClazz()).contains("TestSerializableChatModel"); + + // Verify that python chat model provider is PythonResourceProvider + // (serializable) + ResourceProvider pythonChatModelProvider = chatModelProviders.get("pythonChatModel"); + assertThat(pythonChatModelProvider).isInstanceOf(PythonResourceProvider.class); + assertThat(pythonChatModelProvider.getName()).isEqualTo("pythonChatModel"); + assertThat(pythonChatModelProvider.getType()).isEqualTo(ResourceType.CHAT_MODEL); + + // Test PythonResourceProvider specific methods + PythonResourceProvider pythonResourceProvider = + (PythonResourceProvider) pythonChatModelProvider; + assertThat(pythonResourceProvider.getClazz()).isEqualTo("TestClazz"); + assertThat(pythonResourceProvider.getModule()).isEqualTo("test.module"); } @Test @@ -339,8 +422,30 @@ public void testGetResourceFromResourceProvider() throws Exception { assertThat(chatModel).isInstanceOf(TestSerializableChatModel.class); assertThat(chatModel.getResourceType()).isEqualTo(ResourceType.CHAT_MODEL); + assertThatThrownBy(() -> agentPlan.getResource("pythonChatModel", ResourceType.CHAT_MODEL)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("PythonResourceAdapter is not set"); + + agentPlan.setPythonResourceAdapter(new TestPythonResourceAdapter()); + Resource pythonChatModel = + agentPlan.getResource("pythonChatModel", ResourceType.CHAT_MODEL); + assertThat(pythonChatModel).isNotNull(); + assertThat(pythonChatModel).isInstanceOf(TestPythonResource.class); + assertThat(pythonChatModel.getResourceType()).isEqualTo(ResourceType.CHAT_MODEL); + // Test that resources are cached (should be the same instance) Resource myToolAgain = agentPlan.getResource("myTool", ResourceType.TOOL); assertThat(myTool).isSameAs(myToolAgain); } + + @Test + public void testExtractIllegalResourceProviderFromAgent() throws Exception { + // Create an agent with resource annotations + TestAgentWithIllegalPythonResource agent = new TestAgentWithIllegalPythonResource(); + + // Expect IllegalArgumentException when creating AgentPlan with illegal resource + assertThatThrownBy(() -> new AgentPlan(agent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("module should not be null or empty"); + } } diff --git a/python/flink_agents/api/tools/utils.py b/python/flink_agents/api/tools/utils.py index 81784d48..976fcdec 100644 --- a/python/flink_agents/api/tools/utils.py +++ b/python/flink_agents/api/tools/utils.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import json import typing from inspect import signature from typing import Any, Callable, Dict, Optional, Type, Union @@ -178,6 +179,19 @@ def resolve_field_type(field_schema: dict) -> type[typing.Any]: return create_model(name, **main_fields, __doc__=schema.get("description", "")) +def create_model_from_java_tool_schema_str(name: str, schema_str: str) -> type[BaseModel]: + """Create Pydantic model from a java tool input schema.""" + json_schema = json.loads(schema_str) + properties = json_schema["properties"] + + fields = {} + for param_name in properties: + description = properties[param_name]["description"] + if description is None: + description = f"Parameter: {param_name}" + type = TYPE_MAPPING.get(properties[param_name]["type"]) + fields[param_name] = (type, FieldInfo(description=description)) + return create_model(name, **fields) def extract_mcp_content_item(content_item: Any) -> Dict[str, Any] | str: """Extract and normalize a single MCP content item. diff --git a/python/flink_agents/runtime/java/__init__.py b/python/flink_agents/runtime/java/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/runtime/java/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py new file mode 100644 index 00000000..496aab33 --- /dev/null +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -0,0 +1,75 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any, List + +from pemja import findClass +from pydantic import Field +from typing_extensions import override + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.prompts.prompt import Prompt +from flink_agents.api.resource import Resource, ResourceType +from flink_agents.api.tools.tool import Tool, ToolType + + +class JavaTool(Tool): + """Java Tool that carries tool metadata and can be recognized by PythonChatModel.""" + + @classmethod + @override + def tool_type(cls) -> ToolType: + """Get the tool type.""" + return ToolType.REMOTE_FUNCTION + + @override + def call(self, *args: Any, **kwargs: Any) -> Any: + err_msg = "Java tool is defined in Java and needs to be executed through the Java runtime." + raise NotImplementedError(err_msg) + +class JavaPrompt(Prompt): + """Python wrapper for Java's Prompt.""" + + j_prompt: Any= Field(exclude=True) + + @override + def format_string(self, **kwargs: str) -> str: + return self.j_prompt.formatString(kwargs) + + @override + def format_messages( + self, role: MessageRole = MessageRole.SYSTEM, **kwargs: str + ) -> List[ChatMessage]: + j_MessageRole = findClass("org.apache.flink.agents.api.chat.messages.MessageRole") + j_chat_messages = self.j_prompt.formatMessages(j_MessageRole.fromValue(role.value), kwargs) + chatMessages = [ChatMessage(role=MessageRole(j_chat_message.getRole().getValue()), + content=j_chat_message.getContent(), + tool_calls= j_chat_message.getToolCalls(), + extra_args=j_chat_message.getExtraArgs()) for j_chat_message in j_chat_messages] + return chatMessages + +class JavaGetResourceWrapper: + """Python wrapper for Java ResourceAdapter.""" + + def __init__(self, j_resource_adapter: Any) -> None: + """Initialize with a Java ResourceAdapter.""" + self._j_resource_adapter = j_resource_adapter + + + def get_resource(self, name: str, type: ResourceType) -> Resource: + """Get a resource by name and type.""" + return self._j_resource_adapter.getResource(name, type.value) diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index 586e6cb0..8b44bd62 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -15,11 +15,22 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any +import importlib +from typing import Any, Callable, Dict import cloudpickle +from pemja import findClass +from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.events.event import InputEvent +from flink_agents.api.resource import Resource +from flink_agents.api.tools.tool import ToolMetadata +from flink_agents.api.tools.utils import create_model_from_java_tool_schema_str +from flink_agents.runtime.java.java_resource_wrapper import ( + JavaGetResourceWrapper, + JavaPrompt, + JavaTool, +) def convert_to_python_object(bytesObject: bytes) -> Any: @@ -40,3 +51,145 @@ def wrap_to_input_event(bytesObject: bytes) -> tuple[bytes, str]: def get_output_from_output_event(bytesObject: bytes) -> Any: """Get output data from OutputEvent and serialize.""" return cloudpickle.dumps(convert_to_python_object(bytesObject).output) + +def create_resource(resource_module: str, resource_clazz: str, func_kwargs: Dict[str, Any]) -> Resource: + """Dynamically create a resource instance from module and class name. + + Args: + resource_module: The module path containing the resource class + resource_clazz: The class name to instantiate + func_kwargs: Keyword arguments to pass to the class constructor + + Returns: + Resource: An instance of the specified resource class + """ + module = importlib.import_module(resource_module) + cls = getattr(module, resource_clazz) + return cls(**func_kwargs) + +def get_resource_function(j_resource_adapter: Any) -> Callable: + """Create a callable wrapper for Java resource adapter. + + Args: + j_resource_adapter: Java resource adapter object + + Returns: + Callable: A Python callable that wraps the Java resource adapter + """ + return JavaGetResourceWrapper(j_resource_adapter).get_resource + +def from_java_tool(j_tool: Any) -> JavaTool: + """Convert a Java tool object to a Python JavaTool instance. + + Args: + j_tool: Java tool object + + Returns: + JavaTool: Python wrapper for the Java tool with extracted metadata + """ + name = j_tool.getName() + metadata = ToolMetadata( + name=name, + description=j_tool.getDescription(), + args_schema=create_model_from_java_tool_schema_str(name, j_tool.getMetadata().getInputSchema()), + ) + return JavaTool(metadata=metadata) + +def from_java_prompt(j_prompt: Any) -> JavaPrompt: + """Convert a Java prompt object to a Python JavaPrompt instance. + + Args: + j_prompt: Java prompt object to be wrapped + + Returns: + JavaPrompt: Python wrapper for the Java prompt + """ + return JavaPrompt(j_prompt=j_prompt) + +def convert_tool_call_to_python_format(tool_call: Dict[str, Any]) -> Dict[str, Any]: + """Convert a tool call dictionary to Python-compatible format. + + Args: + tool_call: Dictionary containing tool call information with keys: + - id: Tool call identifier + - function: Function details + - original_id: Optional original tool call ID + + Returns: + Dict containing Python-formatted tool call with keys: + - id: String representation of tool call ID + - function: Original function details + - additional_kwargs: Dictionary containing original_tool_call_id (if present) + """ + return { + "id": str(tool_call.get("id", "")), + "function": tool_call["function"], # Will raise KeyError if missing + "additional_kwargs": {"original_tool_call_id", tool_call.get("original_id", "")}, + } + +def from_java_chat_message(j_chat_message: Any) -> ChatMessage: + """Convert a chat message to a python chat message.""" + return ChatMessage(role=MessageRole(j_chat_message.getRole().getValue()), + content=j_chat_message.getContent(), + tool_calls=[convert_tool_call_to_python_format(tool_call) for tool_call in j_chat_message.getToolCalls()], + extra_args=j_chat_message.getExtraArgs()) + +def convert_tool_call_to_java_format(tool_call: Dict[str, Any]) -> Dict[str, Any]: + """Convert a tool call dictionary to Java-compatible format. + + Args: + tool_call: Dictionary containing tool call information with keys: + - id: Tool call identifier + - function: Function details + - additional_kwargs: Optional additional parameters containing + original_tool_call_id + + Returns: + Dict containing Java-formatted tool call with keys: + - id: String representation of tool call ID + - function: Original function details + - original_id: Original tool call ID from additional_kwargs (if present) + """ + return { + "id": str(tool_call.get("id", "")), + "function": tool_call["function"], # Will raise KeyError if missing + "original_id": tool_call.get("additional_kwargs", {}).get( + "original_tool_call_id" + ), + } + +def to_java_chat_message(chat_message: ChatMessage) -> Any: + """Convert a chat message to a java chat message.""" + j_ChatMessage = findClass("org.apache.flink.agents.api.chat.messages.ChatMessage") + j_chat_message = j_ChatMessage() + + j_MessageRole = findClass("org.apache.flink.agents.api.chat.messages.MessageRole") + j_chat_message.setRole(j_MessageRole.fromValue(chat_message.role.value)) + j_chat_message.setContent(chat_message.content) + j_chat_message.setExtraArgs(chat_message.extra_args) + if chat_message.tool_calls: + tool_calls = [convert_tool_call_to_java_format(tool_call) for tool_call in chat_message.tool_calls] + j_chat_message.setToolCalls(tool_calls) + + return j_chat_message + +def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any: + """Calls a method on `obj` by name and passes in positional and keyword arguments. + + Parameters: + obj: Any Python object + method_name: A string representing the name of the method to call + kwargs: Keyword arguments to pass to the method + + Returns: + The return value of the method + + Raises: + AttributeError: If the object does not have the specified method + """ + if not hasattr(obj, method_name): + err_msg = f"Object {obj} has no attribute '{method_name}'" + raise AttributeError(err_msg) + + method = getattr(obj, method_name) + return method(**kwargs) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 2a545631..7712917a 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -27,14 +27,17 @@ import org.apache.flink.agents.api.logger.EventLoggerConfig; import org.apache.flink.agents.api.logger.EventLoggerFactory; import org.apache.flink.agents.api.logger.EventLoggerOpenParams; +import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; @@ -45,6 +48,7 @@ import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.python.operator.PythonActionTask; import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; @@ -75,6 +79,7 @@ import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import pemja.core.PythonInterpreter; import java.lang.reflect.Field; import java.util.ArrayList; @@ -119,9 +124,16 @@ public class ActionExecutionOperator extends AbstractStreamOperator shortTermMemState; + private transient PythonEnvironmentManager pythonEnvironmentManager; + + private transient PythonInterpreter pythonInterpreter; + // PythonActionExecutor for Python actions private transient PythonActionExecutor pythonActionExecutor; + // PythonResourceAdapter for Python resources in Java actions + private transient PythonResourceAdapterImpl pythonResourceAdapter; + private transient FlinkAgentsMetricGroupImpl metricGroup; private transient BuiltInMetrics builtInMetrics; @@ -251,8 +263,8 @@ public void open() throws Exception { new ListStateDescriptor<>( "currentProcessingKeys", TypeInformation.of(Object.class))); - // init PythonActionExecutor - initPythonActionExecutor(); + // init PythonActionExecutor and PythonResourceAdapter + initPythonEnvironment(); mailboxProcessor = getMailboxProcessor(); @@ -496,17 +508,29 @@ private void processActionTaskForKey(Object key) throws Exception { } } - private void initPythonActionExecutor() throws Exception { + private void initPythonEnvironment() throws Exception { boolean containPythonAction = agentPlan.getActions().values().stream() .anyMatch(action -> action.getExec() instanceof PythonFunction); - if (containPythonAction) { - LOG.debug("Begin initialize PythonActionExecutor."); + + boolean containPythonResource = + agentPlan.getResourceProviders().values().stream() + .anyMatch( + resourceProviderMap -> + resourceProviderMap.values().stream() + .anyMatch( + resourceProvider -> + resourceProvider + instanceof + PythonResourceProvider)); + + if (containPythonAction || containPythonResource) { + LOG.debug("Begin initialize PythonEnvironmentManager."); PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create( getExecutionConfig().toConfiguration(), getRuntimeContext().getDistributedCache()); - PythonEnvironmentManager pythonEnvironmentManager = + pythonEnvironmentManager = new PythonEnvironmentManager( dependencyInfo, getContainingTask() @@ -515,14 +539,39 @@ private void initPythonActionExecutor() throws Exception { .getTmpDirectories(), new HashMap<>(System.getenv()), getRuntimeContext().getJobInfo().getJobId()); - pythonActionExecutor = - new PythonActionExecutor( - pythonEnvironmentManager, - new ObjectMapper().writeValueAsString(agentPlan)); - pythonActionExecutor.open(); + pythonEnvironmentManager.open(); + EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); + pythonInterpreter = env.getInterpreter(); + if (containPythonAction) { + initPythonActionExecutor(); + } else { + initPythonResourceAdapter(); + } } } + private void initPythonActionExecutor() throws Exception { + pythonActionExecutor = + new PythonActionExecutor( + pythonInterpreter, new ObjectMapper().writeValueAsString(agentPlan)); + pythonActionExecutor.open(); + } + + private void initPythonResourceAdapter() throws Exception { + pythonResourceAdapter = + new PythonResourceAdapterImpl( + (String anotherName, ResourceType anotherType) -> { + try { + return agentPlan.getResource(anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + pythonInterpreter); + pythonResourceAdapter.open(); + agentPlan.setPythonResourceAdapter(pythonResourceAdapter); + } + @Override public void endInput() throws Exception { waitInFlightEventsFinished(); @@ -540,6 +589,12 @@ public void close() throws Exception { if (pythonActionExecutor != null) { pythonActionExecutor.close(); } + if (pythonInterpreter != null) { + pythonInterpreter.close(); + } + if (pythonEnvironmentManager != null) { + pythonEnvironmentManager.close(); + } if (eventLogger != null) { eventLogger.close(); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index ded01408..9b8c7c44 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -19,8 +19,6 @@ import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.runtime.context.RunnerContextImpl; -import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; -import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.utils.EventUtil; import pemja.core.PythonInterpreter; @@ -59,21 +57,16 @@ public class PythonActionExecutor { private static final String GET_OUTPUT_FROM_OUTPUT_EVENT = "python_java_utils.get_output_from_output_event"; - private final PythonEnvironmentManager environmentManager; + private final PythonInterpreter interpreter; private final String agentPlanJson; - private PythonInterpreter interpreter; private Object pythonAsyncThreadPool; - public PythonActionExecutor(PythonEnvironmentManager environmentManager, String agentPlanJson) { - this.environmentManager = environmentManager; + public PythonActionExecutor(PythonInterpreter interpreter, String agentPlanJson) { + this.interpreter = interpreter; this.agentPlanJson = agentPlanJson; } public void open() throws Exception { - environmentManager.open(); - EmbeddedPythonEnvironment env = environmentManager.createEnvironment(); - - interpreter = env.getInterpreter(); interpreter.exec(PYTHON_IMPORTS); pythonAsyncThreadPool = interpreter.invoke(CREATE_ASYNC_THREAD_POOL); @@ -160,11 +153,6 @@ public void close() throws Exception { if (pythonAsyncThreadPool != null) { interpreter.invoke(CLOSE_ASYNC_THREAD_POOL, pythonAsyncThreadPool); } - interpreter.close(); - } - - if (environmentManager != null) { - environmentManager.close(); } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java new file mode 100644 index 00000000..f0465407 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.python.utils; + +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.tools.Tool; +import pemja.core.PythonInterpreter; +import pemja.core.object.PyObject; + +import java.util.Map; +import java.util.function.BiFunction; + +public class PythonResourceAdapterImpl implements PythonResourceAdapter { + + static final String PYTHON_IMPORTS = "from flink_agents.runtime import python_java_utils"; + + static final String GET_RESOURCE_KEY = "get_resource"; + + static final String PYTHON_MODULE_PREFIX = "python_java_utils."; + + static final String GET_RESOURCE_FUNCTION = PYTHON_MODULE_PREFIX + "get_resource_function"; + + static final String CALL_METHOD = PYTHON_MODULE_PREFIX + "call_method"; + + static final String CREATE_RESOURCE = PYTHON_MODULE_PREFIX + "create_resource"; + + static final String FROM_JAVA_TOOL = PYTHON_MODULE_PREFIX + "from_java_tool"; + + static final String FROM_JAVA_PROMPT = PYTHON_MODULE_PREFIX + "from_java_prompt"; + + private final BiFunction getResource; + private final PythonInterpreter interpreter; + private Object pythonGetResourceFunction; + + public PythonResourceAdapterImpl( + BiFunction getResource, PythonInterpreter interpreter) { + this.getResource = getResource; + this.interpreter = interpreter; + } + + public void open() { + interpreter.exec(PYTHON_IMPORTS); + pythonGetResourceFunction = interpreter.invoke(GET_RESOURCE_FUNCTION, this); + } + + public Object getResource(String resourceName, String resourceType) { + Resource resource = + this.getResource.apply(resourceName, ResourceType.fromValue(resourceType)); + if (resource instanceof PythonResourceWrapper) { + PythonResourceWrapper pythonResource = (PythonResourceWrapper) resource; + return pythonResource.getPythonResource(); + } + if (resource instanceof Tool) { + return convertToPythonTool((Tool) resource); + } + if (resource instanceof Prompt) { + return convertToPythonPrompt((Prompt) resource); + } + return resource; + } + + @Override + public PyObject initPythonResource(String module, String clazz, Map kwargs) { + kwargs.put(GET_RESOURCE_KEY, pythonGetResourceFunction); + return (PyObject) interpreter.invoke(CREATE_RESOURCE, module, clazz, kwargs); + } + + private Object convertToPythonTool(Tool tool) { + return interpreter.invoke(FROM_JAVA_TOOL, tool); + } + + private Object convertToPythonPrompt(Prompt prompt) { + return interpreter.invoke(FROM_JAVA_PROMPT, prompt); + } + + @Override + public Object callMethod(Object obj, String methodName, Map kwargs) { + return interpreter.invoke(CALL_METHOD, obj, methodName, kwargs); + } + + @Override + public Object invoke(String name, Object... args) { + return interpreter.invoke(name, args); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java new file mode 100644 index 00000000..18ee02e9 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.python.utils; + +import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.tools.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import pemja.core.PythonInterpreter; +import pemja.core.object.PyObject; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +public class PythonResourceAdapterImplTest { + @Mock private PythonInterpreter mockInterpreter; + + @Mock private BiFunction getResource; + + private PythonResourceAdapterImpl pythonResourceAdapter; + private AutoCloseable mocks; + + @BeforeEach + void setUp() throws Exception { + mocks = MockitoAnnotations.openMocks(this); + pythonResourceAdapter = new PythonResourceAdapterImpl(getResource, mockInterpreter); + } + + @AfterEach + void tearDown() throws Exception { + if (mocks != null) { + mocks.close(); + } + } + + @Test + void testInitPythonResource() { + String module = "test_module"; + String clazz = "TestClass"; + Map kwargs = new HashMap<>(); + kwargs.put("param1", "value1"); + kwargs.put("param2", 42); + + PyObject expectedResult = mock(PyObject.class); + when(mockInterpreter.invoke( + PythonResourceAdapterImpl.CREATE_RESOURCE, module, clazz, kwargs)) + .thenReturn(expectedResult); + + PyObject result = pythonResourceAdapter.initPythonResource(module, clazz, kwargs); + + assertThat(result).isEqualTo(expectedResult); + assertThat(kwargs).containsKey(PythonResourceAdapterImpl.GET_RESOURCE_KEY); + verify(mockInterpreter) + .invoke(PythonResourceAdapterImpl.CREATE_RESOURCE, module, clazz, kwargs); + } + + @Test + void testOpen() { + + pythonResourceAdapter.open(); + + verify(mockInterpreter).exec(PythonResourceAdapterImpl.PYTHON_IMPORTS); + verify(mockInterpreter) + .invoke(PythonResourceAdapterImpl.GET_RESOURCE_FUNCTION, pythonResourceAdapter); + } + + @Test + void testGetResourceWithPythonResourceWrapper() { + String resourceName = "test_resource"; + String resourceType = "chat_model"; + PythonResourceWrapper mockPythonChatModelSetup = mock(PythonChatModelSetup.class); + Object expectedPythonResource = new Object(); + + when(getResource.apply(resourceName, ResourceType.CHAT_MODEL)) + .thenReturn((Resource) mockPythonChatModelSetup); + when(mockPythonChatModelSetup.getPythonResource()).thenReturn(expectedPythonResource); + + Object result = pythonResourceAdapter.getResource(resourceName, resourceType); + + assertThat(result).isEqualTo(expectedPythonResource); + verify(getResource).apply(resourceName, ResourceType.CHAT_MODEL); + verify(mockPythonChatModelSetup).getPythonResource(); + } + + @Test + void testGetResourceWithTool() { + String resourceName = "test_tool"; + String resourceType = "tool"; + Tool mockTool = mock(Tool.class); + Object expectedPythonTool = new Object(); + + when(getResource.apply(resourceName, ResourceType.TOOL)).thenReturn(mockTool); + when(mockInterpreter.invoke(PythonResourceAdapterImpl.FROM_JAVA_TOOL, mockTool)) + .thenReturn(expectedPythonTool); + + Object result = pythonResourceAdapter.getResource(resourceName, resourceType); + + assertThat(result).isEqualTo(expectedPythonTool); + verify(getResource).apply(resourceName, ResourceType.TOOL); + verify(mockInterpreter).invoke(PythonResourceAdapterImpl.FROM_JAVA_TOOL, mockTool); + } + + @Test + void testGetResourceWithPrompt() { + String resourceName = "test_prompt"; + String resourceType = "prompt"; + Prompt mockPrompt = mock(Prompt.class); + Object expectedPythonPrompt = new Object(); + + when(getResource.apply(resourceName, ResourceType.PROMPT)).thenReturn(mockPrompt); + when(mockInterpreter.invoke(PythonResourceAdapterImpl.FROM_JAVA_PROMPT, mockPrompt)) + .thenReturn(expectedPythonPrompt); + + Object result = pythonResourceAdapter.getResource(resourceName, resourceType); + + assertThat(result).isEqualTo(expectedPythonPrompt); + verify(getResource).apply(resourceName, ResourceType.PROMPT); + verify(mockInterpreter).invoke(PythonResourceAdapterImpl.FROM_JAVA_PROMPT, mockPrompt); + } + + @Test + void testGetResourceWithRegularResource() { + String resourceName = "test_resource"; + String resourceType = "chat_model"; + Resource mockResource = mock(Resource.class); + + when(getResource.apply(resourceName, ResourceType.CHAT_MODEL)).thenReturn(mockResource); + + Object result = pythonResourceAdapter.getResource(resourceName, resourceType); + + assertThat(result).isEqualTo(mockResource); + verify(getResource).apply(resourceName, ResourceType.CHAT_MODEL); + } + + @Test + void testCallMethod() { + // Arrange + Object obj = new Object(); + String methodName = "test_method"; + Map kwargs = Map.of("param", "value"); + Object expectedResult = "method_result"; + + when(mockInterpreter.invoke(PythonResourceAdapterImpl.CALL_METHOD, obj, methodName, kwargs)) + .thenReturn(expectedResult); + + Object result = pythonResourceAdapter.callMethod(obj, methodName, kwargs); + + assertThat(result).isEqualTo(expectedResult); + verify(mockInterpreter) + .invoke(PythonResourceAdapterImpl.CALL_METHOD, obj, methodName, kwargs); + } + + @Test + void testInvoke() { + String name = "test_function"; + Object[] args = {"arg1", 42, true}; + Object expectedResult = "invoke_result"; + + when(mockInterpreter.invoke(name, args)).thenReturn(expectedResult); + + Object result = pythonResourceAdapter.invoke(name, args); + + assertThat(result).isEqualTo(expectedResult); + verify(mockInterpreter).invoke(name, args); + } +}