Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions prompto-lab-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
<maven.compiler.target>17</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>commons-lang</groupId>
<artifactId>commons-lang</artifactId>
<version>2.6</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package io.github.timemachinelab.controller;

import io.github.timemachinelab.core.qatree.QaTree;
import io.github.timemachinelab.core.qatree.QaTreeDomain;
import io.github.timemachinelab.core.session.application.ConversationService;
import io.github.timemachinelab.core.session.application.MessageProcessingService;
import io.github.timemachinelab.core.session.application.SessionManagementService;
import io.github.timemachinelab.core.session.application.SseNotificationService;
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest;
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
import io.github.timemachinelab.entity.req.RetryRequest;
import io.github.timemachinelab.entity.resp.ApiResult;
import io.github.timemachinelab.entity.resp.RetryResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
Expand All @@ -23,7 +18,6 @@
import javax.annotation.Resource;
import javax.validation.Valid;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -40,16 +34,12 @@
@RequestMapping("/api/user-interaction")
@Validated
public class UserInteractionController {

@Resource
private ConversationService conversationService;
@Resource
private MessageProcessingService messageProcessingService;
@Resource
private SessionManagementService sessionManagementService;
private final Map<String, SseEmitter> sseEmitters = new ConcurrentHashMap<>();
@Autowired
private QaTreeDomain qaTreeDomain;
@Resource
private SseNotificationService sseNotificationService;

/**
* 建立SSE连接
Expand All @@ -60,7 +50,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId);

boolean isNewSession = false;
ConversationSession session = null;
ConversationSession session;

try {
if (sessionId == null || sessionId.isEmpty()) {
Expand All @@ -82,7 +72,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
}

SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
sseEmitters.put(sessionId, emitter);
sseNotificationService.registerSseConnection(sessionId, emitter);

// 连接建立时发送会话信息
Map<String, Object> connectionData = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -115,9 +105,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId);
}

emitter.send(SseEmitter.event()
.name("connected")
.data(connectionData));
sseNotificationService.sendWelcomeMessage(sessionId, connectionData);

// 设置连接事件处理
String finalSessionId = sessionId;
Expand All @@ -127,12 +115,12 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess

emitter.onTimeout(() -> {
log.info("SSE连接超时: {}", finalSessionId);
sseEmitters.remove(finalSessionId);
sseNotificationService.removeSseConnection(finalSessionId);
});

emitter.onError((ex) -> {
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
sseEmitters.remove(finalSessionId);
sseNotificationService.removeSseConnection(finalSessionId);
});

return emitter;
Expand Down Expand Up @@ -163,8 +151,45 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}",
request.getNodeId(), request.getSessionId(), request.getWhyretry());

// 使用应用服务验证节点存在性
//todo: 有可能水平越权 不传userId的话
if (!sessionManagementService.validateNodeExists(request.getSessionId(), request.getNodeId())) {
log.warn("节点不存在 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
return ResponseEntity.badRequest().body(ApiResult.error("指定的节点不存在"));
}

// 使用应用服务获取问题内容
String question = sessionManagementService.getNodeQuestion(request.getSessionId(), request.getNodeId());
if (question == null) {
log.warn("节点问题内容为空 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
return ResponseEntity.badRequest().body(ApiResult.error("节点问题内容为空"));
}


// 获取会话对象
ConversationSession session = sessionManagementService.getSessionById(request.getSessionId());
if (session == null) {
log.warn("会话不存在 - sessionId: {}", request.getSessionId());
return ResponseEntity.badRequest().body(ApiResult.error("会话不存在"));
}

// 移除要重试的节点(AI会基于parentId重新创建节点)
boolean nodeRemoved = sessionManagementService.removeNode(request.getSessionId(), request.getNodeId());
if (!nodeRemoved) {
log.warn("移除节点失败,但继续处理重试 - sessionId: {}, nodeId: {}",
request.getSessionId(), request.getNodeId());
}

// 使用MessageProcessingService处理重试消息
String processedMessage = messageProcessingService.processRetryMessage(
request.getSessionId(),
request.getNodeId(),
request.getWhyretry(),
session
);

// 发送处理后的消息给AI服务
messageProcessingService.processAndSendMessage(session, processedMessage);

// 构建响应数据
RetryResponse response = RetryResponse.builder()
.nodeId(request.getNodeId())
Expand Down Expand Up @@ -196,12 +221,6 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
request.getNodeId(),
request.getQuestionType());

// 1. 强制要求sessionId
if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) {
log.warn("缺少必需的sessionId参数");
return ResponseEntity.badRequest().body("sessionId参数是必需的");
}

// 2. 会话管理和验证
String userId = request.getUserId();
if (userId == null || userId.trim().isEmpty()) {
Expand Down Expand Up @@ -248,28 +267,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
return ResponseEntity.badRequest().body("答案格式不正确");
}

QaTree qaTree = session.getQaTree();

// 根据问题类型获取正确的答案数据
Object answerData;
switch (request.getQuestionType().toLowerCase()) {
case "input":
answerData = request.getInputAnswer();
break;
case "single":
case "multi":
answerData = request.getChoiceAnswer();
break;
case "form":
answerData = request.getFormAnswer();
break;
default:
log.warn("未知的问题类型: {}", request.getQuestionType());
answerData = request.getAnswerString();
break;
}

qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData);
// 答案更新逻辑已在MessageProcessingService中处理

// 4. 处理答案并转换为消息
String processedMessage = messageProcessingService.preprocessMessage(
Expand All @@ -279,11 +277,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
);

// 5. 发送处理后的消息给AI服务
conversationService.processUserMessage(
session.getUserId(),
processedMessage,
response -> sendSseMessage(session.getSessionId(), response)
);
messageProcessingService.processAndSendMessage(session, processedMessage);


return ResponseEntity.ok("答案处理成功");
Expand All @@ -294,89 +288,11 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
}
}

/**
* 通过SSE发送消息给客户端
* 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新
*
* @param sessionId 会话ID
* @param response 消息响应对象
*/
private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) {
SseEmitter emitter = sseEmitters.get(sessionId);
if (emitter != null) {
try {
String currentNodeId = null;

// 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空)
ConversationSession session = sessionManagementService.getSessionById(sessionId);
if (session != null && session.getQaTree() != null && response.getQuestion() != null) {
// 使用QaTreeDomain添加新节点,answer字段会自动为空
// appendNode方法内部会调用session.getNextNodeId()获取新节点ID
QaTree qaTree = qaTreeDomain.appendNode(
session.getQaTree(),
response.getParentId(),
response.getQuestion(),
session
);

// 获取刚刚创建的节点ID(当前计数器的值)
currentNodeId = String.valueOf(session.getNodeIdCounter().get());

log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}",
sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType());
} else {
log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}",
sessionId, session != null,
session != null && session.getQaTree() != null,
response.getQuestion() != null);
}

