Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import io.github.timemachinelab.core.session.application.RetryProcessingService;
import io.github.timemachinelab.core.qatree.QaTreeDomain;
import io.github.timemachinelab.core.qatree.QaTree;
import io.github.timemachinelab.core.qatree.QaTreeNode;
import io.github.timemachinelab.core.question.InputQuestion;

import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
import io.github.timemachinelab.core.session.infrastructure.web.dto.*;
Expand All @@ -21,6 +23,7 @@
import io.github.timemachinelab.entity.resp.ApiResult;
import io.github.timemachinelab.entity.resp.RetryResponse;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
Expand All @@ -32,6 +35,7 @@
import javax.validation.Valid;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -253,6 +257,17 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
// 2. 检查sessionId和answer的逻辑
String sessionId = request.getSessionId();
Object answer = request.getAnswer();

// 将Object类型的answer转换为String类型
String answerStr = "";
if (answer != null) {
if (answer instanceof String) {
answerStr = (String) answer;
} else {
// 对于其他类型,转换为JSON字符串
answerStr = JSONObject.toJSONString(answer);
}
}

if (sessionId == null || sessionId.trim().isEmpty()) {
// 如果没有sessionId,必须检查answer是否为空
Expand All @@ -269,11 +284,6 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
if (sessionId == null || sessionId.trim().isEmpty()) {
// 新建会话
session = sessionManagementService.createNewSession(fingerprint);
if (session == null) {
log.error("会话创建失败 - 指纹: {}", fingerprint);
sseNotificationService.sendErrorMessage(fingerprint, "会话创建失败,请重试"); // 保持原样,因为错误消息的发送方式未改变
return ResponseEntity.internalServerError().body("会话处理失败");
}
} else {
// 3. 如果存在sessionId,获取conversation的currentNodeId,表示当前node节点需要过滤
session = sessionManagementService.validateAndGetSession(fingerprint, sessionId);
Expand All @@ -283,15 +293,37 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
return ResponseEntity.badRequest().body("会话不存在或已失效");
}

// 获取当前节点ID并过滤qaTree
// 获取当前节点ID
String currentNodeId = session.getCurrentNode();
QaTree originalQaTree = session.getQaTree();

// answerStr已在方法开始处定义,这里直接使用

// 检查是否是在回答当前问题
if (!StringUtils.isBlank(answerStr)) {
// 用户提供了答案,说明是在回答当前问题
// 先将答案插入到qaTree中
try {
boolean updateSuccess = qaTreeDomain.updateNodeAnswer(originalQaTree, currentNodeId, answerStr);
if (updateSuccess) {
log.info("已将用户答案插入qaTree - 会话: {}, 节点: {}, 答案: {}", sessionId, currentNodeId, answerStr);
} else {
log.warn("更新节点答案失败,节点可能不存在 - 会话: {}, 节点: {}", sessionId, currentNodeId);
sseNotificationService.sendErrorMessage(fingerprint, "当前问题节点不存在,请刷新页面重试");
return ResponseEntity.badRequest().body("当前问题节点不存在");
}
} catch (Exception e) {
log.error("插入用户答案失败 - 会话: {}, 节点: {}", sessionId, currentNodeId, e);
sseNotificationService.sendErrorMessage(fingerprint, "处理用户答案失败,请重试");
return ResponseEntity.internalServerError().body("处理用户答案失败");
}
}

// 4. 在qaTreeDomain里过滤qaNode(如果answer不存在则过滤),返回整个qaTree
filteredQaTree = qaTreeDomain.filterQaTreeByAnswer(originalQaTree, currentNodeId);
log.info("已过滤qaTree - 会话: {}, 过滤节点: {}", sessionId, currentNodeId);
}

// 5. 走现在有的逻辑(从创建会话开始) - 调用AI服务生成提示词
// 如果有过滤后的qaTree,临时替换session中的qaTree
QaTree originalQaTree = null;
if (filteredQaTree != null) {
Expand All @@ -300,16 +332,7 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
}

final QaTree finalOriginalQaTree = originalQaTree;
// 将Object类型的answer转换为String类型
String answerStr = "";
if (request.getAnswer() != null) {
if (request.getAnswer() instanceof String) {
answerStr = (String) request.getAnswer();
} else {
// 对于其他类型,转换为JSON字符串
answerStr = JSONObject.toJSONString(request.getAnswer());
}
}
// answerStr已在上面定义,这里不需要重复定义

