diff --git a/prompto-lab-app/pom.xml b/prompto-lab-app/pom.xml
index 6942066..a5ec5f2 100644
--- a/prompto-lab-app/pom.xml
+++ b/prompto-lab-app/pom.xml
@@ -19,6 +19,11 @@
17
+
+ commons-lang
+ commons-lang
+ 2.6
+
org.springframework.boot
spring-boot-starter-web
diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java
index f31cf4a..e310f6c 100644
--- a/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java
+++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java
@@ -1,19 +1,14 @@
package io.github.timemachinelab.controller;
-import io.github.timemachinelab.core.qatree.QaTree;
-import io.github.timemachinelab.core.qatree.QaTreeDomain;
-import io.github.timemachinelab.core.session.application.ConversationService;
import io.github.timemachinelab.core.session.application.MessageProcessingService;
import io.github.timemachinelab.core.session.application.SessionManagementService;
+import io.github.timemachinelab.core.session.application.SseNotificationService;
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
-import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest;
-import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
import io.github.timemachinelab.entity.req.RetryRequest;
import io.github.timemachinelab.entity.resp.ApiResult;
import io.github.timemachinelab.entity.resp.RetryResponse;
import lombok.extern.slf4j.Slf4j;
-import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
@@ -23,7 +18,6 @@
import javax.annotation.Resource;
import javax.validation.Valid;
import java.io.IOException;
-import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@@ -40,16 +34,12 @@
@RequestMapping("/api/user-interaction")
@Validated
public class UserInteractionController {
-
- @Resource
- private ConversationService conversationService;
@Resource
private MessageProcessingService messageProcessingService;
@Resource
private SessionManagementService sessionManagementService;
- private final Map sseEmitters = new ConcurrentHashMap<>();
- @Autowired
- private QaTreeDomain qaTreeDomain;
+ @Resource
+ private SseNotificationService sseNotificationService;
/**
* 建立SSE连接
@@ -60,7 +50,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId);
boolean isNewSession = false;
- ConversationSession session = null;
+ ConversationSession session;
try {
if (sessionId == null || sessionId.isEmpty()) {
@@ -82,7 +72,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
}
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
- sseEmitters.put(sessionId, emitter);
+ sseNotificationService.registerSseConnection(sessionId, emitter);
// 连接建立时发送会话信息
Map connectionData = new ConcurrentHashMap<>();
@@ -115,9 +105,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId);
}
- emitter.send(SseEmitter.event()
- .name("connected")
- .data(connectionData));
+ sseNotificationService.sendWelcomeMessage(sessionId, connectionData);
// 设置连接事件处理
String finalSessionId = sessionId;
@@ -127,12 +115,12 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
emitter.onTimeout(() -> {
log.info("SSE连接超时: {}", finalSessionId);
- sseEmitters.remove(finalSessionId);
+ sseNotificationService.removeSseConnection(finalSessionId);
});
emitter.onError((ex) -> {
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
- sseEmitters.remove(finalSessionId);
+ sseNotificationService.removeSseConnection(finalSessionId);
});
return emitter;
@@ -163,8 +151,45 @@ public ResponseEntity> retry(@Valid @RequestBody RetryR
log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}",
request.getNodeId(), request.getSessionId(), request.getWhyretry());
+ // 使用应用服务验证节点存在性
+ //todo: 有可能水平越权 不传userId的话
+ if (!sessionManagementService.validateNodeExists(request.getSessionId(), request.getNodeId())) {
+ log.warn("节点不存在 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
+ return ResponseEntity.badRequest().body(ApiResult.error("指定的节点不存在"));
+ }
+
+ // 使用应用服务获取问题内容
+ String question = sessionManagementService.getNodeQuestion(request.getSessionId(), request.getNodeId());
+ if (question == null) {
+ log.warn("节点问题内容为空 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
+ return ResponseEntity.badRequest().body(ApiResult.error("节点问题内容为空"));
+ }
-
+ // 获取会话对象
+ ConversationSession session = sessionManagementService.getSessionById(request.getSessionId());
+ if (session == null) {
+ log.warn("会话不存在 - sessionId: {}", request.getSessionId());
+ return ResponseEntity.badRequest().body(ApiResult.error("会话不存在"));
+ }
+
+ // 移除要重试的节点(AI会基于parentId重新创建节点)
+ boolean nodeRemoved = sessionManagementService.removeNode(request.getSessionId(), request.getNodeId());
+ if (!nodeRemoved) {
+ log.warn("移除节点失败,但继续处理重试 - sessionId: {}, nodeId: {}",
+ request.getSessionId(), request.getNodeId());
+ }
+
+ // 使用MessageProcessingService处理重试消息
+ String processedMessage = messageProcessingService.processRetryMessage(
+ request.getSessionId(),
+ request.getNodeId(),
+ request.getWhyretry(),
+ session
+ );
+
+ // 发送处理后的消息给AI服务
+ messageProcessingService.processAndSendMessage(session, processedMessage);
+
// 构建响应数据
RetryResponse response = RetryResponse.builder()
.nodeId(request.getNodeId())
@@ -196,12 +221,6 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe
request.getNodeId(),
request.getQuestionType());
- // 1. 强制要求sessionId
- if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) {
- log.warn("缺少必需的sessionId参数");
- return ResponseEntity.badRequest().body("sessionId参数是必需的");
- }
-
// 2. 会话管理和验证
String userId = request.getUserId();
if (userId == null || userId.trim().isEmpty()) {
@@ -248,28 +267,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe
return ResponseEntity.badRequest().body("答案格式不正确");
}
- QaTree qaTree = session.getQaTree();
-
- // 根据问题类型获取正确的答案数据
- Object answerData;
- switch (request.getQuestionType().toLowerCase()) {
- case "input":
- answerData = request.getInputAnswer();
- break;
- case "single":
- case "multi":
- answerData = request.getChoiceAnswer();
- break;
- case "form":
- answerData = request.getFormAnswer();
- break;
- default:
- log.warn("未知的问题类型: {}", request.getQuestionType());
- answerData = request.getAnswerString();
- break;
- }
-
- qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData);
+ // 答案更新逻辑已在MessageProcessingService中处理
// 4. 处理答案并转换为消息
String processedMessage = messageProcessingService.preprocessMessage(
@@ -279,11 +277,7 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe
);
// 5. 发送处理后的消息给AI服务
- conversationService.processUserMessage(
- session.getUserId(),
- processedMessage,
- response -> sendSseMessage(session.getSessionId(), response)
- );
+ messageProcessingService.processAndSendMessage(session, processedMessage);
return ResponseEntity.ok("答案处理成功");
@@ -294,89 +288,11 @@ public ResponseEntity processAnswer(@Validated @RequestBody UnifiedAnswe
}
}
- /**
- * 通过SSE发送消息给客户端
- * 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新
- *
- * @param sessionId 会话ID
- * @param response 消息响应对象
- */
- private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) {
- SseEmitter emitter = sseEmitters.get(sessionId);
- if (emitter != null) {
- try {
- String currentNodeId = null;
-
- // 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空)
- ConversationSession session = sessionManagementService.getSessionById(sessionId);
- if (session != null && session.getQaTree() != null && response.getQuestion() != null) {
- // 使用QaTreeDomain添加新节点,answer字段会自动为空
- // appendNode方法内部会调用session.getNextNodeId()获取新节点ID
- QaTree qaTree = qaTreeDomain.appendNode(
- session.getQaTree(),
- response.getParentId(),
- response.getQuestion(),
- session
- );
-
- // 获取刚刚创建的节点ID(当前计数器的值)
- currentNodeId = String.valueOf(session.getNodeIdCounter().get());
-
- log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}",
- sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType());
- } else {
- log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}",
- sessionId, session != null,
- session != null && session.getQaTree() != null,
- response.getQuestion() != null);
- }
-
- // 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId
- Map modifiedResponse = new HashMap<>();
- modifiedResponse.put("question", response.getQuestion());
- modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId());
- modifiedResponse.put("parentNodeId", response.getParentId());
-
- // 3. 发送SSE消息给前端
- emitter.send(SseEmitter.event()
- .name("message")
- .data(modifiedResponse));
- log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId);
- } catch (IOException e) {
- log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
- sseEmitters.remove(sessionId);
- } catch (Exception e) {
- log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
- // 即使QaTree更新失败,仍然发送SSE消息给前端
- try {
- Map fallbackResponse = new HashMap<>();
- fallbackResponse.put("question", response.getQuestion());
- fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback
- fallbackResponse.put("parentNodeId", response.getParentId());
-
- emitter.send(SseEmitter.event()
- .name("message")
- .data(fallbackResponse));
- log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId);
- } catch (IOException ioException) {
- log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage());
- sseEmitters.remove(sessionId);
- }
- }
- } else {
- log.warn("SSE连接不存在 - 会话: {}", sessionId);
- }
- }
-
/**
* 获取SSE连接状态
*/
@GetMapping("/sse-status")
public ResponseEntity