Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class MistralAiChatProperties extends MistralAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.mistralai.chat";

public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.SMALL.getValue();
public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.MISTRAL_SMALL.getValue();

private static final Double DEFAULT_TOP_P = 1.0;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,7 +60,8 @@ class PaymentStatusBeanIT {
void functionCallTest() {

this.contextRunner
.withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue())
.withPropertyValues(
"spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.MISTRAL_LARGE.getValue())
.run(context -> {

MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,8 @@ class PaymentStatusBeanOpenAiIT {
void functionCallTest() {

this.contextRunner
.withPropertyValues("spring.ai.openai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue())
.withPropertyValues(
"spring.ai.openai.chat.options.model=" + MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.run(context -> {

OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,7 +56,8 @@ public class PaymentStatusPromptIT {
@Test
void functionCallTest() {
this.contextRunner
.withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue())
.withPropertyValues(
"spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.run(context -> {

MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,7 +62,8 @@ public class WeatherServicePromptIT {
@Test
void promptFunctionCall() {
this.contextRunner
.withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue())
.withPropertyValues(
"spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.MISTRAL_LARGE.getValue())
.run(context -> {

MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class);
Expand Down Expand Up @@ -91,7 +92,8 @@ void promptFunctionCall() {
@Test
void functionCallWithPortableFunctionCallingOptions() {
this.contextRunner
.withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue())
.withPropertyValues(
"spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.MISTRAL_LARGE.getValue())
.run(context -> {

MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ public static final class Builder {
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.build();

private ToolCallingManager toolCallingManager = DEFAULT_TOOL_CALLING_MANAGER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ public enum ChatCompletionFinishReason {
/**
* List of well-known Mistral chat models.
*
* @see <a href=
* "https://docs.mistral.ai/getting-started/models/models_overview/">Mistral AI Models
* Overview</a>
* @see <a href="https://docs.mistral.ai/getting-started/models">Mistral AI Models</a>
*/
public enum ChatModel implements ChatModelDescription {

Expand All @@ -259,15 +257,17 @@ public enum ChatModel implements ChatModelDescription {
MAGISTRAL_MEDIUM("magistral-medium-latest"),
MISTRAL_MEDIUM("mistral-medium-latest"),
CODESTRAL("codestral-latest"),
LARGE("mistral-large-latest"),
DEVSTRAL_MEDIUM("devstral-medium-latest"),
MISTRAL_LARGE("mistral-large-latest"),
PIXTRAL_LARGE("pixtral-large-latest"),
MINISTRAL_3B_LATEST("ministral-3b-latest"),
MINISTRAL_8B_LATEST("ministral-8b-latest"),
// Free Models
MINISTRAL_3B("ministral-3b-latest"),
MINISTRAL_8B("ministral-8b-latest"),
MINISTRAL_14B("ministral-14b-latest"),
MAGISTRAL_SMALL("magistral-small-latest"),
DEVSTRAL_SMALL("devstral-small-latest"),
SMALL("mistral-small-latest"),
PIXTRAL("pixtral-12b-2409"),
MISTRAL_SMALL("mistral-small-latest"),
PIXTRAL_12B("pixtral-12b-latest"),
// Free Models - Research
OPEN_MISTRAL_NEMO("open-mistral-nemo");
// @formatter:on
Expand All @@ -292,9 +292,7 @@ public String getName() {
/**
* List of well-known Mistral embedding models.
*
* @see <a href=
* "https://docs.mistral.ai/getting-started/models/models_overview/">Mistral AI Models
* Overview</a>
* @see <a href="https://docs.mistral.ai/getting-started/models">Mistral AI Models</a>
*/
public enum EmbeddingModel {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -226,7 +226,7 @@ void functionCallTest() {

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).toolChoice(ToolChoice.AUTO).build())
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.MISTRAL_SMALL).toolChoice(ToolChoice.AUTO).build())
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
Expand All @@ -248,7 +248,7 @@ void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.MISTRAL_SMALL).build())
.defaultToolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand All @@ -270,7 +270,7 @@ void streamFunctionCallTest() {

// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.MISTRAL_SMALL).build())
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
Expand All @@ -291,7 +291,7 @@ void streamFunctionCallTest() {
@Test
void validateCallResponseMetadata() {
// String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName();
String model = MistralAiApi.ChatModel.PIXTRAL.getName();
String model = MistralAiApi.ChatModel.PIXTRAL_12B.getName();
// String model = MistralAiApi.ChatModel.PIXTRAL_LARGE.getName();
// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel).prompt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.mistralai;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -224,7 +223,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand All @@ -249,7 +248,7 @@ void streamFunctionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand Down Expand Up @@ -291,7 +290,7 @@ void multiModalityEmbeddedImage(String modelName) {

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "pixtral-large-latest" })
void multiModalityImageUrl(String modelName) throws IOException {
void multiModalityImageUrl(String modelName) {
var userMessage = UserMessage.builder()
.text("Explain what do you see on this picture?")
.media(List.of(Media.builder()
Expand All @@ -309,7 +308,7 @@ void multiModalityImageUrl(String modelName) throws IOException {
}

@Test
void streamingMultiModalityImageUrl() throws IOException {
void streamingMultiModalityImageUrl() {
var userMessage = UserMessage.builder()
.text("Explain what do you see on this picture?")
.media(List.of(Media.builder()
Expand Down Expand Up @@ -341,7 +340,7 @@ void streamFunctionCallUsageTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void beforeEach() {
@Test
void observationForChatOperation() {
var options = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.maxTokens(2048)
.stop(List.of("this-is-the-end"))
.temperature(0.7)
Expand All @@ -94,7 +94,7 @@ void observationForChatOperation() {
@Test
void observationForStreamingChatOperation() {
var options = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.maxTokens(2048)
.stop(List.of("this-is-the-end"))
.temperature(0.7)
Expand Down Expand Up @@ -131,12 +131,12 @@ private void validate(ChatResponseMetadata responseMetadata) {
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.SMALL.getValue())
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.CHAT.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.MISTRAL_AI.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
MistralAiApi.ChatModel.SMALL.getValue())
MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
: KeyValue.NONE_VALUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ void testBuilderWithAllFields() {
@Test
void testBuilderWithEnum() {
MistralAiChatOptions optionsWithEnum = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.MINISTRAL_8B_LATEST)
.model(MistralAiApi.ChatModel.MINISTRAL_8B)
.build();
assertThat(optionsWithEnum.getModel()).isEqualTo(MistralAiApi.ChatModel.MINISTRAL_8B_LATEST.getValue());
assertThat(optionsWithEnum.getModel()).isEqualTo(MistralAiApi.ChatModel.MINISTRAL_8B.getValue());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void beforeEach() {
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.SMALL.getValue())
.model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue())
.build())
.retryTemplate(this.retryTemplate)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
return MistralAiChatModel.builder()
.mistralAiApi(mistralAiApi)
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL.getValue()).build())
.defaultOptions(
MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.MISTRAL_SMALL.getValue()).build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class MistralAiApiIT {
void chatCompletionEntity() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
List.of(chatCompletionMessage), MistralAiApi.ChatModel.MISTRAL_SMALL.getValue(), 0.8, false));

assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();
Expand All @@ -65,7 +65,7 @@ void chatCompletionEntityWithSystemMessage() {
""", Role.SYSTEM);

ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.MISTRAL_SMALL.getValue(), 0.8, false));

assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();
Expand All @@ -75,7 +75,7 @@ void chatCompletionEntityWithSystemMessage() {
void chatCompletionStream() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
Flux<ChatCompletionChunk> response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest(
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, true));
List.of(chatCompletionMessage), MistralAiApi.ChatModel.MISTRAL_SMALL.getValue(), 0.8, true));

assertThat(response).isNotNull();
assertThat(response.collectList().block()).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
public class MistralAiApiToolFunctionCallIT {

static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.LARGE.getValue();
static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.MISTRAL_LARGE.getValue();

private final Logger logger = LoggerFactory.getLogger(MistralAiApiToolFunctionCallIT.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private static <T> T jsonToObject(String json, Class<T> targetClass) {

@Test
@SuppressWarnings("null")
public void toolFunctionCall() throws JsonProcessingException {
public void toolFunctionCall() {

var transactionJsonSchema = """
{
Expand Down Expand Up @@ -109,8 +109,9 @@ public void toolFunctionCall() throws JsonProcessingException {

MistralAiApi mistralApi = MistralAiApi.builder().apiKey(System.getenv("MISTRAL_AI_API_KEY")).build();

ResponseEntity<ChatCompletion> response = mistralApi.chatCompletionEntity(new ChatCompletionRequest(messages,
MistralAiApi.ChatModel.LARGE.getValue(), List.of(paymentStatusTool, paymentDateTool), ToolChoice.AUTO));
ResponseEntity<ChatCompletion> response = mistralApi
.chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.MISTRAL_LARGE.getValue(),
List.of(paymentStatusTool, paymentDateTool), ToolChoice.AUTO));

ChatCompletionMessage responseMessage = response.getBody().choices().get(0).message();

Expand All @@ -135,7 +136,7 @@ public void toolFunctionCall() throws JsonProcessingException {
}

response = mistralApi
.chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.LARGE.getValue()));
.chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.MISTRAL_LARGE.getValue()));

var responseContent = response.getBody().choices().get(0).message().content();
logger.info("Final response: " + responseContent);
Expand Down