From 73a523a08de0c0ef2b1774ae69cbdef4d097a4ee Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 05:25:11 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20Add=20docstrings=20to=20`main`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstrings generation was requested by @KimmyXYC. * https://github.com/hduhelp/backend_2025_freshman_task/pull/8#issuecomment-3375254862 The following files were modified: * `KimmyXYC/cmd/server/main.go` * `KimmyXYC/internal/db/db.go` * `KimmyXYC/internal/httpserver/router.go` * `KimmyXYC/internal/provider/openai.go` * `KimmyXYC/internal/provider/provider.go` * `KimmyXYC/internal/services/auth_service.go` * `KimmyXYC/internal/services/chat_service.go` * `KimmyXYC/pkg/auth/token.go` * `KimmyXYC/pkg/middleware/auth.go` * `KimmyXYC/web/js/api.js` * `KimmyXYC/web/js/auth.js` * `KimmyXYC/web/js/chat.js` * `KimmyXYC/web/js/main.js` * `KimmyXYC/web/js/state.js` --- KimmyXYC/cmd/server/main.go | 46 +++++ KimmyXYC/internal/db/db.go | 37 ++++ KimmyXYC/internal/httpserver/router.go | 200 +++++++++++++++++++++ KimmyXYC/internal/provider/openai.go | 174 ++++++++++++++++++ KimmyXYC/internal/provider/provider.go | 74 ++++++++ KimmyXYC/internal/services/auth_service.go | 71 ++++++++ KimmyXYC/internal/services/chat_service.go | 123 +++++++++++++ KimmyXYC/pkg/auth/token.go | 67 +++++++ KimmyXYC/pkg/middleware/auth.go | 81 +++++++++ KimmyXYC/web/js/api.js | 186 +++++++++++++++++++ KimmyXYC/web/js/auth.js | 51 ++++++ KimmyXYC/web/js/chat.js | 150 ++++++++++++++++ KimmyXYC/web/js/main.js | 65 +++++++ KimmyXYC/web/js/state.js | 61 +++++++ 14 files changed, 1386 insertions(+) create mode 100644 KimmyXYC/cmd/server/main.go create mode 100644 KimmyXYC/internal/db/db.go create mode 100644 KimmyXYC/internal/httpserver/router.go create mode 100644 KimmyXYC/internal/provider/openai.go create mode 100644 KimmyXYC/internal/provider/provider.go create mode 100644 KimmyXYC/internal/services/auth_service.go create mode 100644 KimmyXYC/internal/services/chat_service.go create mode 100644 KimmyXYC/pkg/auth/token.go create mode 100644 KimmyXYC/pkg/middleware/auth.go create mode 100644 KimmyXYC/web/js/api.js create mode 100644 KimmyXYC/web/js/auth.js create mode 100644 KimmyXYC/web/js/chat.js create mode 100644 KimmyXYC/web/js/main.js create mode 100644 KimmyXYC/web/js/state.js diff --git a/KimmyXYC/cmd/server/main.go b/KimmyXYC/cmd/server/main.go new file mode 100644 index 0000000..41b6bd3 --- /dev/null +++ b/KimmyXYC/cmd/server/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "log" + "os" + + "github.com/joho/godotenv" + + "AIBackend/internal/db" + "AIBackend/internal/httpserver" + "AIBackend/internal/provider" +) + +// main 是程序的入口点。它可选加载 `.env`,使用 `DATABASE_URL` 建立并迁移数据库, +// 从环境创建 LLM 提供者,并使用 `ADDR`(默认 ":8080")启动 HTTP 服务器;在连接、迁移或启动失败时记录致命错误,在缺少 `DATABASE_URL` 时记录警告。 +func main() { + // Load .env if present (dev convenience) + _ = godotenv.Load() + + // Initialize DB + pgURL := os.Getenv("DATABASE_URL") + if pgURL == "" { + log.Println("WARNING: DATABASE_URL is not set. The server may fail to start when DB is required.") + } + gormDB, err := db.Connect(pgURL) + if err != nil { + log.Fatalf("failed to connect database: %v", err) + } + if err := db.AutoMigrate(gormDB); err != nil { + log.Fatalf("failed to migrate database: %v", err) + } + + // Initialize LLM provider (Mock by default) + llm := provider.NewProviderFromEnv() + + // Start HTTP server + r := httpserver.NewRouter(gormDB, llm) + addr := os.Getenv("ADDR") + if addr == "" { + addr = ":8080" + } + log.Printf("Server listening on %s", addr) + if err := r.Run(addr); err != nil { + log.Fatalf("server error: %v", err) + } +} \ No newline at end of file diff --git a/KimmyXYC/internal/db/db.go b/KimmyXYC/internal/db/db.go new file mode 100644 index 0000000..77cfa0d --- /dev/null +++ b/KimmyXYC/internal/db/db.go @@ -0,0 +1,37 @@ +package db + +import ( + "fmt" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "AIBackend/internal/models" +) + +// Connect 使用提供的数据库 URL 打开一个 PostgreSQL 连接。 +// 如果传入的 databaseURL 为空,则使用本地默认 DSN: +// postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable。 +// 返回已打开的 *gorm.DB;在无法建立连接时返回非 nil 错误。 +func Connect(databaseURL string) (*gorm.DB, error) { + if databaseURL == "" { + // Provide a friendly default to help first run; it will still fail if DB not available. + databaseURL = "postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable" + } + dsn := databaseURL + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("connect postgres: %w", err) + } + return db, nil +} + +// AutoMigrate 在数据库上应用 User、Conversation 和 Message 模型的自动迁移。 +// 如果迁移失败,返回相应的错误。 +func AutoMigrate(db *gorm.DB) error { + return db.AutoMigrate( + &models.User{}, + &models.Conversation{}, + &models.Message{}, + ) +} \ No newline at end of file diff --git a/KimmyXYC/internal/httpserver/router.go b/KimmyXYC/internal/httpserver/router.go new file mode 100644 index 0000000..84ab775 --- /dev/null +++ b/KimmyXYC/internal/httpserver/router.go @@ -0,0 +1,200 @@ +package httpserver + +import ( + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" + + "AIBackend/internal/provider" + "AIBackend/internal/services" + "AIBackend/pkg/middleware" +) + +type Server struct { + Auth *services.AuthService + Chat *services.ChatService +} + +// NewRouter 创建并返回已配置的 Gin 引擎,注册健康检查、认证相关路由、带鉴权的会话与聊天 API(包含可选的 SSE 流式聊天)并提供前端静态文件服务。 +func NewRouter(db *gorm.DB, llm provider.LLMProvider) *gin.Engine { + g := gin.Default() + + server := &Server{ + Auth: services.NewAuthService(db), + Chat: services.NewChatService(db, llm), + } + + g.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) + + api := g.Group("/api") + { + auth := api.Group("/auth") + auth.POST("/register", server.handleRegister) + auth.POST("/login", server.handleLogin) + } + + protected := api.Group("") + protected.Use(middleware.AuthRequired()) + { + protected.GET("/me", server.handleMe) + protected.GET("/conversations", server.handleListConversations) + protected.GET("/conversations/:id/messages", server.handleGetMessages) + protected.POST("/chat", middleware.ModelAccess(), server.handleChat) + } + + // Serve static frontend files without conflicting wildcard + g.StaticFile("/", "./web/index.html") + g.Static("/css", "./web/css") + g.Static("/js", "./web/js") + + return g +} + +type registerReq struct { + Email string `json:"email" binding:"required"` + Password string `json:"password" binding:"required"` + Role string `json:"role"` +} + +type loginReq struct { + Email string `json:"email" binding:"required"` + Password string `json:"password" binding:"required"` +} + +func (s *Server) handleRegister(c *gin.Context) { + var req registerReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + user, token, err := s.Auth.Register(req.Email, req.Password, req.Role) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"user": user, "token": token}) +} + +func (s *Server) handleLogin(c *gin.Context) { + var req loginReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + user, token, err := s.Auth.Login(req.Email, req.Password) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"user": user, "token": token}) +} + +func (s *Server) handleMe(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "user_id": c.GetUint("user_id"), + "user_email": c.GetString("user_email"), + "user_role": c.GetString("user_role"), + }) +} + +func (s *Server) handleListConversations(c *gin.Context) { + uid := c.GetUint("user_id") + convs, err := s.Chat.ListConversations(uid) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"conversations": convs}) +} + +func (s *Server) handleGetMessages(c *gin.Context) { + uid := c.GetUint("user_id") + idStr := c.Param("id") + id64, _ := strconv.ParseUint(idStr, 10, 64) + msgs, err := s.Chat.GetMessages(uid, uint(id64)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"messages": msgs}) +} + +type chatReq struct { + ConversationID uint `json:"conversation_id"` + Model string `json:"model"` + Message string `json:"message" binding:"required"` + Stream *bool `json:"stream"` +} + +func (s *Server) handleChat(c *gin.Context) { + uid := c.GetUint("user_id") + var req chatReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + // Fallback to query param model for middleware check compatibility + if req.Model == "" { + req.Model = c.Query("model") + } + // Enforce model access if provided in body + role := c.GetString("user_role") + if !middleware.CheckModelAccess(role, req.Model) { + c.JSON(http.StatusForbidden, gin.H{"error": "model access denied for role"}) + return + } + streaming := false + if req.Stream != nil { + streaming = *req.Stream + } + if c.Query("stream") == "1" || c.Query("stream") == "true" { + streaming = true + } + if !streaming { + convID, reply, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, nil) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"conversation_id": convID, "reply": reply}) + return + } + // Streaming via SSE + w := c.Writer + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(http.StatusOK) + flusher, _ := w.(http.Flusher) + sentAny := false + convID, _, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, func(chunk string) error { + sentAny = true + _, err := w.Write([]byte("data: " + chunk + "\n\n")) + if err == nil && flusher != nil { + flusher.Flush() + } + return err + }) + if err != nil { + // send error as SSE comment and 0-length event end + _, _ = w.Write([]byte(": error: " + err.Error() + "\n\n")) + if flusher != nil { + flusher.Flush() + } + return + } + if !sentAny { + // send at least one empty event to keep clients happy + _, _ = w.Write([]byte("data: \n\n")) + } + // end event + _, _ = w.Write([]byte("event: done\n" + "data: {\"conversation_id\": " + strconv.FormatUint(uint64(convID), 10) + "}\n\n")) + if flusher != nil { + flusher.Flush() + } + // allow connection to close shortly after + time.Sleep(50 * time.Millisecond) +} \ No newline at end of file diff --git a/KimmyXYC/internal/provider/openai.go b/KimmyXYC/internal/provider/openai.go new file mode 100644 index 0000000..0dcbc97 --- /dev/null +++ b/KimmyXYC/internal/provider/openai.go @@ -0,0 +1,174 @@ +package provider + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "strings" + "time" +) + +// OpenAIProvider implements LLMProvider using OpenAI-compatible Chat Completions API. +// It supports custom endpoint and token via environment variables: +// OPENAI_API_KEY - required to enable this provider +// OPENAI_API_BASE - optional, defaults to https://api.openai.com +// The API path used is {BASE}/v1/chat/completions with stream=true. +// The "model" passed from caller is forwarded as-is. + +type OpenAIProvider struct { + BaseURL string + APIKey string + Client *http.Client +} + +// NewOpenAIProviderFromEnv 从环境变量创建并返回一个指向 OpenAIProvider 的指针。 +// 当环境变量 OPENAI_API_KEY 为空时返回 nil。若未设置 OPENAI_API_BASE,则使用 +// "https://api.openai.com" 作为默认值;会去除 base URL 的尾部斜杠并创建一个带有 +// 90 秒超时的 http.Client。 +func NewOpenAIProviderFromEnv() *OpenAIProvider { + key := os.Getenv("OPENAI_API_KEY") + if key == "" { + return nil + } + base := os.Getenv("OPENAI_API_BASE") + if base == "" { + base = "https://api.openai.com" + } + return &OpenAIProvider{ + BaseURL: strings.TrimRight(base, "/"), + APIKey: key, + Client: &http.Client{Timeout: 90 * time.Second}, + } +} + +type openAIChatRequest struct { + Model string `json:"model"` + Messages []openAIChatMessage `json:"messages"` + Stream bool `json:"stream"` +} + +type openAIChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openAIStreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []openAIStreamChunkChoice `json:"choices"` +} + +type openAIStreamChunkChoice struct { + Index int `json:"index"` + Delta openAIStreamDelta `json:"delta"` + // finish_reason may be "stop" etc. + FinishReason *string `json:"finish_reason"` +} + +type openAIStreamDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +// ChatCompletionStream implements streaming chat using OpenAI SSE. +func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, messages []ChatMessage) (<-chan StreamChunk, error) { + if p == nil || p.APIKey == "" { + return nil, errors.New("openai provider not configured") + } + url := p.BaseURL + "/v1/chat/completions" + + reqPayload := openAIChatRequest{ + Model: model, + Stream: true, + } + for _, m := range messages { + role := strings.ToLower(m.Role) + if role == "assistant" || role == "user" || role == "system" { + // ok + } else { + // map unknown roles to user to avoid API errors + role = "user" + } + reqPayload.Messages = append(reqPayload.Messages, openAIChatMessage{Role: role, Content: m.Content}) + } + buf, err := json.Marshal(reqPayload) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(buf)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Authorization", "Bearer "+p.APIKey) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + + resp, err := p.Client.Do(httpReq) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + return nil, errors.New(strings.TrimSpace(string(b))) + } + + ch := make(chan StreamChunk) + go func() { + defer close(ch) + defer resp.Body.Close() + r := bufio.NewReader(resp.Body) + for { + select { + case <-ctx.Done(): + ch <- StreamChunk{Err: ctx.Err()} + return + default: + } + line, err := r.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) { + // end of stream + ch <- StreamChunk{Done: true} + } + return + } + line = strings.TrimRight(line, "\r\n") + if line == "" || strings.HasPrefix(line, ":") { // comments/keepalive + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "[DONE]" { + ch <- StreamChunk{Done: true} + return + } + var chunk openAIStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + // send as raw text if JSON parse fails + ch <- StreamChunk{Content: data} + continue + } + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + ch <- StreamChunk{Content: choice.Delta.Content} + } + if choice.FinishReason != nil && *choice.FinishReason != "" { + // when finish reason received, mark done soon + // we won't break immediately because there could be other choices + } + } + } + }() + return ch, nil +} \ No newline at end of file diff --git a/KimmyXYC/internal/provider/provider.go b/KimmyXYC/internal/provider/provider.go new file mode 100644 index 0000000..7df7ba5 --- /dev/null +++ b/KimmyXYC/internal/provider/provider.go @@ -0,0 +1,74 @@ +package provider + +import ( + "context" + "os" + "strings" + "time" +) + +// ChatMessage represents a message sent to/from the model. +type ChatMessage struct { + Role string + Content string +} + +// StreamChunk represents a chunk of streamed content. +type StreamChunk struct { + Content string + Done bool + Err error +} + +// LLMProvider is an abstraction over an AI chat model provider. +type LLMProvider interface { + // ChatCompletionStream streams the assistant reply for given messages and model. + ChatCompletionStream(ctx context.Context, model string, messages []ChatMessage) (<-chan StreamChunk, error) +} + +// NewProviderFromEnv selects a provider based on environment variables. +// NewProviderFromEnv 根据环境变量选择并返回一个 LLMProvider。 +// 如果存在可用的 OpenAI 兼容配置(由 NewOpenAIProviderFromEnv 提供),则返回该提供者;否则返回一个用于测试/占位的 MockProvider。 +// 读取 VOLC_API_KEY 保留用于将来真实提供者的支持。 +func NewProviderFromEnv() LLMProvider { + if p := NewOpenAIProviderFromEnv(); p != nil { + return p + } + _ = os.Getenv("VOLC_API_KEY") // reserved for future real provider + return &MockProvider{} +} + +// MockProvider is a simple echo-based provider with fake streaming. +type MockProvider struct{} + +func (m *MockProvider) ChatCompletionStream(ctx context.Context, model string, messages []ChatMessage) (<-chan StreamChunk, error) { + ch := make(chan StreamChunk) + go func() { + defer close(ch) + // naive: concatenate last user message and reply with a friendly echo + var prompt string + for i := len(messages) - 1; i >= 0; i-- { + if strings.ToLower(messages[i].Role) == "user" { + prompt = messages[i].Content + break + } + } + if prompt == "" { + prompt = "Hello! Ask me anything." + } + reply := "[Mock-" + model + "] " + "You said: " + prompt + // stream in word chunks + words := strings.Split(reply, " ") + for i, w := range words { + select { + case <-ctx.Done(): + ch <- StreamChunk{Err: ctx.Err()} + return + case ch <- StreamChunk{Content: func() string { if i == 0 { return w } ; return " " + w }()}: + time.Sleep(50 * time.Millisecond) + } + } + ch <- StreamChunk{Done: true} + }() + return ch, nil +} \ No newline at end of file diff --git a/KimmyXYC/internal/services/auth_service.go b/KimmyXYC/internal/services/auth_service.go new file mode 100644 index 0000000..18db5ca --- /dev/null +++ b/KimmyXYC/internal/services/auth_service.go @@ -0,0 +1,71 @@ +package services + +import ( + "errors" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" + + "AIBackend/internal/models" + "AIBackend/pkg/auth" +) + +type AuthService struct { + DB *gorm.DB +} + +// NewAuthService 创建并返回一个使用提供的数据库句柄的 AuthService 实例。 +// db 是用于执行用户相关持久化操作的 GORM 数据库连接。 +func NewAuthService(db *gorm.DB) *AuthService { + return &AuthService{DB: db} +} + +func (s *AuthService) Register(email, password, role string) (*models.User, string, error) { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || password == "" { + return nil, "", errors.New("email and password required") + } + if role == "" { + role = "free" + } + var existing models.User + if err := s.DB.Where("email = ?", email).First(&existing).Error; err == nil { + return nil, "", errors.New("email already registered") + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, "", err + } + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, "", err + } + user := &models.User{Email: email, PasswordHash: string(hash), Role: role} + if err := s.DB.Create(user).Error; err != nil { + return nil, "", err + } + token, err := auth.CreateToken(user.ID, user.Email, user.Role, 24*time.Hour) + if err != nil { + return nil, "", err + } + return user, token, nil +} + +func (s *AuthService) Login(email, password string) (*models.User, string, error) { + email = strings.TrimSpace(strings.ToLower(email)) + var user models.User + if err := s.DB.Where("email = ?", email).First(&user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, "", errors.New("invalid credentials") + } + return nil, "", err + } + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return nil, "", errors.New("invalid credentials") + } + token, err := auth.CreateToken(user.ID, user.Email, user.Role, 24*time.Hour) + if err != nil { + return nil, "", err + } + return &user, token, nil +} \ No newline at end of file diff --git a/KimmyXYC/internal/services/chat_service.go b/KimmyXYC/internal/services/chat_service.go new file mode 100644 index 0000000..ada1c43 --- /dev/null +++ b/KimmyXYC/internal/services/chat_service.go @@ -0,0 +1,123 @@ +package services + +import ( + "context" + "errors" + "strings" + "time" + + "gorm.io/gorm" + + "AIBackend/internal/models" + "AIBackend/internal/provider" +) + +type ChatService struct { + DB *gorm.DB + LLM provider.LLMProvider + MaxTurns int // number of last messages to keep in context +} + +// NewChatService 创建一个 ChatService,使用提供的数据库句柄和 LLM 提供者,并将 MaxTurns 初始化为 10。 +func NewChatService(db *gorm.DB, llm provider.LLMProvider) *ChatService { + return &ChatService{DB: db, LLM: llm, MaxTurns: 10} +} + +// EnsureConversation ensures conversation exists (and belongs to user), creating if needed. +func (s *ChatService) EnsureConversation(userID uint, convID uint, model string, title string) (*models.Conversation, error) { + if convID != 0 { + var conv models.Conversation + if err := s.DB.Where("id = ? AND user_id = ?", convID, userID).First(&conv).Error; err != nil { + return nil, err + } + return &conv, nil + } + conv := &models.Conversation{UserID: userID, Title: title, Model: model} + if conv.Title == "" { + conv.Title = "New Chat" + } + if err := s.DB.Create(conv).Error; err != nil { + return nil, err + } + return conv, nil +} + +// ListConversations returns user's conversations. +func (s *ChatService) ListConversations(userID uint) ([]models.Conversation, error) { + var convs []models.Conversation + if err := s.DB.Where("user_id = ?", userID).Order("updated_at desc").Find(&convs).Error; err != nil { + return nil, err + } + return convs, nil +} + +// GetMessages returns messages for a conversation if owned by user. +func (s *ChatService) GetMessages(userID, convID uint) ([]models.Message, error) { + var conv models.Conversation + if err := s.DB.Where("id = ? AND user_id = ?", convID, userID).First(&conv).Error; err != nil { + return nil, err + } + var msgs []models.Message + if err := s.DB.Where("conversation_id = ?", convID).Order("id asc").Find(&msgs).Error; err != nil { + return nil, err + } + return msgs, nil +} + +// SendMessage adds a user message, streams assistant reply via callback, and persists the assistant message. +func (s *ChatService) SendMessage(ctx context.Context, userID uint, convID uint, model string, userText string, stream func(chunk string) error) (uint, string, error) { + userText = strings.TrimSpace(userText) + if userText == "" { + return 0, "", errors.New("message content required") + } + conv, err := s.EnsureConversation(userID, convID, model, "") + if err != nil { + return 0, "", err + } + // Save user message + um := &models.Message{ConversationID: conv.ID, Role: "user", Content: userText} + if err := s.DB.Create(um).Error; err != nil { + return 0, "", err + } + // Load recent messages for context + var msgs []models.Message + s.DB.Where("conversation_id = ?", conv.ID).Order("id desc").Limit(s.MaxTurns * 2).Find(&msgs) + // reverse to chronological + for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { + msgs[i], msgs[j] = msgs[j], msgs[i] + } + llmMsgs := make([]provider.ChatMessage, 0, len(msgs)) + for _, m := range msgs { + llmMsgs = append(llmMsgs, provider.ChatMessage{Role: m.Role, Content: m.Content}) + } + // Stream assistant reply + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + ch, err := s.LLM.ChatCompletionStream(ctx, conv.Model, llmMsgs) + if err != nil { + return 0, "", err + } + assistantContent := strings.Builder{} + for chunk := range ch { + if chunk.Err != nil { + return conv.ID, "", chunk.Err + } + if chunk.Content != "" { + assistantContent.WriteString(chunk.Content) + if stream != nil { + if err := stream(chunk.Content); err != nil { + return conv.ID, "", err + } + } + } + if chunk.Done { + break + } + } + // Save assistant message + am := &models.Message{ConversationID: conv.ID, Role: "assistant", Content: assistantContent.String()} + if err := s.DB.Create(am).Error; err != nil { + return conv.ID, "", err + } + return conv.ID, am.Content, nil +} \ No newline at end of file diff --git a/KimmyXYC/pkg/auth/token.go b/KimmyXYC/pkg/auth/token.go new file mode 100644 index 0000000..cece833 --- /dev/null +++ b/KimmyXYC/pkg/auth/token.go @@ -0,0 +1,67 @@ +package auth + +import ( + "errors" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var defaultSecret = []byte("dev-secret-change-me") + +// jwtSecret 返回用于签名 JWT 的密钥字节切片。 +// 它优先使用环境变量 JWT_SECRET 的非空值,若未设置或为空则返回包级默认密钥 defaultSecret. +func jwtSecret() []byte { + if s := os.Getenv("JWT_SECRET"); s != "" { + return []byte(s) + } + return defaultSecret +} + +// Claims represents JWT claims for a user session. +type Claims struct { + UserID uint `json:"user_id"` + Email string `json:"email"` + Role string `json:"role"` + jwt.RegisteredClaims +} + +// CreateToken 为给定用户生成并签名一个 JWT。 +// +// 生成的令牌包含用户标识(UserID)、邮箱(Email)、角色(Role)以及标准注册声明: +// 设置过期时间为当前时间加上 ttl,设置签发时间为当前时间。令牌使用 HS256 签名并由内部密钥签名。 +// 返回签名后的 JWT 字符串;签名过程中发生错误则返回该错误。 +func CreateToken(userID uint, email, role string, ttl time.Duration) (string, error) { + claims := Claims{ + UserID: userID, + Email: email, + Role: role, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return t.SignedString(jwtSecret()) +} + +// ParseToken 解析并验证给定的 JWT 字符串,成功时返回其中的 Claims。 +// 如果解析过程出错会返回该错误;当 token 验证未通过时返回 "invalid token" 错误; +// 当提取到的声明不能断言为 *Claims 时返回 "invalid claims" 错误。 +func ParseToken(token string) (*Claims, error) { + tok, err := jwt.ParseWithClaims(token, &Claims{}, func(t *jwt.Token) (interface{}, error) { + return jwtSecret(), nil + }) + if err != nil { + return nil, err + } + if !tok.Valid { + return nil, errors.New("invalid token") + } + claims, ok := tok.Claims.(*Claims) + if !ok { + return nil, errors.New("invalid claims") + } + return claims, nil +} \ No newline at end of file diff --git a/KimmyXYC/pkg/middleware/auth.go b/KimmyXYC/pkg/middleware/auth.go new file mode 100644 index 0000000..9d0d73a --- /dev/null +++ b/KimmyXYC/pkg/middleware/auth.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + + "AIBackend/pkg/auth" +) + +// Allowed models by role (exported for reuse) +var AllowedModelsByRole = map[string][]string{ + "free": {"mock-mini", "gpt-4o-mini"}, + "pro": {"mock-mini", "mock-pro", "gpt-4o-mini", "gpt-4o"}, + "admin": {"mock-mini", "mock-pro", "mock-admin", "gpt-4o-mini", "gpt-4o", "gpt-4.1"}, +} + +// CheckModelAccess 验证给定角色是否被允许使用指定模型。 +// 当 model 为空字符串时视为允许;否则在 AllowedModelsByRole 中查找该角色的允许列表并进行精确匹配。 +// 返回 `true` 表示角色被允许使用该模型,`false` 表示不允许(包括角色不存在于映射时)。 +func CheckModelAccess(role, model string) bool { + if model == "" { + return true + } + list := AllowedModelsByRole[role] + for _, m := range list { + if m == model { + return true + } + } + return false +} + +// AuthRequired 返回一个 Gin 中间件,用于验证 Authorization Bearer JWT 并在成功时将用户信息存入上下文。 +// 如果请求缺少或格式错误的 Bearer 令牌,或令牌解析失败,响应 401 并中止请求。 +// 成功时在上下文中设置 "user_id"、"user_email" 和 "user_role" 三个键,然后继续处理链。 +func AuthRequired() gin.HandlerFunc { + return func(c *gin.Context) { + h := c.GetHeader("Authorization") + if h == "" || !strings.HasPrefix(h, "Bearer ") { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing bearer token"}) + return + } + token := strings.TrimPrefix(h, "Bearer ") + claims, err := auth.ParseToken(token) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + return + } + c.Set("user_id", claims.UserID) + c.Set("user_email", claims.Email) + c.Set("user_role", claims.Role) + c.Next() + } +} + +// ModelAccess 在存在 "model" 查询参数时根据上下文中的用户角色强制模型访问控制。 +// 如果上下文中未设置 `user_role` 或为空,则视为 "free" 角色。 +// 当查询参数 `model` 缺失时,跳过此中间件的访问检查以便后续处理器自行验证。 +// 若访问被拒绝,中间件会以 403 状态并返回 JSON 错误信息终止请求。 +func ModelAccess() gin.HandlerFunc { + return func(c *gin.Context) { + role, _ := c.Get("user_role") + roleStr := "free" + if r, ok := role.(string); ok && r != "" { + roleStr = r + } + reqModel := c.Query("model") + if reqModel == "" { + // body may contain model; handler should validate with CheckModelAccess + c.Next() + return + } + if !CheckModelAccess(roleStr, reqModel) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "model access denied for role"}) + return + } + c.Next() + } +} \ No newline at end of file diff --git a/KimmyXYC/web/js/api.js b/KimmyXYC/web/js/api.js new file mode 100644 index 0000000..4d7bb90 --- /dev/null +++ b/KimmyXYC/web/js/api.js @@ -0,0 +1,186 @@ +import { getToken } from './state.js'; + +/** + * 发起带自动令牌注入与统一响应解析的 HTTP 请求并返回解析后的响应内容。 + * + * 发送请求时会在请求头中添加 `Accept: application/json`,若存在 token 则添加 `Authorization: Bearer `。 + * + * @param {string} path - 请求的 URL 或相对路径。 + * @param {Object} [options] - 请求选项。 + * @param {string} [options.method='GET'] - HTTP 方法。 + * @param {Object} [options.headers={}] - 额外请求头,会与默认头合并(默认头可被覆盖)。 + * @param {string|Blob|FormData|URLSearchParams|ReadableStream|undefined} [options.body] - 请求体。 + * @returns {any} 如果响应 Content-Type 包含 `application/json` 则返回解析后的 JSON,否则返回响应文本。 + * @throws {Error} 当响应状态不是 ok 时,抛出包含服务端错误信息或状态文本的 Error。 + */ +async function api(path, { method = 'GET', headers = {}, body = undefined } = {}) { + const h = { 'Accept': 'application/json', ...headers }; + const token = getToken(); + if (token) h['Authorization'] = `Bearer ${token}`; + const res = await fetch(path, { method, headers: h, body }); + if (!res.ok) { + let errText = await res.text().catch(() => ''); + try { const j = JSON.parse(errText); errText = j.error || errText; } catch {} + throw new Error(errText || `${res.status} ${res.statusText}`); + } + const ct = res.headers.get('Content-Type') || ''; + if (ct.includes('application/json')) return res.json(); + return res.text(); +} + +/** + * 注册新用户并返回服务器的响应数据。 + * @param {string} email - 用户邮箱。 + * @param {string} password - 用户密码。 + * @param {string} [role='free'] - 用户角色,常见取值包括 'free'、'pro'、'admin';默认为 'free'。 + * @returns {any} 服务器返回的解析后响应(通常为 JSON,可能包含用户信息和/或认证令牌)。 + */ +export async function register(email, password, role = 'free') { + return api('/api/auth/register', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ email, password, role }), + }); +} + +/** + * 使用邮箱和密码对用户进行登录。 + * @param {string} email - 要登录的用户邮箱地址。 + * @param {string} password - 用户密码。 + * @returns {any} 服务器返回的解析结果(通常包含认证信息和用户数据)。 + */ +export async function login(email, password) { + return api('/api/auth/login', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ email, password }), + }); +} + +/** + * 获取当前已认证用户的信息。 + * @returns {Object} 当前用户的详细信息对象。 + */ +export async function me() { + return api('/api/me'); +} + +/** + * 获取当前用户的会话列表。 + * @returns {Object[]} 包含会话对象的数组,每个对象表示一个会话。 + */ +export async function listConversations() { + return api('/api/conversations'); +} + +/** + * 获取指定对话的消息列表。 + * @param {string|number} convId - 要检索消息的对话 ID。 + * @returns {Array} 该对话的消息数组。 + */ +export async function getMessages(convId) { + return api(`/api/conversations/${convId}/messages`); +} + +/** + * 向服务器发送一条聊天消息并返回服务器响应。 + * + * @param {Object} options - 发送选项。 + * @param {number} [options.conversation_id=0] - 目标会话的 ID。 + * @param {string} [options.model='mock-mini'] - 使用的模型名称。 + * @param {string} [options.message=''] - 要发送的消息文本。 + * @returns {any} 服务器返回的已解析响应对象。 + */ +export async function sendChat({ conversation_id = 0, model = 'mock-mini', message = '' }) { + return api('/api/chat', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ conversation_id, model, message, stream: false }), + }); +} + +/** + * 以流式方式发送聊天消息并按事件逐块处理服务器返回的数据。 + * + * @param {{conversation_id?: number, model?: string, message?: string}} params - 请求参数。 + * @param {number} [params.conversation_id=0] - 初始会话 ID,若为 0 则表示新会话。 + * @param {string} [params.model='mock-mini'] - 要使用的模型名称。 + * @param {string} [params.message=''] - 要发送的消息内容。 + * @param {{onChunk?: (chunk: string) => void, onDone?: (result: {conversation_id: number}) => void}} [options] - 回调配置。 + * @param {(chunk: string) => void} [options.onChunk] - 在接收到非 `done` 事件的数据段时被调用,参数为该数据字符串。 + * @param {(result: {conversation_id: number}) => void} [options.onDone] - 在接收到 `done` 事件或流结束时被调用,参数包含最新的 `conversation_id`。 + * @returns {{conversation_id: number}} 包含发送完成后(或服务器指示完成时)最终的会话 ID。 + */ +export async function chatStream({ conversation_id = 0, model = 'mock-mini', message = '' }, { onChunk, onDone } = {}) { + const res = await fetch('/api/chat?stream=1', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${getToken()}`, + }, + body: JSON.stringify({ conversation_id, model, message, stream: true }), + }); + if (!res.ok || !res.body) { + let t = await res.text().catch(() => ''); + try { const j = JSON.parse(t); t = j.error || t; } catch {} + throw new Error(t || `${res.status} ${res.statusText}`); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let convId = conversation_id; + + const flushEvents = () => { + let idx; + while ((idx = buffer.indexOf('\n\n')) >= 0) { + const evt = buffer.slice(0, idx); + buffer = buffer.slice(idx + 2); + const lines = evt.split('\n'); + let eventName = 'message'; + const dataLines = []; + for (const line of lines) { + if (line.startsWith('event:')) eventName = line.slice(6).trim(); + else if (line.startsWith('data:')) dataLines.push(line.slice(5).replace(/^\s*/, '')); + } + const data = dataLines.join('\n'); + if (eventName === 'done') { + try { + const obj = JSON.parse(data); + if (obj.conversation_id) convId = obj.conversation_id; + } catch {} + if (onDone) onDone({ conversation_id: convId }); + } else { + if (onChunk) onChunk(data); + } + } + }; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + flushEvents(); + } + // flush any remainder + flushEvents(); + return { conversation_id: convId }; +} + +export const AllowedModelsByRole = { + free: ['mock-mini', 'gpt-4o-mini'], + pro: ['mock-mini', 'mock-pro', 'gpt-4o-mini', 'gpt-4o'], + admin: ['mock-mini', 'mock-pro', 'mock-admin', 'gpt-4o-mini', 'gpt-4o', 'gpt-4.1'], +}; + +/** + * 检查指定角色是否允许使用某个模型。 + * @param {string} role - 用户角色名称(例如 'free'、'pro'、'admin')。 + * @param {string|undefined|null} model - 要验证的模型名称;若为假值则视为允许。 + * @returns {boolean} `true` 如果 model 为假值 或者该角色的允许列表包含该模型,`false` 否则. + */ +export function roleAllowsModel(role, model) { + if (!model) return true; + const list = AllowedModelsByRole[role] || []; + return list.includes(model); +} \ No newline at end of file diff --git a/KimmyXYC/web/js/auth.js b/KimmyXYC/web/js/auth.js new file mode 100644 index 0000000..1c3d9d6 --- /dev/null +++ b/KimmyXYC/web/js/auth.js @@ -0,0 +1,51 @@ +import { setToken, setUser } from './state.js'; +import { login, register, me } from './api.js'; + +/** + * 初始化登录与注册表单的交互行为并在成功认证后更新本地认证状态。 + * + * 在页面中查找具有特定 id 的登录/注册表单与相关输入和错误显示元素,绑定提交事件:提交时调用相应的 API(login 或 register),保存返回的 token,获取并存储当前用户信息,认证成功后调用可选回调。 + * @param {Object} [options] - 可选配置对象。 + * @param {Function} [options.onAuthenticated] - 在成功完成认证并更新用户状态后调用的回调函数(无参数)。 + */ +export function initAuthUI({ onAuthenticated } = {}) { + const loginForm = document.getElementById('login-form'); + const loginEmail = document.getElementById('login-email'); + const loginPassword = document.getElementById('login-password'); + const loginError = document.getElementById('login-error'); + + const regForm = document.getElementById('register-form'); + const regEmail = document.getElementById('register-email'); + const regPassword = document.getElementById('register-password'); + const regRole = document.getElementById('register-role'); + const regError = document.getElementById('register-error'); + + loginForm.addEventListener('submit', async (e) => { + e.preventDefault(); + loginError.textContent = ''; + try { + const resp = await login(loginEmail.value.trim(), loginPassword.value); + setToken(resp.token); + // Get me to store role/email consistently + const profile = await me(); + setUser({ email: profile.user_email, role: profile.user_role, id: profile.user_id }); + if (onAuthenticated) onAuthenticated(); + } catch (err) { + loginError.textContent = err.message || '登录失败'; + } + }); + + regForm.addEventListener('submit', async (e) => { + e.preventDefault(); + regError.textContent = ''; + try { + const resp = await register(regEmail.value.trim(), regPassword.value, regRole.value); + setToken(resp.token); + const profile = await me(); + setUser({ email: profile.user_email, role: profile.user_role, id: profile.user_id }); + if (onAuthenticated) onAuthenticated(); + } catch (err) { + regError.textContent = err.message || '注册失败'; + } + }); +} \ No newline at end of file diff --git a/KimmyXYC/web/js/chat.js b/KimmyXYC/web/js/chat.js new file mode 100644 index 0000000..8c6faa4 --- /dev/null +++ b/KimmyXYC/web/js/chat.js @@ -0,0 +1,150 @@ +import { chatStream, sendChat, listConversations, getMessages, roleAllowsModel } from './api.js'; +import { getUser } from './state.js'; + +/** + * 初始化聊天界面:展示用户信息、加载会话列表与消息,并绑定发送、创建会话和模型选择等交互。 + * + * 初始化后会: + * - 在界面上显示当前用户的邮箱与角色并根据角色提示模型使用权限; + * - 加载并渲染会话列表,支持切换会话以加载对应消息; + * - 支持新建会话(首次发送时由后端创建)并清空消息视图; + * - 处理消息发送,支持非流式一次性响应和流式增量渲染助理回复; + * - 在发送失败时将错误以助理消息形式显示在聊天窗口。 + */ +export function initChatUI() { + const userInfoEmail = document.getElementById('user-email'); + const userInfoRole = document.getElementById('user-role'); + const modelSelect = document.getElementById('model-select'); + const streamToggle = document.getElementById('stream-toggle'); + const convList = document.getElementById('conv-list'); + const newChatBtn = document.getElementById('new-chat-btn'); + const messagesEl = document.getElementById('messages'); + const chatForm = document.getElementById('chat-form'); + const chatInput = document.getElementById('chat-input'); + const sendHint = document.getElementById('send-hint'); + + const user = getUser(); + userInfoEmail.textContent = user?.email || ''; + userInfoRole.textContent = user?.role || 'free'; + + let currentConv = 0; + let sending = false; + + const updateModelHint = () => { + const model = modelSelect.value; + const allowed = roleAllowsModel(user?.role || 'free', model); + if (!allowed) { + sendHint.textContent = `当前角色无权使用 ${model},尝试发送将被后端拒绝`; + } else { + sendHint.textContent = ''; + } + }; + modelSelect.addEventListener('change', updateModelHint); + updateModelHint(); + + const scrollToBottom = () => { + messagesEl.scrollTop = messagesEl.scrollHeight; + }; + + const fmtTime = (iso) => { + try { return new Date(iso).toLocaleString(); } catch { return ''; } + }; + + const renderMessage = (m) => { + const div = document.createElement('div'); + div.className = `message ${m.role}`; + const meta = document.createElement('div'); + meta.className = 'meta'; + meta.textContent = `${m.role}`; + const content = document.createElement('div'); + content.className = 'content'; + content.textContent = m.content || ''; + div.appendChild(meta); + div.appendChild(content); + messagesEl.appendChild(div); + scrollToBottom(); + return content; // return content node for streaming update + }; + + const clearMessages = () => { messagesEl.innerHTML = ''; }; + + async function loadConversations() { + convList.innerHTML = ''; + try { + const resp = await listConversations(); + const convs = resp.conversations || []; + for (const c of convs) { + const li = document.createElement('li'); + li.dataset.id = c.id; + li.className = (c.id === currentConv) ? 'active' : ''; + const title = c.title || `对话 #${c.id}`; + li.innerHTML = `
${title}
${c.model || ''}`; + li.addEventListener('click', async () => { + currentConv = c.id; + document.querySelectorAll('#conv-list li').forEach(x => x.classList.remove('active')); + li.classList.add('active'); + await loadMessages(c.id); + }); + convList.appendChild(li); + } + } catch (err) { + console.error('加载会话失败', err); + } + } + + async function loadMessages(convId) { + clearMessages(); + if (!convId) return; + try { + const resp = await getMessages(convId); + const msgs = resp.messages || []; + for (const m of msgs) renderMessage(m); + } catch (err) { + console.error('加载消息失败', err); + } + } + + newChatBtn.addEventListener('click', () => { + currentConv = 0; // backend will create on first send + clearMessages(); + chatInput.focus(); + }); + + chatForm.addEventListener('submit', async (e) => { + e.preventDefault(); + if (sending) return; + const text = chatInput.value.trim(); + if (!text) return; + sending = true; + try { + const model = modelSelect.value || 'mock-mini'; + // render user message immediately + renderMessage({ role: 'user', content: text }); + chatInput.value = ''; + + const doStream = streamToggle.checked; + if (!doStream) { + const r = await sendChat({ conversation_id: currentConv, model, message: text }); + currentConv = r.conversation_id || currentConv; + renderMessage({ role: 'assistant', content: r.reply || '' }); + } else { + let assistantNode = renderMessage({ role: 'assistant', content: '' }); + await chatStream( + { conversation_id: currentConv, model, message: text }, + { + onChunk: (chunk) => { assistantNode.textContent += chunk; scrollToBottom(); }, + onDone: ({ conversation_id }) => { if (conversation_id) currentConv = conversation_id; }, + } + ); + } + await loadConversations(); + } catch (err) { + renderMessage({ role: 'assistant', content: `错误:${err.message || err}` }); + } finally { + sending = false; + } + }); + + // Load initial + loadConversations(); +} \ No newline at end of file diff --git a/KimmyXYC/web/js/main.js b/KimmyXYC/web/js/main.js new file mode 100644 index 0000000..200d7ff --- /dev/null +++ b/KimmyXYC/web/js/main.js @@ -0,0 +1,65 @@ +import { isLoggedIn, clearToken, clearUser, setUser } from './state.js'; +import { me } from './api.js'; +import { initAuthUI } from './auth.js'; +import { initChatUI } from './chat.js'; + +/** + * 隐藏所有具有 `view` 类的元素并显示指定 id 的视图元素。 + * @param {string} id - 要显示的视图元素的 DOM id。 */ +function show(id) { + document.querySelectorAll('.view').forEach(v => v.classList.add('hidden')); + document.getElementById(id).classList.remove('hidden'); +} + +/** + * 显示主应用视图并初始化聊天界面。 + * + * 调用后页面切换至应用主视图并启动聊天相关的 UI 初始化流程。 + */ +async function enterApp() { + show('app-view'); + initChatUI(); +} + +/** + * 显示认证视图并初始化认证界面;认证成功后切换到主应用界面。 + */ +async function enterAuth() { + show('auth-view'); + initAuthUI({ onAuthenticated: enterApp }); +} + +/** + * 初始化并引导应用:注册登出处理器,检查会话并根据令牌状态切换到认证界面或主界面。 + * + * 如果存在登出按钮,注册其点击处理以清除会话并刷新页面;如果没有登录则进入认证流程;若已登录则验证令牌、设置当前用户并进入主界面;在令牌验证失败时清除会话并重新进入认证流程。 + */ +async function bootstrap() { + const logoutBtn = document.getElementById('logout-btn'); + if (logoutBtn) { + logoutBtn.addEventListener('click', () => { + clearToken(); + clearUser(); + location.reload(); + }); + } + + if (!isLoggedIn()) { + await enterAuth(); + return; + } + + // Validate token and fetch profile + try { + const profile = await me(); + setUser({ email: profile.user_email, role: profile.user_role, id: profile.user_id }); + await enterApp(); + } catch (err) { + console.warn('Token invalid, returning to auth', err); + clearToken(); + clearUser(); + await enterAuth(); + } +} + +window.addEventListener('DOMContentLoaded', bootstrap); \ No newline at end of file diff --git a/KimmyXYC/web/js/state.js b/KimmyXYC/web/js/state.js new file mode 100644 index 0000000..2624742 --- /dev/null +++ b/KimmyXYC/web/js/state.js @@ -0,0 +1,61 @@ +const TOKEN_KEY = 'aib_token'; +const USER_KEY = 'aib_user'; + +/** + * 获取存储的认证令牌字符串。 + * @returns {string} 存储的令牌字符串;若不存在则返回空字符串。 + */ +export function getToken() { + return localStorage.getItem(TOKEN_KEY) || ''; +} + +/** + * 将提供的令牌保存到 localStorage 中(仅在令牌为真值时)。 + * @param {string} t - 要保存的令牌;若为假值(例如 `''`, `null`, `undefined`)则不执行任何操作。 + */ +export function setToken(t) { + if (t) localStorage.setItem(TOKEN_KEY, t); +} + +/** + * 从 localStorage 中移除用于存储认证 token 的项。 + */ +export function clearToken() { + localStorage.removeItem(TOKEN_KEY); +} + +/** + * 从 localStorage 中读取并解析用户数据。 + * + * 如果未找到数据或数据不是有效的 JSON,则返回 null。 + * @returns {Object|null} 解析后的用户对象;当不存在或解析失败时返回 `null`。 + */ +export function getUser() { + const raw = localStorage.getItem(USER_KEY); + if (!raw) return null; + try { return JSON.parse(raw); } catch { return null; } +} + +/** + * 将用户对象序列化并存储到 localStorage 中(仅当提供非空值时)。 + * @param {Object} u - 要持久化的用户数据对象;如果为 falsy(例如 null 或 undefined),则不执行任何操作。 + */ +export function setUser(u) { + if (u) localStorage.setItem(USER_KEY, JSON.stringify(u)); +} + +/** + * 从 localStorage 中删除与 USER_KEY 对应的用户数据。 + */ +export function clearUser() { + localStorage.removeItem(USER_KEY); +} + +/** + * 判断当前是否已登录(基于本地存储中是否存在认证 token)。 + * + * @returns {boolean} `true` if a token exists in storage, `false` otherwise. + */ +export function isLoggedIn() { + return !!getToken(); +} \ No newline at end of file