Skip to content
Merged

123 #28

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.timemachinelab.controller;

import io.github.timemachinelab.core.qatree.QaTree;
import io.github.timemachinelab.core.qatree.QaTreeDomain;
import io.github.timemachinelab.core.session.application.ConversationService;
import io.github.timemachinelab.core.session.application.MessageProcessingService;
import io.github.timemachinelab.core.session.application.SessionManagementService;
Expand All @@ -11,6 +13,7 @@
import io.github.timemachinelab.entity.resp.ApiResult;
import io.github.timemachinelab.entity.resp.RetryResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
Expand All @@ -20,6 +23,7 @@
import javax.annotation.Resource;
import javax.validation.Valid;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -44,7 +48,9 @@ public class UserInteractionController {
@Resource
private SessionManagementService sessionManagementService;
private final Map<String, SseEmitter> sseEmitters = new ConcurrentHashMap<>();

@Autowired
private QaTreeDomain qaTreeDomain;

/**
* 建立SSE连接
*/
Expand Down Expand Up @@ -88,8 +94,8 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
// 根据会话状态返回nodeId
if (isNewSession) {
// 新会话返回根节点ID
connectionData.put("nodeId", "root");
log.info("新会话返回根节点ID: root - 会话: {}", sessionId);
connectionData.put("nodeId", "1");
log.info("新会话返回根节点ID: 1 - 会话: {}", sessionId);
} else if (session.getQaTree() != null && session.getQaTree().getRoot() != null) {
// 已存在会话,返回根节点ID(因为qaTree只有根节点)
String rootNodeId = session.getQaTree().getRoot().getId();
Expand All @@ -105,8 +111,8 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
}
} else {
// 兜底情况,返回根节点ID
connectionData.put("nodeId", "root");
log.info("兜底返回根节点ID: root - 会话: {}", sessionId);
connectionData.put("nodeId", "1");
log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId);
}

emitter.send(SseEmitter.event()
Expand Down Expand Up @@ -219,8 +225,8 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
return ResponseEntity.badRequest().body("现有会话必须提供nodeId");
}
log.info("新建会话的第一个问题 - 会话: {}", session.getSessionId());
} else if ("root".equals(nodeId)) {
// nodeId为'root',表示这是根节点的回答
} else if ("1".equals(nodeId)) {
// nodeId为'1',表示这是根节点的回答
if (session.getQaTree() == null || session.getQaTree().getRoot() == null) {
log.info("根节点回答,但qaTree未初始化 - 会话: {}", session.getSessionId());
// 允许继续处理,后续会创建qaTree
Expand All @@ -242,6 +248,29 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
return ResponseEntity.badRequest().body("答案格式不正确");
}

QaTree qaTree = session.getQaTree();

// 根据问题类型获取正确的答案数据
Object answerData;
switch (request.getQuestionType().toLowerCase()) {
case "input":
answerData = request.getInputAnswer();
break;
case "single":
case "multi":
answerData = request.getChoiceAnswer();
break;
case "form":
answerData = request.getFormAnswer();
break;
default:
log.warn("未知的问题类型: {}", request.getQuestionType());
answerData = request.getAnswerString();
break;
}

qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData);

// 4. 处理答案并转换为消息
String processedMessage = messageProcessingService.preprocessMessage(
null, // 没有额外的原始消息
Expand All @@ -256,6 +285,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
response -> sendSseMessage(session.getSessionId(), response)
);