// 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId
Map<String, Object> modifiedResponse = new HashMap<>();
modifiedResponse.put("question", response.getQuestion());
modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId());
modifiedResponse.put("parentNodeId", response.getParentId());

// 3. 发送SSE消息给前端
emitter.send(SseEmitter.event()
.name("message")
.data(modifiedResponse));
log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId);
} catch (IOException e) {
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
sseEmitters.remove(sessionId);
} catch (Exception e) {
log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
// 即使QaTree更新失败,仍然发送SSE消息给前端
try {
Map<String, Object> fallbackResponse = new HashMap<>();
fallbackResponse.put("question", response.getQuestion());
fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback
fallbackResponse.put("parentNodeId", response.getParentId());

emitter.send(SseEmitter.event()
.name("message")
.data(fallbackResponse));
log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId);
} catch (IOException ioException) {
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage());
sseEmitters.remove(sessionId);
}
}
} else {
log.warn("SSE连接不存在 - 会话: {}", sessionId);
}
}

/**
* 获取SSE连接状态
*/
@GetMapping("/sse-status")
public ResponseEntity<Map<String, Object>> getSseStatus() {
Map<String, Object> status = new ConcurrentHashMap<>();
status.put("connectedSessions", sseEmitters.keySet());
status.put("totalConnections", sseEmitters.size());
status.put("timestamp", System.currentTimeMillis());
return ResponseEntity.ok(status);
return ResponseEntity.ok(sseNotificationService.getSseStatus());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,59 @@ public void addNode(String parentId, QaTreeNode node) {
public QaTreeNode getNodeById(String id) {
return nodeMap.get(id);
}

/**
* 移除指定节点及其所有子节点
* @param nodeId 要移除的节点ID
* @return 是否移除成功
*/
public boolean removeNode(String nodeId) {
QaTreeNode nodeToRemove = nodeMap.get(nodeId);
if (nodeToRemove == null) {
return false;
}

// 递归移除所有子节点
removeNodeAndChildren(nodeToRemove);

// 从父节点的children中移除该节点
removeFromParent(nodeToRemove);

return true;
}

/**
* 递归移除节点及其所有子节点
* @param node 要移除的节点
*/
private void removeNodeAndChildren(QaTreeNode node) {
if (node == null) {
return;
}

// 递归移除所有子节点
if (node.getChildren() != null) {
for (QaTreeNode child : node.getChildren().values()) {
removeNodeAndChildren(child);
}
}

// 从nodeMap中移除当前节点
nodeMap.remove(node.getId());
}

/**
* 从父节点的children中移除指定节点
* @param nodeToRemove 要移除的节点
*/
private void removeFromParent(QaTreeNode nodeToRemove) {
// 遍历所有节点找到父节点
for (QaTreeNode node : nodeMap.values()) {
if (node.getChildren() != null && node.getChildren().containsKey(nodeToRemove.getId())) {
node.removeChild(nodeToRemove.getId());
break;
}
}
}

}
Loading
Loading