Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions prompto-lab-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
<description>提示词工程</description>
<properties>
<java.version>17</java.version>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>
<dependencies>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ public void addCorsMappings(CorsRegistry registry) {
public CorsConfigurationSource corsConfigurationSource() {
CorsConfiguration configuration = new CorsConfiguration();

// 允许的源
configuration.setAllowedOriginPatterns(Arrays.asList("http://localhost:*", "http://127.0.0.1:*"));
// 允许的源 - 使用具体的域名模式而不是通配符
configuration.setAllowedOriginPatterns(Arrays.asList(
"http://localhost:*",
"http://127.0.0.1:*",
"https://localhost:*",
"https://127.0.0.1:*"
));

// 允许的HTTP方法
configuration.setAllowedMethods(Arrays.asList("GET", "POST", "PUT", "DELETE", "OPTIONS"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
package io.github.timemachinelab.controller;

import io.github.timemachinelab.core.session.application.ConversationService;
import io.github.timemachinelab.core.session.application.MessageProcessingService;
import io.github.timemachinelab.core.session.application.SessionManagementService;
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest;
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
import io.github.timemachinelab.entity.req.RetryRequest;
import io.github.timemachinelab.entity.resp.ApiResult;
import io.github.timemachinelab.entity.resp.RetryResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import javax.annotation.Resource;
import javax.validation.Valid;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

/**
* 用户交互控制器
Expand All @@ -23,6 +37,55 @@
@Validated
public class UserInteractionController {

@Resource
private ConversationService conversationService;
@Resource
private MessageProcessingService messageProcessingService;
@Resource
private SessionManagementService sessionManagementService;
private final Map<String, SseEmitter> sseEmitters = new ConcurrentHashMap<>();

/**
* 建立SSE连接
*/
@GetMapping(value = "/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter streamConversation(@RequestParam(required = false) String sessionId) {
log.info("建立SSE连接 - 会话ID: {}", sessionId);

if(sessionId == null || sessionId.isEmpty()) {
sessionId = UUID.randomUUID().toString();
}
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
sseEmitters.put(sessionId, emitter);

// 连接建立时发送欢迎消息
try {
emitter.send(SseEmitter.event()
.name("connected")
.data("SSE连接已建立,会话ID: " + sessionId));
} catch (IOException e) {
log.error("发送欢迎消息失败: {}", e.getMessage());
}

// 设置连接事件处理
String finalSessionId = sessionId;
emitter.onCompletion(() -> {
log.info("SSE连接完成: {}", finalSessionId);
});

emitter.onTimeout(() -> {
log.info("SSE连接超时: {}", finalSessionId);
sseEmitters.remove(finalSessionId);
});

emitter.onError((ex) -> {
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
sseEmitters.remove(finalSessionId);
});

return emitter;
}

/**
* 重试接口
*
Expand All @@ -34,7 +97,9 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
try {
log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}",
request.getNodeId(), request.getSessionId(), request.getWhyretry());




// 构建响应数据
RetryResponse response = RetryResponse.builder()
.nodeId(request.getNodeId())
Expand All @@ -53,4 +118,90 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
return ResponseEntity.badRequest().body(ApiResult.serverError("重试请求处理失败: " + e.getMessage()));
}
}

/**
* 处理统一答案请求
* 支持单选、多选、输入框、表单等多种问题类型的回答
*/
@PostMapping("/message")
public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswerRequest request) {
try {
log.info("接收到答案请求 - 会话ID: {}, 节点ID: {}, 问题类型: {}",
request.getSessionId(),
request.getNodeId(),
request.getQuestionType());

// 1. 会话管理和验证
String userId = request.getUserId();

ConversationSession session = sessionManagementService.getOrCreateSession(userId, request.getSessionId());

// 2. 验证nodeId是否属于该会话
if (request.getNodeId() != null && !sessionManagementService.validateNodeId(session.getSessionId(), request.getNodeId())) {
log.warn("无效的节点ID - 会话: {}, 节点: {}", session.getSessionId(), request.getNodeId());
return ResponseEntity.badRequest().body("无效的节点ID");
}

// 3. 验证答案格式
if (!messageProcessingService.validateAnswer(request)) {
log.warn("答案格式验证失败: {}", request);
return ResponseEntity.badRequest().body("答案格式不正确");
}

// 4. 处理答案并转换为消息
String processedMessage = messageProcessingService.preprocessMessage(
null, // 没有额外的原始消息
request,
session
);

// 5. 发送处理后的消息给AI服务
conversationService.processUserMessage(
session.getUserId(),
processedMessage,
response -> sendSseMessage(session.getSessionId(), response)
);

return ResponseEntity.ok("答案处理成功");

} catch (Exception e) {
log.error("处理答案失败 - 会话ID: {}, 错误: {}", request.getSessionId(), e.getMessage(), e);
return ResponseEntity.internalServerError().body("答案处理失败: " + e.getMessage());
}
}

/**
* 通过SSE发送消息给客户端
*
* @param sessionId 会话ID
* @param response 消息响应对象
*/
private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) {
SseEmitter emitter = sseEmitters.get(sessionId);
if (emitter != null) {
try {
emitter.send(SseEmitter.event()
.name("message")
.data(response));
log.info("SSE消息发送成功 - 会话: {}, 消息: {}", sessionId, response);
} catch (IOException e) {
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
sseEmitters.remove(sessionId);
}
} else {
log.warn("SSE连接不存在 - 会话: {}", sessionId);
}
}

/**
* 获取SSE连接状态
*/
@GetMapping("/sse-status")
public ResponseEntity<Map<String, Object>> getSseStatus() {
Map<String, Object> status = new ConcurrentHashMap<>();
status.put("connectedSessions", sseEmitters.keySet());
status.put("totalConnections", sseEmitters.size());
status.put("timestamp", System.currentTimeMillis());
return ResponseEntity.ok(status);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONException;
import com.alibaba.fastjson2.JSONObject;
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
import lombok.extern.slf4j.Slf4j;

import java.util.Arrays;
Expand Down Expand Up @@ -37,7 +38,7 @@ public class QuestionParser {
* @return BaseQuestion对象
* @throws QuestionParseException 解析失败时抛出异常
*/
public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseException {
public static QuestionGenerationOperation.QuestionGenerationResponse parseQuestion(String jsonStr) throws QuestionParseException {
if (jsonStr == null || jsonStr.trim().isEmpty()) {
throw new QuestionParseException("JSON字符串不能为空", jsonStr, "输入为空或null");
}
Expand All @@ -52,7 +53,7 @@ public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseExc

// 收集所有解析失败的原因
List<String> failureReasons = new ArrayList<>();

String parentId = jsonObject.getString("parentId");
// 依次尝试解析成不同类型
for (Class<? extends BaseQuestion> questionType : QUESTION_TYPES) {
try {
Expand All @@ -61,7 +62,7 @@ public static BaseQuestion parseQuestion(String jsonStr) throws QuestionParseExc
String validationResult = validateQuestion(question, jsonObject);
if (validationResult == null) {
log.info("成功解析为: {}", questionType.getSimpleName());
return question;
return new QuestionGenerationOperation.QuestionGenerationResponse(question,parentId);
} else {
failureReasons.add(questionType.getSimpleName() + ": " + validationResult);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
package io.github.timemachinelab.core.session.application;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import io.github.timemachinelab.core.qatree.QaTree;
import io.github.timemachinelab.core.qatree.QaTreeDomain;
import io.github.timemachinelab.core.qatree.QaTreeNode;
import io.github.timemachinelab.core.question.BaseQuestion;
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
import io.github.timemachinelab.core.session.infrastructure.ai.ConversationOperation;
import io.github.timemachinelab.sfchain.core.AIService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Map;
import java.util.List;
Expand All @@ -18,93 +26,43 @@
@RequiredArgsConstructor
@Slf4j
public class ConversationService {


@Resource
private final AIService aiService;
@Resource
private SessionManagementService sessionManagementService;
private final QaTreeDomain qaTreeDomain;

private final Map<String, ConversationSession> sessions = new ConcurrentHashMap<>();

public ConversationSession createSession(String userId) {
ConversationSession session = new ConversationSession(userId);
sessions.put(session.getSessionId(), session);
return session;
}

public ConversationSession getSession(String sessionId) {
return sessions.get(sessionId);
}

public void processUserMessage(String sessionId, String userMessage, Consumer<MessageResponse> sseCallback) {
ConversationSession session = sessions.get(sessionId);
public void processUserMessage(String userId, String userMessage, Consumer<QuestionGenerationOperation.QuestionGenerationResponse> sseCallback) {
ConversationSession session = sessionManagementService.getUserCurrentSession(userId);
if (session == null) {
log.warn("会话不存在: {}", sessionId);
log.warn("会话不存在");
return;
}

// 1. 添加用户消息到会话历史
session.addMessage(userMessage, "user");

// 2. 发送用户消息确认
sseCallback.accept(MessageResponse.userAnswer("user_" + System.currentTimeMillis(), userMessage));

// 3. 调用AI服务获取回复
processAIResponse(session, userMessage, sseCallback);

processAIResponse(userMessage, sseCallback);
}

private void processAIResponse(ConversationSession session, String userMessage, Consumer<MessageResponse> sseCallback) {
private void processAIResponse(String userMessage, Consumer<QuestionGenerationOperation.QuestionGenerationResponse> sseCallback) {
try {
// 构建对话历史
List<ConversationOperation.ConversationHistory> history = buildConversationHistory(session);

JSONObject object = JSON.parseObject(userMessage);

// 创建AI请求
ConversationOperation.ConversationRequest request = new ConversationOperation.ConversationRequest(
session.getSessionId(),
"current",
userMessage
);
request.setConversationHistory(history);

QuestionGenerationOperation.QuestionGenerationRequest request = new QuestionGenerationOperation.QuestionGenerationRequest(object.getString("prompt"),object.getString("tree"),object.getString("input"));
// 调用AI服务
ConversationOperation.ConversationResponse aiResponse = aiService.execute("CONVERSATION_OP", request);
log.info("AI服务调用成功: {}", aiResponse);
QuestionGenerationOperation.QuestionGenerationResponse aiResponse = aiService.execute("QUESTION_GENERATION_OP", request);

// 添加AI回复到会话历史
session.addMessage(aiResponse.getAnswer(), "assistant");

// 根据响应类型处理AI回复
String nodeId = "ai_" + System.currentTimeMillis();
sseCallback.accept(MessageResponse.aiAnswer("ai_" + System.currentTimeMillis(), aiResponse.getAnswer()));
// if (aiResponse.getResponseType() == ConversationOperation.ResponseType.SELECTION) {
// // 选择题类型
// sseCallback.accept(MessageResponse.aiSelectionQuestion(nodeId, aiResponse.getAnswer(), aiResponse.getOptions()));
// } else {
// // 普通文本回复
// sseCallback.accept(MessageResponse.aiQuestion(nodeId, aiResponse.getAnswer()));
// }
sseCallback.accept(aiResponse);
log.info("AI服务调用成功: {}", aiResponse);

} catch (Exception e) {
log.error("AI服务调用失败: {}", e.getMessage(), e);
// 降级处理
String fallbackResponse = "抱歉,我暂时无法处理您的请求,请稍后再试。";
session.addMessage(fallbackResponse, "assistant");
String nodeId = "ai_" + System.currentTimeMillis();
sseCallback.accept(MessageResponse.aiQuestion(nodeId, fallbackResponse));
}
}

private List<ConversationOperation.ConversationHistory> buildConversationHistory(ConversationSession session) {
List<ConversationOperation.ConversationHistory> history = new ArrayList<>();

// 从会话消息构建对话历史
for (ConversationSession.ConversationMessage message : session.getMessages()) {
ConversationOperation.ConversationHistory historyItem = new ConversationOperation.ConversationHistory(
message.getRole(),
message.getContent(),
message.getRole() + "_" + message.getTimestamp().toString()
);
history.add(historyItem);
}

return history;
}


Expand Down
Loading
Loading