diff --git a/prompto-lab-app/pom.xml b/prompto-lab-app/pom.xml index 6942066..a5ec5f2 100644 --- a/prompto-lab-app/pom.xml +++ b/prompto-lab-app/pom.xml @@ -19,6 +19,11 @@ 17 + + commons-lang + commons-lang + 2.6 + org.springframework.boot spring-boot-starter-web diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java index f31cf4a..e310f6c 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java @@ -1,19 +1,14 @@ package io.github.timemachinelab.controller; -import io.github.timemachinelab.core.qatree.QaTree; -import io.github.timemachinelab.core.qatree.QaTreeDomain; -import io.github.timemachinelab.core.session.application.ConversationService; import io.github.timemachinelab.core.session.application.MessageProcessingService; import io.github.timemachinelab.core.session.application.SessionManagementService; +import io.github.timemachinelab.core.session.application.SseNotificationService; import io.github.timemachinelab.core.session.domain.entity.ConversationSession; -import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation; import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest; -import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse; import io.github.timemachinelab.entity.req.RetryRequest; import io.github.timemachinelab.entity.resp.ApiResult; import io.github.timemachinelab.entity.resp.RetryResponse; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.validation.annotation.Validated; @@ -23,7 +18,6 @@ import javax.annotation.Resource; import javax.validation.Valid; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -40,16 +34,12 @@ @RequestMapping("/api/user-interaction") @Validated public class UserInteractionController { - - @Resource - private ConversationService conversationService; @Resource private MessageProcessingService messageProcessingService; @Resource private SessionManagementService sessionManagementService; - private final Map sseEmitters = new ConcurrentHashMap<>(); - @Autowired - private QaTreeDomain qaTreeDomain; + @Resource + private SseNotificationService sseNotificationService; /** * 建立SSE连接 @@ -60,7 +50,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId); boolean isNewSession = false; - ConversationSession session = null; + ConversationSession session; try { if (sessionId == null || sessionId.isEmpty()) { @@ -82,7 +72,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess } SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - sseEmitters.put(sessionId, emitter); + sseNotificationService.registerSseConnection(sessionId, emitter); // 连接建立时发送会话信息 Map connectionData = new ConcurrentHashMap<>(); @@ -115,9 +105,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId); } - emitter.send(SseEmitter.event() - .name("connected") - .data(connectionData)); + sseNotificationService.sendWelcomeMessage(sessionId, connectionData); // 设置连接事件处理 String finalSessionId = sessionId; @@ -127,12 +115,12 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess emitter.onTimeout(() -> { log.info("SSE连接超时: {}", finalSessionId); - sseEmitters.remove(finalSessionId); + sseNotificationService.removeSseConnection(finalSessionId); }); emitter.onError((ex) -> { log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage()); - sseEmitters.remove(finalSessionId); + sseNotificationService.removeSseConnection(finalSessionId); }); return emitter; @@ -163,8 +151,45 @@ public ResponseEntity> retry(@Valid @RequestBody RetryR log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}", request.getNodeId(), request.getSessionId(), request.getWhyretry()); + // 使用应用服务验证节点存在性 + //todo: 有可能水平越权 不传userId的话 + if (!sessionManagementService.validateNodeExists(request.getSessionId(), request.getNodeId())) { + log.warn("节点不存在 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId()); + return ResponseEntity.badRequest().body(ApiResult.error("指定的节点不存在")); + } + + // 使用应用服务获取问题内容 + String question = sessionManagementService.getNodeQuestion(request.getSessionId(), request.getNodeId()); + if (question == null) { + log.warn("节点问题内容为空 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId()); + return ResponseEntity.badRequest().body(ApiResult.error("节点问题内容为空")); + } - + // 获取会话对象 + ConversationSession session = sessionManagementService.getSessionById(request.getSessionId()); + if (session == null) { + log.warn("会话不存在 - sessionId: {}", request.getSessionId()); + return ResponseEntity.badRequest().body(ApiResult.error("会话不存在")); + } + + // 移除要重试的节点(AI会基于parentId重新创建节点) + boolean nodeRemoved = sessionManagementService.removeNode(request.getSessionId(), request.getNodeId()); + if (!nodeRemoved) { + log.warn("移除节点失败,但继续处理重试 - sessionId: {}, nodeId: {}", + request.getSessionId(), request.getNodeId()); + } + + // 使用MessageProcessingService处理重试消息 + String processedMessage = messageProcessingService.processRetryMessage( + request.getSessionId(), + request.getNodeId(), + request.getWhyretry(), + session + ); + + // 发送处理后的消息给AI服务 + messageProcessingService.processAndSendMessage(session, processedMessage); + // 构建响应数据 RetryResponse response = RetryResponse.builder() .nodeId(request.getNodeId()) @@ -196,12 +221,6 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe request.getNodeId(), request.getQuestionType()); - // 1. 强制要求sessionId - if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) { - log.warn("缺少必需的sessionId参数"); - return ResponseEntity.badRequest().body("sessionId参数是必需的"); - } - // 2. 会话管理和验证 String userId = request.getUserId(); if (userId == null || userId.trim().isEmpty()) { @@ -248,28 +267,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe return ResponseEntity.badRequest().body("答案格式不正确"); } - QaTree qaTree = session.getQaTree(); - - // 根据问题类型获取正确的答案数据 - Object answerData; - switch (request.getQuestionType().toLowerCase()) { - case "input": - answerData = request.getInputAnswer(); - break; - case "single": - case "multi": - answerData = request.getChoiceAnswer(); - break; - case "form": - answerData = request.getFormAnswer(); - break; - default: - log.warn("未知的问题类型: {}", request.getQuestionType()); - answerData = request.getAnswerString(); - break; - } - - qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData); + // 答案更新逻辑已在MessageProcessingService中处理 // 4. 处理答案并转换为消息 String processedMessage = messageProcessingService.preprocessMessage( @@ -279,11 +277,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe ); // 5. 发送处理后的消息给AI服务 - conversationService.processUserMessage( - session.getUserId(), - processedMessage, - response -> sendSseMessage(session.getSessionId(), response) - ); + messageProcessingService.processAndSendMessage(session, processedMessage); return ResponseEntity.ok("答案处理成功"); @@ -294,89 +288,11 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe } } - /** - * 通过SSE发送消息给客户端 - * 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新 - * - * @param sessionId 会话ID - * @param response 消息响应对象 - */ - private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) { - SseEmitter emitter = sseEmitters.get(sessionId); - if (emitter != null) { - try { - String currentNodeId = null; - - // 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空) - ConversationSession session = sessionManagementService.getSessionById(sessionId); - if (session != null && session.getQaTree() != null && response.getQuestion() != null) { - // 使用QaTreeDomain添加新节点,answer字段会自动为空 - // appendNode方法内部会调用session.getNextNodeId()获取新节点ID - QaTree qaTree = qaTreeDomain.appendNode( - session.getQaTree(), - response.getParentId(), - response.getQuestion(), - session - ); - - // 获取刚刚创建的节点ID(当前计数器的值) - currentNodeId = String.valueOf(session.getNodeIdCounter().get()); - - log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}", - sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType()); - } else { - log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}", - sessionId, session != null, - session != null && session.getQaTree() != null, - response.getQuestion() != null); - } - - // 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId - Map modifiedResponse = new HashMap<>(); - modifiedResponse.put("question", response.getQuestion()); - modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId()); - modifiedResponse.put("parentNodeId", response.getParentId()); - - // 3. 发送SSE消息给前端 - emitter.send(SseEmitter.event() - .name("message") - .data(modifiedResponse)); - log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId); - } catch (IOException e) { - log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage()); - sseEmitters.remove(sessionId); - } catch (Exception e) { - log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage()); - // 即使QaTree更新失败,仍然发送SSE消息给前端 - try { - Map fallbackResponse = new HashMap<>(); - fallbackResponse.put("question", response.getQuestion()); - fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback - fallbackResponse.put("parentNodeId", response.getParentId()); - - emitter.send(SseEmitter.event() - .name("message") - .data(fallbackResponse)); - log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId); - } catch (IOException ioException) { - log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage()); - sseEmitters.remove(sessionId); - } - } - } else { - log.warn("SSE连接不存在 - 会话: {}", sessionId); - } - } - /** * 获取SSE连接状态 */ @GetMapping("/sse-status") public ResponseEntity> getSseStatus() { - Map status = new ConcurrentHashMap<>(); - status.put("connectedSessions", sseEmitters.keySet()); - status.put("totalConnections", sseEmitters.size()); - status.put("timestamp", System.currentTimeMillis()); - return ResponseEntity.ok(status); + return ResponseEntity.ok(sseNotificationService.getSseStatus()); } } \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTree.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTree.java index d0d5430..9cd794b 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTree.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTree.java @@ -29,5 +29,59 @@ public void addNode(String parentId, QaTreeNode node) { public QaTreeNode getNodeById(String id) { return nodeMap.get(id); } + + /** + * 移除指定节点及其所有子节点 + * @param nodeId 要移除的节点ID + * @return 是否移除成功 + */ + public boolean removeNode(String nodeId) { + QaTreeNode nodeToRemove = nodeMap.get(nodeId); + if (nodeToRemove == null) { + return false; + } + + // 递归移除所有子节点 + removeNodeAndChildren(nodeToRemove); + + // 从父节点的children中移除该节点 + removeFromParent(nodeToRemove); + + return true; + } + + /** + * 递归移除节点及其所有子节点 + * @param node 要移除的节点 + */ + private void removeNodeAndChildren(QaTreeNode node) { + if (node == null) { + return; + } + + // 递归移除所有子节点 + if (node.getChildren() != null) { + for (QaTreeNode child : node.getChildren().values()) { + removeNodeAndChildren(child); + } + } + + // 从nodeMap中移除当前节点 + nodeMap.remove(node.getId()); + } + + /** + * 从父节点的children中移除指定节点 + * @param nodeToRemove 要移除的节点 + */ + private void removeFromParent(QaTreeNode nodeToRemove) { + // 遍历所有节点找到父节点 + for (QaTreeNode node : nodeMap.values()) { + if (node.getChildren() != null && node.getChildren().containsKey(nodeToRemove.getId())) { + node.removeChild(nodeToRemove.getId()); + break; + } + } + } } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java index 1ecccf3..037740c 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTreeDomain.java @@ -88,4 +88,62 @@ public boolean updateNodeAnswer(QaTree tree, String nodeId, Object answer) { return true; } + + /** + * 获取指定节点的问题内容 + * @param tree QA树 + * @param nodeId 节点ID + * @return 问题内容,如果节点不存在或问题为空则返回null + */ + public String getNodeQuestion(QaTree tree, String nodeId) { + if (tree == null || nodeId == null) { + return null; + } + + QaTreeNode node = tree.getNodeById(nodeId); + if (node == null || node.getQa() == null) { + return null; + } + + return node.getQa().getQuestion(); + } + + /** + * 验证节点是否存在 + * @param tree QA树 + * @param nodeId 节点ID + * @return 节点是否存在 + */ + public boolean nodeExists(QaTree tree, String nodeId) { + if (tree == null || nodeId == null) { + return false; + } + + return tree.getNodeById(nodeId) != null; + } + + /** + * 移除指定节点及其所有子节点 + * @param tree QA树 + * @param nodeId 要移除的节点ID + * @return 是否移除成功 + */ + public boolean removeNode(QaTree tree, String nodeId) { + if (tree == null || nodeId == null) { + return false; + } + + QaTreeNode nodeToRemove = tree.getNodeById(nodeId); + if (nodeToRemove == null) { + return false; + } + + // 不能移除根节点 + if (tree.getRoot() != null && tree.getRoot().getId().equals(nodeId)) { + return false; + } + + // 从树中移除节点(包括从父节点的children中移除和从nodeMap中移除) + return tree.removeNode(nodeId); + } } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java index 7ca17e5..fde0a41 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/ConversationService.java @@ -2,24 +2,14 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; -import io.github.timemachinelab.core.qatree.QaTree; -import io.github.timemachinelab.core.qatree.QaTreeDomain; -import io.github.timemachinelab.core.qatree.QaTreeNode; -import io.github.timemachinelab.core.question.BaseQuestion; import io.github.timemachinelab.core.session.domain.entity.ConversationSession; import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation; -import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse; -import io.github.timemachinelab.core.session.infrastructure.ai.ConversationOperation; import io.github.timemachinelab.sfchain.core.AIService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import javax.annotation.Resource; -import java.util.concurrent.ConcurrentHashMap; -import java.util.Map; -import java.util.List; -import java.util.ArrayList; import java.util.function.Consumer; @Service @@ -31,7 +21,6 @@ public class ConversationService { private final AIService aiService; @Resource private SessionManagementService sessionManagementService; - private final QaTreeDomain qaTreeDomain; public void processUserMessage(String userId, String userMessage, Consumer sseCallback) { diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/MessageProcessingService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/MessageProcessingService.java index 2d86c89..b415f1b 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/MessageProcessingService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/MessageProcessingService.java @@ -38,4 +38,25 @@ public interface MessageProcessingService { * @return 是否有效 */ boolean validateAnswer(UnifiedAnswerRequest request); + + /** + * 处理重试请求 + * 将重试信息转换为适合大模型处理的格式 + * + * @param sessionId 会话ID + * @param nodeId 节点ID + * @param whyRetry 重试原因 + * @param conversationSession 会话对象 + * @return 处理后的消息内容 + */ + String processRetryMessage(String sessionId, String nodeId, String whyRetry, ConversationSession conversationSession); + + /** + * 处理并发送消息给AI服务 + * 统一的消息处理和发送逻辑 + * + * @param session 会话对象 + * @param processedMessage 处理后的消息 + */ + void processAndSendMessage(ConversationSession session, String processedMessage); } \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java index ac390ff..715799c 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/SessionManagementService.java @@ -219,6 +219,72 @@ public boolean userOwnsSession(String userId, String sessionId) { return userSessions != null && userSessions.contains(sessionId); } + /** + * 获取指定会话中节点的问题内容 + * + * @param sessionId 会话ID + * @param nodeId 节点ID + * @return 问题内容,如果节点不存在或问题为空则返回null + */ + public String getNodeQuestion(String sessionId, String nodeId) { + ConversationSession session = sessions.get(sessionId); + if (session == null) { + log.warn("会话不存在: {}", sessionId); + return null; + } + + QaTree tree = session.getQaTree(); + return qaTreeDomain.getNodeQuestion(tree, nodeId); + } + + /** + * 验证指定会话中的节点是否存在 + * + * @param sessionId 会话ID + * @param nodeId 节点ID + * @return 节点是否存在 + */ + public boolean validateNodeExists(String sessionId, String nodeId) { + ConversationSession session = sessions.get(sessionId); + if (session == null) { + log.warn("会话不存在: {}", sessionId); + return false; + } + + QaTree tree = session.getQaTree(); + return qaTreeDomain.nodeExists(tree, nodeId); + } + + /** + * 移除指定会话中的节点 + * + * @param sessionId 会话ID + * @param nodeId 节点ID + * @return 是否移除成功 + */ + public boolean removeNode(String sessionId, String nodeId) { + ConversationSession session = sessions.get(sessionId); + if (session == null) { + log.warn("会话不存在: {}", sessionId); + return false; + } + + QaTree tree = session.getQaTree(); + if (tree == null) { + log.warn("会话的QaTree不存在: {}", sessionId); + return false; + } + + boolean removed = qaTreeDomain.removeNode(tree, nodeId); + if (removed) { + log.info("成功移除节点 - 会话: {}, 节点: {}", sessionId, nodeId); + } else { + log.warn("移除节点失败 - 会话: {}, 节点: {}", sessionId, nodeId); + } + + return removed; + } + /** * 获取会话统计信息 */ diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java index 5b197d7..2f31fe5 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/application/impl/DefaultMessageProcessingService.java @@ -9,6 +9,8 @@ import io.github.timemachinelab.core.qatree.QaTreeDomain; import io.github.timemachinelab.core.session.application.MessageProcessingService; import io.github.timemachinelab.core.session.application.SessionManagementService; +import io.github.timemachinelab.core.session.application.ConversationService; +import io.github.timemachinelab.core.session.application.SseNotificationService; import io.github.timemachinelab.core.session.domain.entity.ConversationSession; import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest; import io.github.timemachinelab.util.QaTreeSerializeUtil; @@ -32,6 +34,10 @@ public class DefaultMessageProcessingService implements MessageProcessingService SessionManagementService sessionManagementService; @Resource QaTreeDomain qaTreeDomain; + @Resource + ConversationService conversationService; + @Resource + SseNotificationService sseNotificationService; @Override public String processAnswer(UnifiedAnswerRequest request) { @@ -152,5 +158,52 @@ public boolean validateAnswer(UnifiedAnswerRequest request) { return false; } } - -} \ No newline at end of file + + @Override + public String processRetryMessage(String sessionId, String nodeId, String whyRetry, ConversationSession conversationSession) { + try { + // 构建重试消息的JSON格式 + JSONObject retryInput = new JSONObject(); + retryInput.put("action", "retry"); + retryInput.put("nodeId", nodeId); + retryInput.put("whyRetry", whyRetry != null ? whyRetry : "用户要求重新生成问题"); + + // 获取节点的问题内容 + String preQuestion = sessionManagementService.getNodeQuestion(sessionId, nodeId); + if (preQuestion != null) { + retryInput.put("preQuestion", preQuestion); + } + + JSONObject object = new JSONObject(); + object.put("prompt", AllPrompt.GLOBAL_PROMPT); + object.put("tree", QaTreeSerializeUtil.serialize(conversationSession.getQaTree())); + object.put("input", retryInput.toString()); + + log.info("处理重试消息 - 会话: {}, 节点: {}, 原因: {}", sessionId, nodeId, whyRetry); + return object.toString(); + + } catch (JsonProcessingException e) { + log.error("处理重试消息失败 - 会话: {}, 错误: {}", sessionId, e.getMessage(), e); + throw new RuntimeException("重试消息处理失败", e); + } + } + + @Override + public void processAndSendMessage(ConversationSession session, String processedMessage) { + try { + log.info("发送消息给AI服务 - 会话: {}, 用户: {}", session.getSessionId(), session.getUserId()); + + conversationService.processUserMessage( + session.getUserId(), + processedMessage, + response -> sseNotificationService.sendSseMessage(session.getSessionId(), response) + ); + + log.info("消息发送成功 - 会话: {}", session.getSessionId()); + } catch (Exception e) { + log.error("发送消息失败 - 会话: {}, 错误: {}", session.getSessionId(), e.getMessage(), e); + throw new RuntimeException("消息发送失败: " + e.getMessage(), e); + } + } + + } \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/dto/UnifiedAnswerRequest.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/dto/UnifiedAnswerRequest.java index 368a548..12755ee 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/dto/UnifiedAnswerRequest.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/core/session/infrastructure/web/dto/UnifiedAnswerRequest.java @@ -25,6 +25,7 @@ public class UnifiedAnswerRequest { /** * 会话ID */ + @NotBlank(message = "会话ID不能为空") private String sessionId; /** diff --git a/prompto-lab-ui/src/components/Chat/AIChatPage.vue b/prompto-lab-ui/src/components/Chat/AIChatPage.vue index 69559b2..0a0b8af 100644 --- a/prompto-lab-ui/src/components/Chat/AIChatPage.vue +++ b/prompto-lab-ui/src/components/Chat/AIChatPage.vue @@ -41,6 +41,7 @@ :is-loading="isLoading" @send-message="handleSendMessage" @submit-answer="handleSubmitAnswer" + @retry-question="handleRetryQuestion" /> @@ -77,7 +78,7 @@ import { ref, computed, onMounted, onUnmounted } from 'vue' import QuestionRenderer from '../QuestionRenderer.vue' import ChatTree from './ChatTree.vue' import MindMapTree from './MindMapTree.vue' -import { startConversation, sendMessage, sendUserMessage, connectSSE, closeSSE, processAnswer, connectUserInteractionSSE, type MessageRequest, type MessageResponse, type ConversationSession, type UnifiedAnswerRequest, type FormAnswerItem } from '@/services/conversationApi' +import { startConversation, sendMessage, sendUserMessage, connectSSE, closeSSE, processAnswer, connectUserInteractionSSE, retryQuestion, type MessageRequest, type MessageResponse, type ConversationSession, type UnifiedAnswerRequest, type FormAnswerItem, type RetryRequest } from '@/services/conversationApi' import { toast } from '@/utils/toast' interface Message { @@ -416,7 +417,7 @@ const handleSSEError = (error: Event) => { eventSource.value = connectUserInteractionSSE( session.value?.sessionId || null, - session.value?.userId || userId, + session.value?.userId || 'demo-user-' + Date.now(), handleSSEMessage, handleSSEError ) @@ -660,6 +661,65 @@ const handleSendMessage = async (content: string) => { } } +// 处理重试问题 +const handleRetryQuestion = async (reason: string = '用户要求重新生成问题') => { + if (!session.value || !currentQuestion.value) { + toast.error({ + title: '重试失败', + message: '会话未建立或没有当前问题', + duration: 3000 + }) + return + } + + // 更新活跃时间 + updateActivity() + + isLoading.value = true + + try { + // 构建重试请求 + const retryRequest: RetryRequest = { + sessionId: session.value.sessionId, + nodeId: currentNodeId.value, // 当前问题节点ID + whyretry: reason + } + + // 调用重试接口 + await retryQuestion(retryRequest) + + toast.success({ + title: '重试成功', + message: '正在重新生成问题,请稍候', + duration: 2000 + }) + + console.log('重试请求已发送,等待AI重新生成问题...') + + } catch (error: any) { + console.error('重试失败:', error) + isLoading.value = false + + // 检查是否是会话相关错误 + if (error.message && (error.message.includes('sessionId') || error.message.includes('nodeId') || error.message.includes('会话') || error.message.includes('节点'))) { + toast.error({ + title: '会话异常', + message: '会话或节点状态异常,请刷新页面重新建立连接', + duration: 5000 + }) + // 清理当前会话状态 + session.value = null + closeConnection() + } else { + toast.error({ + title: '重试失败', + message: error.message || '重试请求失败,请重试', + duration: 4000 + }) + } + } +} + // 处理答案提交 const handleSubmitAnswer = async (answerData: any) => { if (!session.value || !currentQuestion.value) { diff --git a/prompto-lab-ui/src/components/QuestionRenderer.vue b/prompto-lab-ui/src/components/QuestionRenderer.vue index 4c3421a..a0547cb 100644 --- a/prompto-lab-ui/src/components/QuestionRenderer.vue +++ b/prompto-lab-ui/src/components/QuestionRenderer.vue @@ -182,11 +182,21 @@ 提交答案 + + @@ -198,6 +208,29 @@ :type="loadingType" :duration="3000" /> + + +
+
+
+

重试原因

+ +
+
+

请说明为什么要重新生成这个问题:

+ +
+
+ + +
+
+
@@ -264,6 +297,7 @@ const props = defineProps() const emit = defineEmits<{ sendMessage: [content: string] submitAnswer: [answer: any] + retryQuestion: [reason: string] }>() // 响应式数据 @@ -440,6 +474,28 @@ const submitAnswer = () => { emit('submitAnswer', answerData) } +// 重试相关状态 +const showRetryDialog = ref(false) +const retryReason = ref('') + +const retryQuestion = () => { + // 显示重试原因输入对话框 + showRetryDialog.value = true + retryReason.value = '' +} + +const confirmRetry = () => { + // 发送重试事件,包含用户输入的原因 + emit('retryQuestion', retryReason.value || '用户要求重新生成问题') + showRetryDialog.value = false + retryReason.value = '' +} + +const cancelRetry = () => { + showRetryDialog.value = false + retryReason.value = '' +} + const resetQuestion = () => { // 重置所有答案 answers.input = '' @@ -1205,6 +1261,35 @@ const resetQuestion = () => { inset 0 1px 0 rgba(255, 255, 255, 0.4); } +.retry-btn { + display: flex; + align-items: center; + gap: 8px; + padding: 16px 24px; + background: rgba(15, 15, 15, 0.8); + border: 1px solid rgba(255, 165, 0, 0.2); + border-radius: 16px; + color: #ffb366; + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: all 0.3s ease; + backdrop-filter: blur(10px); +} + +.retry-btn:hover { + border-color: rgba(255, 165, 0, 0.4); + color: #ffffff; + transform: translateY(-1px); + background: rgba(255, 165, 0, 0.1); +} + +.retry-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; +} + .reset-btn { display: flex; align-items: center; @@ -1387,4 +1472,143 @@ const resetQuestion = () => { justify-content: center; } } + +/* 重试对话框样式 */ +.retry-dialog-overlay { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.7); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + backdrop-filter: blur(8px); +} + +.retry-dialog { + background: linear-gradient(135deg, rgba(15, 15, 15, 0.95), rgba(25, 25, 25, 0.95)); + border: 1px solid rgba(255, 165, 0, 0.3); + border-radius: 16px; + width: 90%; + max-width: 500px; + backdrop-filter: blur(20px); + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.5); +} + +.retry-dialog-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 20px 24px; + border-bottom: 1px solid rgba(255, 165, 0, 0.2); +} + +.retry-dialog-header h3 { + margin: 0; + color: #ffb366; + font-size: 18px; + font-weight: 600; +} + +.close-btn { + background: none; + border: none; + color: #888; + font-size: 24px; + cursor: pointer; + padding: 0; + width: 32px; + height: 32px; + display: flex; + align-items: center; + justify-content: center; + border-radius: 8px; + transition: all 0.3s ease; +} + +.close-btn:hover { + background: rgba(255, 165, 0, 0.1); + color: #ffb366; +} + +.retry-dialog-content { + padding: 24px; +} + +.retry-dialog-content p { + margin: 0 0 16px 0; + color: #e8e8e8; + font-size: 14px; + line-height: 1.5; +} + +.retry-reason-input { + width: 100%; + background: rgba(10, 10, 10, 0.8); + border: 1px solid rgba(255, 165, 0, 0.3); + border-radius: 12px; + padding: 12px 16px; + color: #e8e8e8; + font-size: 14px; + font-family: inherit; + resize: vertical; + min-height: 100px; + transition: all 0.3s ease; +} + +.retry-reason-input:focus { + outline: none; + border-color: rgba(255, 165, 0, 0.6); + box-shadow: 0 0 0 2px rgba(255, 165, 0, 0.1); +} + +.retry-reason-input::placeholder { + color: #666; +} + +.retry-dialog-actions { + display: flex; + gap: 12px; + padding: 0 24px 24px 24px; + justify-content: flex-end; +} + +.cancel-btn, .confirm-btn { + padding: 10px 20px; + border-radius: 10px; + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: all 0.3s ease; + border: 1px solid; +} + +.cancel-btn { + background: rgba(128, 128, 128, 0.1); + border-color: rgba(128, 128, 128, 0.3); + color: #cccccc; +} + +.cancel-btn:hover { + background: rgba(128, 128, 128, 0.2); + border-color: rgba(128, 128, 128, 0.5); + color: #ffffff; +} + +.confirm-btn { + background: linear-gradient(135deg, rgba(255, 165, 0, 0.2), rgba(255, 140, 0, 0.2)); + border-color: rgba(255, 165, 0, 0.5); + color: #ffb366; +} + +.confirm-btn:hover { + background: linear-gradient(135deg, rgba(255, 165, 0, 0.3), rgba(255, 140, 0, 0.3)); + border-color: rgba(255, 165, 0, 0.7); + color: #ffffff; + transform: translateY(-1px); + box-shadow: 0 4px 15px rgba(255, 165, 0, 0.3); +} \ No newline at end of file diff --git a/prompto-lab-ui/src/services/conversationApi.ts b/prompto-lab-ui/src/services/conversationApi.ts index a18b5b4..7a6838b 100644 --- a/prompto-lab-ui/src/services/conversationApi.ts +++ b/prompto-lab-ui/src/services/conversationApi.ts @@ -57,7 +57,7 @@ export const startConversation = async (userId: string): Promise void, onError?: (error: Event) => void): EventSource => { +export const connectUserInteractionSSE = (sessionId: string | null, userId: string, onMessage: (response: any) => void, onError?: (error: Event) => void): EventSource => { + // 构建查询参数 const params = new URLSearchParams() if (sessionId) { params.append('sessionId', sessionId) } params.append('userId', userId) + const url = `${USER_INTERACTION_BASE}/sse?${params.toString()}` const eventSource = new EventSource(url) @@ -226,10 +228,13 @@ export const connectUserInteractionSSE = (sessionId: string | null, userId: stri eventSource.addEventListener('connected', (event: MessageEvent) => { console.log('用户交互SSE连接已建立:', event.data) try { - const response = JSON.parse(event.data) - onMessage(response) + // 解析连接数据并传递给onMessage回调 + const connectionData = JSON.parse(event.data) + onMessage(connectionData) } catch (error) { - console.error('解析连接建立消息失败:', error) + console.error('解析连接数据失败:', error) + // 如果解析失败,传递原始数据 + onMessage({ type: 'connected', data: event.data }) } }) @@ -253,6 +258,28 @@ export const connectUserInteractionSSE = (sessionId: string | null, userId: stri return eventSource } +/** + * 重试请求接口 + * 对接后端的retry接口 + */ +export interface RetryRequest { + nodeId: string + sessionId: string + whyretry: string +} + +export const retryQuestion = async (request: RetryRequest): Promise => { + const url = `${USER_INTERACTION_BASE}/retry` + await apiRequest(url, { + method: 'POST', + body: JSON.stringify(request), + headers: { + 'Content-Type': 'application/json' + }, + requireAuth: false + }) +} + /** * 关闭SSE连接 */