Skip to content

Commit 8968df7

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
refactor: GH-3998 Refactor org.springframework.ai.ollama.OllamaChatModel#ollamaChatRequest to support custom implementations of AbstractMessage and align with other ChatModel implementations.
Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
1 parent 14d6f58 commit 8968df7

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import reactor.core.scheduler.Schedulers;
3333

3434
import org.springframework.ai.chat.messages.AssistantMessage;
35-
import org.springframework.ai.chat.messages.SystemMessage;
35+
import org.springframework.ai.chat.messages.MessageType;
3636
import org.springframework.ai.chat.messages.ToolResponseMessage;
3737
import org.springframework.ai.chat.messages.UserMessage;
3838
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -439,18 +439,21 @@ Prompt buildRequestPrompt(Prompt prompt) {
439439
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
440440

441441
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
442-
if (message instanceof UserMessage userMessage) {
442+
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
443443
var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText());
444-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
445-
messageBuilder.images(
446-
userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
444+
if (message instanceof UserMessage userMessage) {
445+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
446+
messageBuilder.images(userMessage.getMedia()
447+
.stream()
448+
.map(media -> this.fromMediaData(media.getData()))
449+
.toList());
450+
}
447451
}
452+
448453
return List.of(messageBuilder.build());
449454
}
450-
else if (message instanceof SystemMessage systemMessage) {
451-
return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(systemMessage.getText()).build());
452-
}
453-
else if (message instanceof AssistantMessage assistantMessage) {
455+
else if (message.getMessageType() == MessageType.ASSISTANT) {
456+
var assistantMessage = (AssistantMessage) message;
454457
List<ToolCall> toolCalls = null;
455458
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
456459
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
@@ -465,7 +468,8 @@ else if (message instanceof AssistantMessage assistantMessage) {
465468
.toolCalls(toolCalls)
466469
.build());
467470
}
468-
else if (message instanceof ToolResponseMessage toolMessage) {
471+
else if (message.getMessageType() == MessageType.TOOL) {
472+
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
469473
return toolMessage.getResponses()
470474
.stream()
471475
.map(tr -> OllamaApi.Message.builder(Role.TOOL).content(tr.responseData()).build())

0 commit comments

Comments
 (0)