conversationService.genPrompt(session.getSessionId(), answerStr, response -> {
try {
Expand All @@ -318,15 +341,34 @@ public ResponseEntity<String> genPrompt(@RequestBody GenPromptRequest request, H
session.setQaTree(finalOriginalQaTree);
}

// 更新currentNode - 在AI回答后创建新节点
String newNodeId = session.getNextNodeId();
session.setCurrentNode(newNodeId);
String parentId = session.getCurrentNode();

// 创建一个文本类型的问题,内容是生成的提示词
InputQuestion promptQuestion = new InputQuestion();
promptQuestion.setQuestion(response.getGenPrompt());
promptQuestion.setAnswer(""); // 初始无答案,等待用户回答

// 添加到qaTree
QaTreeNode promptNode = qaTreeDomain.appendNode(session.getQaTree(), parentId, promptQuestion, session);

// 更新currentNode为新创建的提示词节点
session.setCurrentNode(promptNode.getId());
session.setUpdateTime(LocalDateTime.now());

// 发送AI生成的提示词
sseNotificationService.sendSuccessMessage(fingerprint, response.getGenPrompt()); // 保持原样,因为成功消息的发送方式未改变

log.info("genPrompt处理完成 - 会话: {}, 新节点: {}", session.getSessionId(), newNodeId);
// 发送question格式的SSE消息,就像普通的AI回答一样
Map<String, Object> questionResponse = new HashMap<>();
questionResponse.put("question", Map.of(
"type", "input",
"question", response.getGenPrompt(),
"desc", "这是为您生成的提示词,您可以基于此内容继续对话"
));
questionResponse.put("sessionId", session.getSessionId());
questionResponse.put("currentNodeId", promptNode.getId());
questionResponse.put("parentNodeId", session.getCurrentNode());

sseNotificationService.sendSuccessMessage(fingerprint, JSONObject.toJSONString(questionResponse));

log.info("genPrompt处理完成 - 会话: {}, 新节点: {}", session.getSessionId(), promptNode.getId());
} catch (Exception e) {
// 恢复原始qaTree(异常情况下)
if (finalOriginalQaTree != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import io.github.timemachinelab.core.qatree.QaTree;
import io.github.timemachinelab.core.qatree.QaTreeNode;
import io.github.timemachinelab.core.serializable.JsonNode;
import io.github.timemachinelab.core.question.*;
import io.github.timemachinelab.core.serializable.TempFormQuestion;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashMap;

public class QaTreeSerializeUtil {

Expand All @@ -16,33 +20,210 @@ public static String serialize(QaTree t) throws JsonProcessingException {
return "[]";
}

List<JsonNode> result = new ArrayList<>();
List<Map<String, Object>> result = new ArrayList<>();

firstOrderTraversal(t.getRoot(), null, result);
firstOrderTraversalEnhanced(t.getRoot(), null, result);

return JSONObject.toJSONString(result);
}

private static void firstOrderTraversal(QaTreeNode node, String parentId, List<JsonNode> result) throws JsonProcessingException {
/**
* 增强版遍历方法,返回SSE兼容的格式
*/
private static void firstOrderTraversalEnhanced(QaTreeNode node, String parentId, List<Map<String, Object>> result) throws JsonProcessingException {
if (node == null) {
return;
}

// 获取子节点列表
List<QaTreeNode> children = new ArrayList<>();

if (node.getChildren() != null) {
children.addAll(node.getChildren().values());
}

// 访问当前节点
JsonNode jsonNode = JsonNode.Convert2JsonNode(node, parentId);

result.add(jsonNode);
// 创建增强的节点数据
Map<String, Object> enhancedNode = createEnhancedNode(node, parentId);
result.add(enhancedNode);

// 先序遍历
for (QaTreeNode child : children) {
firstOrderTraversal(child, node.getId(), result);
firstOrderTraversalEnhanced(child, node.getId(), result);
}
}

/**
* 创建SSE兼容的增强节点数据
*/
private static Map<String, Object> createEnhancedNode(QaTreeNode node, String parentId) {
Map<String, Object> enhancedNode = new HashMap<>();
enhancedNode.put("nodeId", node.getId());
enhancedNode.put("parentId", parentId);

String answer = "";
Map<String, Object> questionData = null;

BaseQuestion qa = node.getQa();
if (qa != null) {
// 根据问题类型创建questionData
QuestionType type = QuestionType.fromString(qa.getType());
switch (type) {
case INPUT:
InputQuestion inputQA = (InputQuestion) qa;
questionData = createInputQuestionData(inputQA);
answer = inputQA.getAnswer() != null ? inputQA.getAnswer() : "";
break;
case SINGLE:
SingleChoiceQuestion singleQA = (SingleChoiceQuestion) qa;
questionData = createSingleQuestionData(singleQA);
answer = formatSingleAnswer(singleQA);
break;
case MULTI:
MultipleChoiceQuestion multiQA = (MultipleChoiceQuestion) qa;
questionData = createMultiQuestionData(multiQA);
answer = formatMultiAnswer(multiQA);
break;
case FORM:
FormQuestion formQA = (FormQuestion) qa;
questionData = createFormQuestionData(formQA);
answer = formQA.getAnswer() != null ? JSONObject.toJSONString(formQA.getAnswer()) : "";
break;
default:
// 普通文本问题
questionData = createTextQuestionData(qa.getQuestion());
break;
}
}

enhancedNode.put("questionData", questionData);
enhancedNode.put("answer", answer);

return enhancedNode;
}

/**
* 创建输入问题数据
*/
private static Map<String, Object> createInputQuestionData(InputQuestion inputQA) {
Map<String, Object> questionData = new HashMap<>();
questionData.put("type", "input");
questionData.put("question", inputQA.getQuestion() != null ? inputQA.getQuestion() : "");
questionData.put("desc", inputQA.getDesc() != null ? inputQA.getDesc() : "");
return questionData;
}

/**
* 创建单选问题数据
*/
private static Map<String, Object> createSingleQuestionData(SingleChoiceQuestion singleQA) {
Map<String, Object> questionData = new HashMap<>();
questionData.put("type", "single");
questionData.put("question", singleQA.getQuestion() != null ? singleQA.getQuestion() : "");
questionData.put("desc", singleQA.getDesc() != null ? singleQA.getDesc() : "");
questionData.put("options", singleQA.getOptions() != null ? singleQA.getOptions() : new ArrayList<>());
return questionData;
}

/**
* 创建多选问题数据
*/
private static Map<String, Object> createMultiQuestionData(MultipleChoiceQuestion multiQA) {
Map<String, Object> questionData = new HashMap<>();
questionData.put("type", "multi");
questionData.put("question", multiQA.getQuestion() != null ? multiQA.getQuestion() : "");
questionData.put("desc", multiQA.getDesc() != null ? multiQA.getDesc() : "");
questionData.put("options", multiQA.getOptions() != null ? multiQA.getOptions() : new ArrayList<>());
return questionData;
}

/**
* 创建表单问题数据
*/
private static Map<String, Object> createFormQuestionData(FormQuestion formQA) {
Map<String, Object> questionData = new HashMap<>();
questionData.put("type", "form");
questionData.put("question", formQA.getQuestion() != null ? formQA.getQuestion() : "");
questionData.put("desc", formQA.getDesc() != null ? formQA.getDesc() : "");
questionData.put("fields", formQA.getFields() != null ? formQA.getFields() : new ArrayList<>());
return questionData;
}

/**
* 创建普通文本问题数据
*/
private static Map<String, Object> createTextQuestionData(String question) {
Map<String, Object> questionData = new HashMap<>();
questionData.put("type", "text");
questionData.put("question", question != null ? question : "");
questionData.put("desc", "");
return questionData;
}

/**
* 格式化单选答案
*/
private static String formatSingleAnswer(SingleChoiceQuestion singleQA) {
if (singleQA.getAnswer() != null && !singleQA.getAnswer().isEmpty()) {
List<String> answerLabels = new ArrayList<>();
for (String answerId : singleQA.getAnswer()) {
String label = findOptionLabel(singleQA.getOptions(), answerId);
answerLabels.add(label != null ? label : answerId);
}
return String.join(",", answerLabels);
}
return "";
}

/**
* 格式化多选答案
*/
private static String formatMultiAnswer(MultipleChoiceQuestion multiQA) {
if (multiQA.getAnswer() != null && !multiQA.getAnswer().isEmpty()) {
List<String> answerLabels = new ArrayList<>();
for (String answerId : multiQA.getAnswer()) {
String label = findOptionLabel(multiQA.getOptions(), answerId);
answerLabels.add(label != null ? label : answerId);
}
return String.join(",", answerLabels);
}
return "";
}

/**
* 根据选项id查找对应的标签
*/
private static String findOptionLabel(List<Option> options, String id) {
if (options == null || id == null) {
return null;
}
for (Option option : options) {
if (id.equals(option.getId())) {
return option.getLabel();
}
}
return null;
}

// 保留原有的序列化方法作为备用
private static void firstOrderTraversal(QaTreeNode node, String parentId, List<JsonNode> result) throws JsonProcessingException {
if (node == null) {
return;
}

// 获取子节点列表
List<QaTreeNode> children = new ArrayList<>();

if (node.getChildren() != null) {
children.addAll(node.getChildren().values());
}

// 访问当前节点
JsonNode jsonNode = JsonNode.Convert2JsonNode(node, parentId);

result.add(jsonNode);

// 先序遍历
for (QaTreeNode child : children) {
firstOrderTraversal(child, node.getId(), result);
}
}
}
Loading
Loading