return ResponseEntity.ok("答案处理成功");

} catch (Exception e) {
Expand All @@ -266,6 +296,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe

/**
* 通过SSE发送消息给客户端
* 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新
*
* @param sessionId 会话ID
* @param response 消息响应对象
Expand All @@ -274,13 +305,63 @@ private void sendSseMessage(String sessionId, QuestionGenerationOperation.Questi
SseEmitter emitter = sseEmitters.get(sessionId);
if (emitter != null) {
try {
String currentNodeId = null;

// 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空)
ConversationSession session = sessionManagementService.getSessionById(sessionId);
if (session != null && session.getQaTree() != null && response.getQuestion() != null) {
// 使用QaTreeDomain添加新节点,answer字段会自动为空
// appendNode方法内部会调用session.getNextNodeId()获取新节点ID
QaTree qaTree = qaTreeDomain.appendNode(
session.getQaTree(),
response.getParentId(),
response.getQuestion(),
session
);

// 获取刚刚创建的节点ID(当前计数器的值)
currentNodeId = String.valueOf(session.getNodeIdCounter().get());

log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}",
sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType());
} else {
log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}",
sessionId, session != null,
session != null && session.getQaTree() != null,
response.getQuestion() != null);
}

// 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId
Map<String, Object> modifiedResponse = new HashMap<>();
modifiedResponse.put("question", response.getQuestion());
modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId());
modifiedResponse.put("parentNodeId", response.getParentId());

// 3. 发送SSE消息给前端
emitter.send(SseEmitter.event()
.name("message")
.data(response));
log.info("SSE消息发送成功 - 会话: {}, 消息: {}", sessionId, response);
.data(modifiedResponse));
log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId);
} catch (IOException e) {
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
sseEmitters.remove(sessionId);
} catch (Exception e) {
log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
// 即使QaTree更新失败,仍然发送SSE消息给前端
try {
Map<String, Object> fallbackResponse = new HashMap<>();
fallbackResponse.put("question", response.getQuestion());
fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback
fallbackResponse.put("parentNodeId", response.getParentId());

emitter.send(SseEmitter.event()
.name("message")
.data(fallbackResponse));
log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId);
} catch (IOException ioException) {
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage());
sseEmitters.remove(sessionId);
}
}
} else {
log.warn("SSE连接不存在 - 会话: {}", sessionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ public QaTree createTree(String userStartQuestion) {

/**
* 使用ConversationSession的自增ID创建QaTree
* @param userStartQuestion 用户开始问题
* @param question 用户开始问题
* @param session 会话对象,用于获取自增ID
* @return 创建的QaTree
*/
public QaTree createTree(String userStartQuestion, ConversationSession session) {
public QaTree createTree(String question, ConversationSession session) {
InputQuestion startQA = new InputQuestion();
startQA.setQuestion(userStartQuestion);
startQA.setAnswer(userStartQuestion);
startQA.setQuestion(question);
// 使用会话的自增ID创建根节点
String rootNodeId = session.getNextNodeId();
QaTreeNode startNode = new QaTreeNode(startQA, rootNodeId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public ConversationSession createNewSession(String userId) {
ConversationSession session = new ConversationSession(userId, newSessionId, null);

// 使用会话的自增ID创建QaTree,确保根节点ID=1
QaTree tree = qaTreeDomain.createTree("default", session);
QaTree tree = qaTreeDomain.createTree("你好,我有什么可以帮你?", session);

// 设置QaTree到会话中
session.setQaTree(tree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ private void updateQaTreeWithAnswer(ConversationSession session, UnifiedAnswerRe
}

String nodeId = request.getNodeId();
// 如果nodeId为'root',使用根节点ID
if ("root".equals(nodeId) && qaTree.getRoot() != null) {
// 如果nodeId为'1'(根节点),使用根节点ID
if ("1".equals(nodeId) && qaTree.getRoot() != null) {
nodeId = qaTree.getRoot().getId();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package io.github.timemachinelab.core.session.domain.entity;

import io.github.timemachinelab.core.qatree.*;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;

import java.time.LocalDateTime;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

@Getter
@Data
public class ConversationSession {

private final String sessionId;
private final String userId;
/**
* -- SETTER --
* 设置QaTree(仅用于初始化)
*
* @param qaTree QA树对象
*/
@Setter
private QaTree qaTree; // 移除final,允许后续设置
private final LocalDateTime createTime;
private LocalDateTime updateTime;
Expand All @@ -36,10 +45,10 @@ public String getNextNodeId() {
}

/**
* 设置QaTree(仅用于初始化)
* @param qaTree QA树对象
* 获取节点ID计数器
* @return 节点ID计数器
*/
public void setQaTree(QaTree qaTree) {
this.qaTree = qaTree;
public AtomicInteger getNodeIdCounter() {
return nodeIdCounter;
}
}
80 changes: 64 additions & 16 deletions prompto-lab-ui/src/components/Chat/AIChatPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,22 @@ const handleSSEMessage = (response: any) => {

console.log('会话已建立:', session.value)

// 后端总是会返回nodeId,新会话返回'root',已存在会话返回实际的nodeId
// 后端总是会返回nodeId,新会话返回'1',已存在会话返回实际的nodeId
if (response.nodeId) {
currentNodeId.value = response.nodeId
console.log('会话节点ID:', response.nodeId)

// 如果是根节点,初始化根节点
if (response.nodeId === 'root') {
if (response.nodeId === '1') {
const rootNode: ConversationNode = {
id: 'root',
id: '1',
content: '您好!我是AI助手,有什么可以帮助您的吗?',
type: 'assistant',
timestamp: new Date(),
children: [],
isActive: true
}
conversationTree.value.set('root', rootNode)
conversationTree.value.set('1', rootNode)
}
}

Expand Down Expand Up @@ -295,14 +295,52 @@ const handleSSEMessage = (response: any) => {
// 这是新的问题格式
currentQuestion.value = response.question

// 更新当前节点ID为问题的parentId
if (response.parentId) {
currentNodeId.value = response.parentId
console.log('更新当前节点ID为:', response.parentId)
// 更新当前节点ID为新创建的问题节点ID
if (response.currentNodeId) {
// 创建问题节点并添加到对话树
const questionContent = `${response.question.question}${response.question.desc ? '\n' + response.question.desc : ''}`

const questionNode: ConversationNode = {
id: response.currentNodeId,
content: questionContent,
type: 'assistant',
timestamp: new Date(),
parentId: response.parentNodeId,
children: [],
isActive: true
}

// 更新父节点的children数组
if (response.parentNodeId) {
const parentNode = conversationTree.value.get(response.parentNodeId)
if (parentNode) {
// 将父节点的其他子节点设为非活跃状态
parentNode.children.forEach(childId => {
const childNode = conversationTree.value.get(childId)
if (childNode) {
setNodeAndDescendantsInactive(childId)
}
})
parentNode.children.push(response.currentNodeId)
}
}

// 添加新问题节点到对话树
conversationTree.value.set(response.currentNodeId, questionNode)
currentNodeId.value = response.currentNodeId
console.log('更新当前节点ID为:', response.currentNodeId)

// 在聊天界面显示问题内容
addAIMessage(response.currentNodeId, questionContent)
}

// 记录父节点ID,用于后续构建树形关系图
if (response.parentNodeId) {
console.log('父节点ID:', response.parentNodeId)
}

isLoading.value = false
console.log('收到新格式问题:', response.question, '父节点ID:', response.parentId)
console.log('收到新格式问题:', response.question, '当前节点ID:', response.currentNodeId, '父节点ID:', response.parentNodeId)
return
}

Expand Down Expand Up @@ -639,10 +677,13 @@ const handleSubmitAnswer = async (answerData: any) => {
isLoading.value = true

try {
// 保存当前问题节点ID,用于后端验证
const questionNodeId = currentNodeId.value

// 构建统一答案请求,必须包含sessionId和正确的nodeId
const request: UnifiedAnswerRequest = {
sessionId: session.value.sessionId, // 必需的sessionId
nodeId: currentNodeId.value, // 当前节点ID,用于后端验证
nodeId: questionNodeId, // 问题节点ID,用于后端验证
questionType: currentQuestion.value.type,
answer: answerData,
userId: session.value.userId // 必需的userId
Expand Down Expand Up @@ -708,10 +749,11 @@ const handleSubmitAnswer = async (answerData: any) => {
}

conversationTree.value.set(userNodeId, userNode)
currentNodeId.value = userNodeId
// 不更新currentNodeId为用户节点ID,保持为问题节点ID直到收到新问题
// currentNodeId.value = userNodeId

// 清除当前问题状态
currentQuestion.value = null
// 不清除当前问题状态,保持显示直到收到新问题
// currentQuestion.value = null

toast.success({
title: '提交成功',
Expand Down Expand Up @@ -799,13 +841,19 @@ const handleBranchDeleted = (nodeId: string) => {
deleteNodeAndDescendants(nodeId)

if (!conversationTree.value.has(currentNodeId.value)) {
let newCurrentId = 'root'
// 动态查找根节点(没有parentId的节点)
let rootNodeId = ''
let newCurrentId = ''
conversationTree.value.forEach((node, id) => {
if (node.isActive && id !== 'root') {
if (!node.parentId) {
rootNodeId = id
}
if (node.isActive) {
newCurrentId = id
}
})
currentNodeId.value = newCurrentId
// 优先使用活跃节点,否则回退到根节点
currentNodeId.value = newCurrentId || rootNodeId
}
}
</script>
Expand Down
Loading
Loading