From 5a0cc945cbd3116a116082d6439e7e3ce86a5bee Mon Sep 17 00:00:00 2001 From: opensearch-ci-bot Date: Fri, 29 Aug 2025 00:10:02 +0000 Subject: [PATCH 01/10] Increment version to 2.19.4-SNAPSHOT Signed-off-by: opensearch-ci-bot Signed-off-by: Brian Flores --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index ecc794f5a2..be85478a58 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "2.19.3-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "2.19.4-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") // 2.0.0-rc1-SNAPSHOT -> 2.0.0.0-rc1-SNAPSHOT From 4e1e1203c16f79d671101f7a29db29171f4ef9fd Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 22 Oct 2025 15:57:41 -0700 Subject: [PATCH 02/10] fix CVE-2025-55163, CVE-2025-48924 (#4298) * address commons-lang3 CVE-2025-48924 Signed-off-by: Brian Flores * pin netty to 4.2.5.Final version address CVE-2025-55163 Signed-off-by: Brian Flores * force all subProjects to use updated common-lang3 version Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores --- build.gradle | 3 +++ ml-algorithms/build.gradle | 4 +++- search-processors/build.gradle | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index be85478a58..509643b5a4 100644 --- a/build.gradle +++ b/build.gradle @@ -71,6 +71,7 @@ allprojects { } + subprojects { configurations { testImplementation.extendsFrom compileOnly @@ -80,6 +81,8 @@ subprojects { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' + resolutionStrategy.force "org.apache.commons:commons-lang3:${versions.commonslang}" + } } diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 9405f6a9ee..c50f0ab54a 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -88,7 +88,9 @@ dependencies { } implementation('net.minidev:json-smart:2.5.2') implementation group: 'org.json', name: 'json', version: '20231013' - implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18" + implementation(enforcedPlatform("io.netty:netty-bom:4.2.5.Final")) + implementation("software.amazon.awssdk:netty-nio-client") + testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' diff --git a/search-processors/build.gradle b/search-processors/build.gradle index e9fbc9a585..2f9f8bb380 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -27,6 +27,7 @@ repositories { mavenLocal() } + dependencies { implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" From 61ad7f76abd1ced131a06a6aca3d4dd0bb021b7a Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 30 Sep 2025 14:11:54 +0800 Subject: [PATCH 03/10] Move HttpClientFactory to common to expose to other components (#4175) * Move HttpClientFactory to common to expose to other componenets Signed-off-by: zane-neo * optimize code for better maintainability Signed-off-by: zane-neo * Optimize code and increase UT coverage Signed-off-by: zane-neo * Address comments Signed-off-by: zane-neo * Use amazon aws version from opensearch core Signed-off-by: zane-neo * address comments Signed-off-by: zane-neo --------- Signed-off-by: zane-neo Signed-off-by: Brian Flores --- common/build.gradle | 4 + .../httpclient/MLHttpClientFactory.java | 67 +++--- .../httpclient/MLHttpClientFactoryTests.java | 191 ++++++++++++++++++ ml-algorithms/build.gradle | 5 +- .../remote/AwsConnectorExecutor.java | 2 +- .../remote/HttpJsonConnectorExecutor.java | 2 +- .../httpclient/MLHttpClientFactoryTests.java | 112 ---------- plugin/build.gradle | 12 +- 8 files changed, 230 insertions(+), 165 deletions(-) rename {ml-algorithms/src/main/java/org/opensearch/ml/engine => common/src/main/java/org/opensearch/ml/common}/httpclient/MLHttpClientFactory.java (58%) create mode 100644 common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java delete mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java diff --git a/common/build.gradle b/common/build.gradle index b84024b2b4..da0afd564d 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -44,6 +44,10 @@ dependencies { compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' // Multi-tenant SDK Client compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + exclude(group: 'org.reactivestreams', module: 'reactive-streams') + exclude(group: 'org.slf4j', module: 'slf4j-api') + } } lombok { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java similarity index 58% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java rename to common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java index ffc95c30de..109a5de5f8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java @@ -3,19 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.engine.httpclient; +package org.opensearch.ml.common.httpclient; import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.util.Arrays; import java.util.Locale; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.common.util.concurrent.ThreadContextAccess; + import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @@ -24,19 +23,15 @@ public class MLHttpClientFactory { public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { - try { - return AccessController - .doPrivileged( - (PrivilegedExceptionAction) () -> NettyNioAsyncHttpClient - .builder() - .connectionTimeout(connectionTimeout) - .readTimeout(readTimeout) - .maxConcurrency(maxConnections) - .build() - ); - } catch (PrivilegedActionException e) { - return null; - } + return ThreadContextAccess + .doPrivileged( + () -> NettyNioAsyncHttpClient + .builder() + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .maxConcurrency(maxConnections) + .build() + ); } /** @@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { - log.error("Remote inference protocol is not http or https: " + protocol); + log.error("Remote inference protocol is not http or https: {}", protocol); throw new IllegalArgumentException("Protocol is not http or https: " + protocol); } // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. @@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea } } if (port < 0 || port > 65536) { - log.error("Remote inference port out of range: " + port); + log.error("Remote inference port out of range: {}", port); throw new IllegalArgumentException("Port out of range: " + port); } validateIp(host, connectorPrivateIpEnabled); @@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) { - log.error("Remote inference host name has private ip address: " + hostName); + log.error("Remote inference host name has private ip address: {}", hostName); throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); } } @@ -83,23 +78,8 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { if (bytes.length != 4) { return true; } else { - int firstOctets = bytes[0] & 0xff; - int firstInOctal = parseWithOctal(String.valueOf(firstOctets)); - int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16); - if (firstInOctal == 127 || firstInHex == 127) { - return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1; - } else if (firstInOctal == 10 || firstInHex == 10) { + if (isPrivateIPv4(bytes)) { return true; - } else if (firstInOctal == 172 || firstInHex == 172) { - int secondOctets = bytes[1] & 0xff; - int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); - int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); - return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32); - } else if (firstInOctal == 192 || firstInHex == 192) { - int secondOctets = bytes[1] & 0xff; - int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); - int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); - return secondInOctal == 168 || secondInHex == 168; } } } @@ -107,11 +87,14 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); } - private static int parseWithOctal(String input) { - try { - return Integer.parseInt(input, 8); - } catch (NumberFormatException e) { - return Integer.parseInt(input); - } + private static boolean isPrivateIPv4(byte[] bytes) { + int first = bytes[0] & 0xff; + int second = bytes[1] & 0xff; + + // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x + return (first == 10) + || (first == 172 && second >= 16 && second <= 31) + || (first == 192 && second == 168) + || (first == 169 && second == 254); } } diff --git a/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java new file mode 100644 index 0000000000..1c01172344 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.httpclient; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + +public class MLHttpClientFactoryTests { + + private static final String TEST_HOST = "api.openai.com"; + private static final String HTTP = "http"; + private static final String HTTPS = "https"; + private static final AtomicBoolean PRIVATE_IP_DISABLED = new AtomicBoolean(false); + private static final AtomicBoolean PRIVATE_IP_ENABLED = new AtomicBoolean(true); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void test_getSdkAsyncHttpClient_success() { + SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100); + assertNotNull(client); + } + + @Test + public void test_invalidIP_localHost_privateIPDisabled() { + IllegalArgumentException e1 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage()); + + IllegalArgumentException e3 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage()); + + IllegalArgumentException e4 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage()); + + IllegalArgumentException e5 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage()); + } + + @Test + public void test_validateIp_validIp_noException() throws Exception { + MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTPS, TEST_HOST, 443, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTPS, "127.0.0.1", 443, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "177.16.0.1", 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED); + } + + @Test + public void test_validateIp_rarePrivateIp_throwException() throws Exception { + try { + MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_DISABLED); + } catch (Exception e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "177.0.0.1", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validate(HTTP, "12.16.2.3", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + } + + @Test + public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception { + MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTPS, "0254.020.00.01", 443, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTPS, "0:0:0:0:0:ffff:127.0.0.1", 443, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "10.24.76.186", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTPS, "10.24.76.186", 443, PRIVATE_IP_ENABLED); + } + + @Test + public void test_validateSchemaAndPort_success() throws Exception { + MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + MLHttpClientFactory.validate("ftp", TEST_HOST, 80, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_portNotInRange1_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port out of range: 65537"); + MLHttpClientFactory.validate(HTTPS, TEST_HOST, 65537, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_portNotInRange2_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port out of range: -10"); + MLHttpClientFactory.validate(HTTP, TEST_HOST, -10, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validatePort_boundaries_success() throws Exception { + MLHttpClientFactory.validate(HTTP, TEST_HOST, 65536, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, TEST_HOST, 0, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, TEST_HOST, -1, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTPS, TEST_HOST, -1, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(null, TEST_HOST, -1, PRIVATE_IP_DISABLED); + } + +} diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index c50f0ab54a..22c46163ba 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -58,6 +58,7 @@ dependencies { // Multi-tenant SDK Client implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" implementation 'commons-beanutils:commons-beanutils:1.11.0' + implementation "org.opensearch:opensearch-remote-metadata-sdk-ddb-client:${opensearch_build}" def os = DefaultNativePlatform.currentOperatingSystem //arm/macos doesn't support GPU @@ -88,9 +89,7 @@ dependencies { } implementation('net.minidev:json-smart:2.5.2') implementation group: 'org.json', name: 'json', version: '20231013' - implementation(enforcedPlatform("io.netty:netty-bom:4.2.5.Final")) - implementation("software.amazon.awssdk:netty-nio-client") - + implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18" testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index e0bcd1bc73..e7e5b7c71f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -24,11 +24,11 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import lombok.Getter; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 5ac0245701..57f2a7019c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -26,11 +26,11 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import lombok.Getter; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java deleted file mode 100644 index 1d79ac995e..0000000000 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.httpclient; - -import static org.junit.Assert.assertNotNull; - -import java.time.Duration; -import java.util.concurrent.atomic.AtomicBoolean; - -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - -import software.amazon.awssdk.http.async.SdkAsyncHttpClient; - -public class MLHttpClientFactoryTests { - - @Rule - public ExpectedException expectedException = ExpectedException.none(); - - @Test - public void test_getSdkAsyncHttpClient_success() { - SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100); - assertNotNull(client); - } - - @Test - public void test_validateIp_validIp_noException() throws Exception { - AtomicBoolean privateIpEnabled = new AtomicBoolean(false); - MLHttpClientFactory.validate("http", "api.openai.com", 80, privateIpEnabled); - } - - @Test - public void test_validateIp_rarePrivateIp_throwException() throws Exception { - AtomicBoolean privateIpEnabled = new AtomicBoolean(false); - try { - MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled); - } catch (Exception e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - } - - @Test - public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception { - AtomicBoolean privateIpEnabled = new AtomicBoolean(true); - MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled); - MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled); - } - - @Test - public void test_validateSchemaAndPort_success() throws Exception { - MLHttpClientFactory.validate("http", "api.openai.com", 80, new AtomicBoolean(false)); - } - - @Test - public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { - expectedException.expect(IllegalArgumentException.class); - MLHttpClientFactory.validate("ftp", "api.openai.com", 80, new AtomicBoolean(false)); - } - - @Test - public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Port out of range: 65537"); - MLHttpClientFactory.validate("https", "api.openai.com", 65537, new AtomicBoolean(false)); - } - -} diff --git a/plugin/build.gradle b/plugin/build.gradle index a352702c2e..e569d9bcb4 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -54,15 +54,15 @@ dependencies { implementation project(':opensearch-ml-memory') compileOnly "com.google.guava:guava:32.1.3-jre" - implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18" - implementation group: 'software.amazon.awssdk', name: 's3', version: "2.30.18" - implementation group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "${versions.aws}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" From 9cd7bd7b2e2d49bddadb3f87bacf6d06da53057e Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 22 Oct 2025 16:32:09 -0700 Subject: [PATCH 04/10] use mainline versions.aws via hardcode Signed-off-by: Brian Flores --- build.gradle | 1 + common/build.gradle | 2 +- ml-algorithms/build.gradle | 23 +++++++++++++---------- plugin/build.gradle | 12 ++++++------ 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/build.gradle b/build.gradle index 509643b5a4..f46051d0b2 100644 --- a/build.gradle +++ b/build.gradle @@ -82,6 +82,7 @@ subprojects { resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "org.apache.commons:commons-lang3:${versions.commonslang}" + resolutionStrategy.force 'software.amazon.awssdk:bom:2.32.29' } } diff --git a/common/build.gradle b/common/build.gradle index da0afd564d..660df0b1cb 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -44,7 +44,7 @@ dependencies { compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' // Multi-tenant SDK Client compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" - compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.32.29") { exclude(group: 'org.reactivestreams', module: 'reactive-streams') exclude(group: 'org.slf4j', module: 'slf4j-api') } diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 22c46163ba..9800a8bb55 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -72,24 +72,27 @@ dependencies { } } - implementation platform('software.amazon.awssdk:bom:2.30.18') - api 'software.amazon.awssdk:auth:2.30.18' - implementation 'software.amazon.awssdk:apache-client' + implementation platform(group: 'software.amazon.awssdk', name: 'bom', version:"2.32.29") + api "software.amazon.awssdk:auth:2.32.29" + implementation group: 'software.amazon.awssdk', name:'apache-client', version: "2.32.29" + implementation (group: 'software.amazon.awssdk', name: 'bedrockruntime', version: "2.32.29") { + exclude group: 'io.netty' + } implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') { exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on' } - implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1' - - compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18" - compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.30.18" - compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18" + // needed by aws-encryption-sdk-java + implementation "org.bouncycastle:bc-fips:2.1.1" + compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.32.29" + compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.32.29" + compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.32.29" implementation ('com.jayway.jsonpath:json-path:2.9.0') { exclude group: 'net.minidev', module: 'json-smart' } implementation('net.minidev:json-smart:2.5.2') implementation group: 'org.json', name: 'json', version: '20231013' - implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.32.29" testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' @@ -102,7 +105,7 @@ lombok { configurations.all { resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5' resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' - resolutionStrategy.force 'software.amazon.awssdk:bom:2.30.18' + resolutionStrategy.force group: 'software.amazon.awssdk', name:'bom', version:"2.32.29" resolutionStrategy.force 'commons-beanutils:commons-beanutils:1.11.0' } diff --git a/plugin/build.gradle b/plugin/build.gradle index e569d9bcb4..b1a0b05a27 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -54,15 +54,15 @@ dependencies { implementation project(':opensearch-ml-memory') compileOnly "com.google.guava:guava:32.1.3-jre" - implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.32.29" + implementation group: 'software.amazon.awssdk', name: 's3', version: "2.32.29" + implementation group: 'software.amazon.awssdk', name: 'regions', version: "2.32.29" - implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "2.32.29" - implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "2.32.29" - implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.32.29" zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" From 3cf3ad1c23c0bfba322f792ffe6ba7fa41905d69 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 22 Oct 2025 17:02:48 -0700 Subject: [PATCH 05/10] address CVE-2025-58057 Signed-off-by: Brian Flores --- build.gradle | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index f46051d0b2..2c5c348efd 100644 --- a/build.gradle +++ b/build.gradle @@ -83,8 +83,15 @@ subprojects { resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "org.apache.commons:commons-lang3:${versions.commonslang}" resolutionStrategy.force 'software.amazon.awssdk:bom:2.32.29' - - } + resolutionStrategy.force 'io.netty:netty-buffer:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-codec:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-codec-http:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-codec-http2:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-common:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-handler:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-resolver:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-transport:4.1.125.Final' + resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.125.Final' } } ext { From 5a0c7f1656625d5a9efed2d25264e22f7a87a3dd Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 22 Oct 2025 17:04:21 -0700 Subject: [PATCH 06/10] fix code format Signed-off-by: Brian Flores --- build.gradle | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 2c5c348efd..4d9549cba8 100644 --- a/build.gradle +++ b/build.gradle @@ -83,6 +83,7 @@ subprojects { resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "org.apache.commons:commons-lang3:${versions.commonslang}" resolutionStrategy.force 'software.amazon.awssdk:bom:2.32.29' + resolutionStrategy.force 'io.netty:netty-buffer:4.1.125.Final' resolutionStrategy.force 'io.netty:netty-codec:4.1.125.Final' resolutionStrategy.force 'io.netty:netty-codec-http:4.1.125.Final' @@ -91,7 +92,8 @@ subprojects { resolutionStrategy.force 'io.netty:netty-handler:4.1.125.Final' resolutionStrategy.force 'io.netty:netty-resolver:4.1.125.Final' resolutionStrategy.force 'io.netty:netty-transport:4.1.125.Final' - resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.125.Final' } + resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.125.Final' + } } ext { From 214c8b617fef1cb663401fd11f1576cd1717a768 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Thu, 23 Oct 2025 14:52:10 -0700 Subject: [PATCH 07/10] empty commit to trigger CI Signed-off-by: Brian Flores From 42c89edaf1db78bdbf66755180c79111c14191e8 Mon Sep 17 00:00:00 2001 From: Xinyuan Lu Date: Fri, 12 Sep 2025 04:59:30 +0800 Subject: [PATCH 08/10] Fix claude model it (#4167) * fix model it by replace claude v1/v2 Signed-off-by: xinyual * remove useless change Signed-off-by: xinyual --------- Signed-off-by: xinyual Signed-off-by: Brian Flores --- .../ml/rest/RestConnectorToolIT.java | 42 ++++---- ...tMLInferenceSearchResponseProcessorIT.java | 42 ++++---- .../ml/rest/RestMLRAGSearchProcessorIT.java | 98 +++++++++++++++++-- 3 files changed, 140 insertions(+), 42 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java index 0f75be90ac..0f90380b0a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java @@ -19,7 +19,9 @@ public class RestConnectorToolIT extends RestBaseAgentToolsIT { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"; private String bedrockClaudeConnectorId; private String bedrockClaudeConnectorIdForPredict; @@ -35,8 +37,8 @@ public void setUp() throws Exception { private String createBedrockClaudeConnector(String action) throws IOException, InterruptedException { String bedrockClaudeConnectorEntity = "{\n" - + " \"name\": \"BedRock Claude instant-v1 Connector \",\n" - + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"name\": \"Bedrock Connector: claude 3.5\",\n" + + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" + " \"version\": 1,\n" + " \"protocol\": \"aws_sigv4\",\n" + " \"parameters\": {\n" @@ -44,10 +46,11 @@ private String createBedrockClaudeConnector(String action) throws IOException, I + GITHUB_CI_AWS_REGION + "\",\n" + " \"service_name\": \"bedrock\",\n" - + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" - + " \"max_tokens_to_sample\": 8000,\n" - + " \"temperature\": 0.0001,\n" - + " \"response_filter\": \"$.completion\"\n" + + " \"model\": \"" + + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET + + "\",\n" + + " \"system_prompt\": \"You are a helpful assistant.\",\n" + + "\"response_filter\": \"$.output.message.content[0].text\"" + " },\n" + " \"credential\": {\n" + " \"access_key\": \"" @@ -61,19 +64,22 @@ private String createBedrockClaudeConnector(String action) throws IOException, I + "\"\n" + " },\n" + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"" + + " {\n" + + " \"action_type\": \"" + action + "\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" - + " }\n" - + " ]\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/" + + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET + + "/converse\",\n" + + " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.messages}\\\"}]}] , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" + + " }\n" + + " ]\n" + "}"; return registerConnector(bedrockClaudeConnectorEntity); } @@ -135,7 +141,7 @@ public void testConnectorToolInFlowAgent() throws IOException { + " ]\n" + "}"; String agentId = createAgent(registerAgentRequestBody); - String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + String agentInput = "{\n" + " \"parameters\": {\n" + " \"messages\": \"hello\"\n" + " }\n" + "}"; String result = executeAgent(agentId, agentInput); assertNotNull(result); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java index f2ef5495fc..315f7da3e0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -67,6 +67,7 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; private final String bedrockEmbeddingModelConnectorEntity = "{\n" @@ -109,8 +110,8 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC + "}"; private final String bedrockClaudeModelConnectorEntity = "{\n" - + " \"name\": \"BedRock Claude instant-v1 Connector\",\n" - + " \"description\": \"The connector to bedrock for claude model\",\n" + + " \"name\": \"Bedrock Connector: claude 3.5\",\n" + + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" + " \"version\": 1,\n" + " \"protocol\": \"aws_sigv4\",\n" + " \"parameters\": {\n" @@ -118,11 +119,11 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC + GITHUB_CI_AWS_REGION + "\",\n" + " \"service_name\": \"bedrock\",\n" - + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" - + " \"max_tokens_to_sample\": 8000,\n" - + " \"temperature\": 0.0001,\n" - + " \"response_filter\": \"$.completion\",\n" - + " \"stop_sequences\": [\"\\n\\nHuman:\",\"\\nObservation:\",\"\\n\\tObservation:\",\"\\nObservation\",\"\\n\\tObservation\",\"\\n\\nQuestion\"]\n" + + " \"model\": \"" + + "anthropic.claude-3-5-sonnet-20240620-v1:0" + + "\",\n" + + " \"system_prompt\": \"You are a helpful assistant.\",\n" + + "\"response_filter\": \"$.output.message.content[0].text\"" + " },\n" + " \"credential\": {\n" + " \"access_key\": \"" @@ -136,17 +137,22 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC + "\"\n" + " },\n" + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"stop_sequences\\\": ${parameters.stop_sequences}, \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" - + " }\n" - + " ]\n" + + " {\n" + + " \"action_type\": \"" + + "predict" + + "\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/" + + "anthropic.claude-3-5-sonnet-20240620-v1:0" + + "/converse\",\n" + + " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.prompt}\\\"}]}] , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" + + " }\n" + + " ]\n" + "}"; private final String bedrockMultiModalEmbeddingModelConnectorEntity = "{\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index b16abef59f..52cba14041 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -111,6 +111,52 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"; private static final String BEDROCK_ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT_INVOKE = "{\n" + + " \"name\": \"Bedrock Connector: claude 3.5\",\n" + + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model\": \"" + + "anthropic.claude-3-5-sonnet-20240620-v1:0" + + "\",\n" + + " \"system_prompt\": \"You are a helpful assistant.\",\n" + + "\"response_filter\": \"$.content[0].text\"" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"" + + "predict" + + "\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"url\": \"https://bedrock-runtime." + + GITHUB_CI_AWS_REGION + + ".amazonaws.com/model/" + + "anthropic.claude-3-5-sonnet-20240620-v1:0" + + "/invoke\",\n" + + " \"request_body\": \"{\\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.inputs}\\\"}]}], \\\"max_tokens\\\":300, \\\"temperature\\\":0.5, \\\"anthropic_version\\\":\\\"bedrock-2023-05-31\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n" + " \"name\": \"Bedrock Connector: claude2\",\n" + " \"description\": \"The connector to bedrock claude2 model\",\n" @@ -181,7 +227,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " ]\n" + "}"; - private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" + static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" + " \"name\": \"Bedrock Connector: claude 3.5\",\n" + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" + " \"version\": 1,\n" @@ -268,8 +314,8 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + "}"; private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null - ? BEDROCK_CONNECTOR_BLUEPRINT2 - : BEDROCK_CONNECTOR_BLUEPRINT1; + ? BEDROCK_CONNECTOR_BLUEPRINT_INVOKE + : BEDROCK_CONNECTOR_BLUEPRINT_INVOKE; private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null ? BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 @@ -425,6 +471,26 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " }\n" + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n" + + " \"_source\": [\"%s\"],\n" + + " \"query\" : {\n" + + " \"match\": {\"%s\": \"%s\"}\n" + + " },\n" + + " \"ext\": {\n" + + " \"generative_qa_parameters\": {\n" + + " \"llm_model\": \"%s\",\n" + + " \"llm_question\": \"%s\",\n" + + " \"memory_id\": \"%s\",\n" + + " \"system_prompt\": \"%s\",\n" + + " \"user_instructions\": \"%s\",\n" + + " \"context_size\": %d,\n" + + " \"message_size\": %d,\n" + + " \"timeout\": %d,\n" + + " \"llm_response_field\": \"%s\"\n" + + " }\n" + + " }\n" + + "}"; + private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -705,6 +771,7 @@ public void testBM25WithBedrock() throws Exception { requestParameters.contextSize = 5; requestParameters.interactionSize = 5; requestParameters.timeout = 60; + requestParameters.llmResponseField = "response"; Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); assertEquals(200, response2.getStatusLine().getStatusCode()); @@ -1068,6 +1135,7 @@ public void testBM25WithBedrockWithConversation() throws Exception { requestParameters.interactionSize = 5; requestParameters.timeout = 60; requestParameters.conversationId = conversationId; + requestParameters.llmResponseField = "response"; Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); assertEquals(200, response2.getStatusLine().getStatusCode()); @@ -1240,7 +1308,7 @@ private Response performSearch(String indexName, String pipeline, int size, Sear throws Exception { // TODO build these templates dynamically - String httpEntity = requestParameters.llmResponseField != null + String httpEntity = requestParameters.llmResponseField != null && requestParameters.conversationId == null ? String .format( Locale.ROOT, @@ -1351,10 +1419,27 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.interactionSize, requestParameters.timeout ) + : (requestParameters.llmResponseField == null) + ? String + .format( + Locale.ROOT, + BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE, + requestParameters.source, + requestParameters.source, + requestParameters.match, + requestParameters.llmModel, + requestParameters.llmQuestion, + requestParameters.conversationId, + requestParameters.systemPrompt, + requestParameters.userInstructions, + requestParameters.contextSize, + requestParameters.interactionSize, + requestParameters.timeout + ) : String .format( Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE, + BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE, requestParameters.source, requestParameters.source, requestParameters.match, @@ -1365,7 +1450,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.userInstructions, requestParameters.contextSize, requestParameters.interactionSize, - requestParameters.timeout + requestParameters.timeout, + requestParameters.llmResponseField ); return makeRequest( client(), From c91cb5c211e19aff8f226d8e7da718ee61c6b615 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 10 Jun 2025 16:45:49 -0700 Subject: [PATCH 09/10] Don't convert schema-defined strings to other types during validation (#3761) Signed-off-by: Daniel Widdis Signed-off-by: Brian Flores --- .../TransportPredictionTaskAction.java | 2 +- .../org/opensearch/ml/utils/MLNodeUtils.java | 23 ++++++--- .../opensearch/ml/utils/MLNodeUtilsTests.java | 51 +++++++++++++++---- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 0a58954826..5a469b35aa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -262,7 +262,7 @@ public void validateInputSchema(String modelId, MLInput mlInput) { try { String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString(); // Process the parameters field in the input dataset to convert it back to its original datatype, instead of a string - String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString); + String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString, inputSchemaString); MLNodeUtils.validateSchema(inputSchemaString, processedInputString); } catch (Exception e) { throw new OpenSearchStatusException( diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index 3cbbc62ef5..ed7bba9053 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -91,13 +91,18 @@ public static void validateSchema(String schemaString, String instanceString) th } /** - * This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON. + * This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON, unless the schema defines the value as a string. * @param inputJson The input JSON string + * @param schemaJson The schema matching the input JSON string * @return The processed JSON string */ - public static String processRemoteInferenceInputDataSetParametersValue(String inputJson) throws IOException { + public static String processRemoteInferenceInputDataSetParametersValue(String inputJson, String schemaJson) throws IOException { ObjectMapper mapper = new ObjectMapper(); JsonNode rootNode = mapper.readTree(inputJson); + JsonNode schemaNode = mapper.readTree(schemaJson); + + // Get the schema properties for parameters if they exist + JsonNode parametersSchema = schemaNode.path("properties").path("parameters").path("properties"); if (rootNode.has("parameters") && rootNode.get("parameters").isObject()) { ObjectNode parametersNode = (ObjectNode) rootNode.get("parameters"); @@ -106,15 +111,12 @@ public static String processRemoteInferenceInputDataSetParametersValue(String in String key = entry.getKey(); JsonNode value = entry.getValue(); - if (value.isTextual()) { - String textValue = value.asText(); + if (value.isTextual() && !isStringTypeInSchema(parametersSchema, key)) { try { - // Try to parse the string as JSON - JsonNode parsedValue = mapper.readTree(textValue); - // If successful, replace the string with the parsed JSON + JsonNode parsedValue = mapper.readTree(value.asText()); parametersNode.set(key, parsedValue); } catch (IOException e) { - // If parsing fails, it's not a valid JSON string, so keep it as is + // If parsing fails, keep it as is parametersNode.set(key, value); } } @@ -123,6 +125,11 @@ public static String processRemoteInferenceInputDataSetParametersValue(String in return mapper.writeValueAsString(rootNode); } + private static boolean isStringTypeInSchema(JsonNode schema, String fieldName) { + JsonNode typeNode = schema.path(fieldName).path("type"); + return typeNode.isTextual() && typeNode.asText().equals("string"); + } + public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) { ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB(); if (openCircuitBreaker != null) { diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index 40bb230bbf..e082571ad6 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -118,61 +118,94 @@ public void testValidateRemoteInputWithTitanMultiModalRemoteSchema() throws IOEx @Test public void testProcessRemoteInferenceInputDataSetParametersValueNoParameters() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(json, processedJson); } @Test public void testProcessRemoteInferenceInputDataSetInvalidJson() { + String schema = "{\"type\": \"object\",\"properties\": {}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"a\"}}"; - assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json)); + assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema)); } @Test public void testProcessRemoteInferenceInputDataSetEmptyParameters() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\"}}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{}}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(json, processedJson); } @Test public void testProcessRemoteInferenceInputDataSetParametersValueParametersWrongType() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"array\"}}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":[\"Hello\",\"world\"]}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(json, processedJson); } @Test public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessArray() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {" + + "\"texts\": {\"type\": \"array\",\"items\": {\"type\": \"string\"}}" + + "}}}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":\"[\\\"Hello\\\",\\\"world\\\"]\"}}"; String expectedJson = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":[\"Hello\",\"world\"]}}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(expectedJson, processedJson); } @Test public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessObject() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {" + + "\"messages\": {\"type\": \"object\"}" + + "}}}}"; String json = - "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}}"; + "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}"; String expectedJson = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":{\"role\":\"system\",\"foo\":\"{\\\"a\\\": \\\"b\\\"}\",\"content\":{\"a\":\"b\"}}}}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(expectedJson, processedJson); } + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersQuotedNumber() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {" + + "\"key1\": {\"type\": \"string\"}," + + "\"key2\": {\"type\": \"integer\"}," + + "\"key3\": {\"type\": \"boolean\"}" + + "}}}}"; + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"123\",\"key2\":123,\"key3\":true}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); + assertEquals(json, processedJson); + } + @Test public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersNoProcess() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {" + + "\"key1\": {\"type\": \"string\"}," + + "\"key2\": {\"type\": \"integer\"}," + + "\"key3\": {\"type\": \"boolean\"}" + + "}}}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true}}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(json, processedJson); } @Test public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersInvalidJson() throws IOException { + String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {" + + "\"key1\": {\"type\": \"string\"}," + + "\"key2\": {\"type\": \"integer\"}," + + "\"key3\": {\"type\": \"boolean\"}," + + "\"texts\": {\"type\": \"array\"}" + + "}}}}"; String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"texts\":\"[\\\"Hello\\\",\\\"world\\\"\"}}"; - String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema); assertEquals(json, processedJson); } } From b8e69bdacae9d90164e53d145be72b74382195e2 Mon Sep 17 00:00:00 2001 From: Xinyuan Lu Date: Wed, 17 Sep 2025 14:12:36 +0800 Subject: [PATCH 10/10] fix Cohere IT (#4174) * fix Cohere IT Signed-off-by: xinyual * apply spotless Signed-off-by: xinyual * delete useless it Signed-off-by: xinyual --------- Signed-off-by: xinyual Signed-off-by: Brian Flores --- .../ml/rest/RestMLRAGSearchProcessorIT.java | 2 +- .../ml/rest/RestMLRemoteInferenceIT.java | 70 ------------------- 2 files changed, 1 insertion(+), 71 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 52cba14041..0bd56a9823 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -333,7 +333,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + "\"\n" + " },\n" + " \"parameters\": {\n" - + " \"model\": \"command\"\n" + + " \"model\": \"command-a-03-2025\"\n" + " },\n" + " \"actions\": [\n" + " {\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 9402db1d71..d1e00090d5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -670,76 +670,6 @@ private void testOpenAITextEmbeddingModel(String charset, Consumer verifyRe } } - public void testCohereGenerateTextModel() throws IOException, InterruptedException { - // Skip test if key is null - if (COHERE_KEY == null) { - return; - } - String entity = "{\n" - + " \"name\": \"Cohere generate text model Connector\",\n" - + " \"description\": \"The connector to public Cohere generate text model service\",\n" - + " \"version\": 1,\n" - + "\"client_config\": {\n" - + " \"max_connection\": 20,\n" - + " \"connection_timeout\": 50000,\n" - + " \"read_timeout\": 50000\n" - + " },\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + COHERE_KEY - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; - Response response = createConnector(entity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("cohere generate text model", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"prompt\": \"Once upon a time in a magical land called\",\n" - + " \"max_tokens\": 40\n" - + " }\n" - + "}"; - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("generations"); - responseMap = (Map) responseList.get(0); - assertFalse(((String) responseMap.get("text")).isEmpty()); - } - public static Response createConnector(String input) throws IOException { try { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);