Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/gateway/image_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func (h *ImagesHandler) dispatchImageToChannel(c *gin.Context,
rec.Status = usage.StatusSuccess
rec.ModelID = m.ID
rec.CreditCost = finalCost
rec.ImageCount = actualCount(result)

c.JSON(http.StatusOK, ImageGenResponse{
Created: time.Now().Unix(),
Expand Down
19 changes: 11 additions & 8 deletions internal/gateway/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ func (h *ImagesHandler) ImageGenerations(c *gin.Context) {
// 7) usage
rec.Status = usage.StatusSuccess
rec.CreditCost = cost
rec.ImageCount = len(res.SignedURLs)

// 8) DAO 回写 credit_cost(Runner 已经 MarkSuccess,这里只补 credit_cost)
if h.DAO != nil {
Expand Down Expand Up @@ -363,14 +364,14 @@ func (h *ImagesHandler) ImageTask(c *gin.Context) {
}

c.JSON(http.StatusOK, gin.H{
"task_id": t.TaskID,
"status": t.Status,
"conversation_id": t.ConversationID,
"created": t.CreatedAt.Unix(),
"finished_at": nullableUnix(t.FinishedAt),
"error": t.Error,
"credit_cost": t.CreditCost,
"data": data,
"task_id": t.TaskID,
"status": t.Status,
"conversation_id": t.ConversationID,
"created": t.CreatedAt.Unix(),
"finished_at": nullableUnix(t.FinishedAt),
"error": t.Error,
"credit_cost": t.CreditCost,
"data": data,
})
}

Expand Down Expand Up @@ -496,6 +497,7 @@ func (h *ImagesHandler) handleChatAsImage(c *gin.Context, rec *usage.Log, ak *ap

rec.Status = usage.StatusSuccess
rec.CreditCost = cost
rec.ImageCount = len(res.SignedURLs)
rec.DurationMs = int(time.Since(startAt).Milliseconds())

// 以 chat 响应返回(content 里内嵌 markdown 图片)。
Expand Down Expand Up @@ -809,6 +811,7 @@ func (h *ImagesHandler) ImageEdits(c *gin.Context) {

rec.Status = usage.StatusSuccess
rec.CreditCost = cost
rec.ImageCount = len(res.SignedURLs)
if h.DAO != nil {
_ = h.DAO.UpdateCost(c.Request.Context(), taskID, cost)
}
Expand Down
60 changes: 22 additions & 38 deletions internal/gateway/images_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ package gateway

import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -55,46 +51,14 @@ type ImageAccountResolver interface {
ProxyURL(ctx context.Context, accountID uint64) string
}

// imageProxySecret 进程级随机密钥,用于 HMAC 签名图片 URL。
// 进程重启后旧的签名 URL 全部失效,这是故意的(防止长期有效的 URL 泄漏)。
var imageProxySecret []byte

func init() {
imageProxySecret = make([]byte, 32)
if _, err := rand.Read(imageProxySecret); err != nil {
for i := range imageProxySecret {
imageProxySecret[i] = byte(i*31 + 7)
}
}
}

// ImageProxyTTL 单条签名 URL 的默认有效期(24h,够前端离线展示一段时间)。
const ImageProxyTTL = 24 * time.Hour

// BuildImageProxyURL 生成代理 URL。返回绝对 path(不含 host),调用方可以直接拼或交给前端同 origin 使用。
//
// 默认 ttl=24h。前端展示一张历史图片,最多走一次上游获取 bytes,之后浏览器缓存即可。
func BuildImageProxyURL(taskID string, idx int, ttl time.Duration) string {
if ttl <= 0 {
ttl = ImageProxyTTL
}
expMs := time.Now().Add(ttl).UnixMilli()
sig := computeImgSig(taskID, idx, expMs)
return fmt.Sprintf("/p/img/%s/%d?exp=%d&sig=%s", taskID, idx, expMs, sig)
}

func computeImgSig(taskID string, idx int, expMs int64) string {
mac := hmac.New(sha256.New, imageProxySecret)
fmt.Fprintf(mac, "%s|%d|%d", taskID, idx, expMs)
return hex.EncodeToString(mac.Sum(nil))[:24]
}

func verifyImgSig(taskID string, idx int, expMs int64, sig string) bool {
if expMs < time.Now().UnixMilli() {
return false
}
want := computeImgSig(taskID, idx, expMs)
return hmac.Equal([]byte(sig), []byte(want))
return image.BuildProxyURL(taskID, idx, ttl)
}

// ImageProxy 按签名代理下载上游图片。无需 API Key,只靠 URL 签名校验。
Expand All @@ -113,12 +77,17 @@ func (h *ImagesHandler) ImageProxy(c *gin.Context) {
c.AbortWithStatus(http.StatusBadRequest)
return
}
thumbKB, err := strconv.Atoi(c.DefaultQuery("thumb_kb", "0"))
if err != nil || thumbKB < 0 || thumbKB > 64 {
c.AbortWithStatus(http.StatusBadRequest)
return
}
expMs, err := strconv.ParseInt(expStr, 10, 64)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
if !verifyImgSig(taskID, idx, expMs, sig) {
if !image.VerifyImgSig(taskID, idx, expMs, sig) {
c.AbortWithStatus(http.StatusForbidden)
return
}
Expand Down Expand Up @@ -179,6 +148,9 @@ func (h *ImagesHandler) ImageProxy(c *gin.Context) {
// 按需放大:若 task 上打了 upscale 标记,先走进程内 LRU,命中则直接返回。
// 未命中再拉原图,放大成 PNG 后写入缓存。
scale := image.ValidateUpscale(t.Upscale)
if thumbKB > 0 {
scale = ""
}
cacheKey := ""
if scale != "" {
cacheKey = fmt.Sprintf("%s|%d|%s", taskID, idx, scale)
Expand All @@ -200,6 +172,18 @@ func (h *ImagesHandler) ImageProxy(c *gin.Context) {
if ct == "" {
ct = "image/png"
}
if thumbKB > 0 {
thumbBytes, thumbCT, err := image.MakeThumbJPEG(body, thumbKB*1024)
if err != nil {
logger.L().Warn("image proxy thumb",
zap.Error(err), zap.String("task_id", taskID),
zap.Int("thumb_kb", thumbKB))
} else {
body = thumbBytes
ct = thumbCT
c.Header("X-Thumb-KB", strconv.Itoa(thumbKB))
}
}

if scale != "" {
// 并发闸:避免 4K 请求风暴把 CPU 打满影响生图主流程
Expand Down
11 changes: 9 additions & 2 deletions internal/image/admin_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package image

import (
"strconv"
"time"

"github.com/gin-gonic/gin"

Expand Down Expand Up @@ -46,16 +47,22 @@ func (h *AdminHandler) List(c *gin.Context) {
return
}

// 把 result_urls JSON bytes 解成可读字符串数组后输出
// 把 result_urls JSON bytes 解成代理 URL 后输出
type rowOut struct {
AdminTaskRow
ResultURLsParsed []string `json:"result_urls_parsed"`
}
out := make([]rowOut, 0, len(rows))
for _, r := range rows {
// 生成代理 URL 而不是直接返回上游 URL
fids := r.DecodeFileIDs()
urls := make([]string, 0, len(fids))
for i := range fids {
urls = append(urls, BuildProxyURL(r.TaskID, i, 24*time.Hour))
}
out = append(out, rowOut{
AdminTaskRow: r,
ResultURLsParsed: r.DecodeResultURLs(),
ResultURLsParsed: urls,
})
}

Expand Down
33 changes: 30 additions & 3 deletions internal/image/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -111,20 +112,46 @@ SELECT id, task_id, user_id, key_id, model_id, account_id, prompt, n, size, upsc
return &t, nil
}

// MyTaskFilter 当前用户图片任务筛选条件。
type MyTaskFilter struct {
Status string
Keyword string
CreatedAt *time.Time
CreatedTo *time.Time
}

// ListByUser 按用户分页。
func (d *DAO) ListByUser(ctx context.Context, userID uint64, limit, offset int) ([]Task, error) {
func (d *DAO) ListByUser(ctx context.Context, userID uint64, f MyTaskFilter, limit, offset int) ([]Task, error) {
if limit <= 0 {
limit = 20
}
where := []string{"user_id = ?"}
args := []interface{}{userID}
if f.Status != "" {
where = append(where, "status = ?")
args = append(args, f.Status)
}
if f.Keyword != "" {
where = append(where, "prompt LIKE ?")
args = append(args, "%"+f.Keyword+"%")
}
if f.CreatedAt != nil {
where = append(where, "created_at >= ?")
args = append(args, *f.CreatedAt)
}
if f.CreatedTo != nil {
where = append(where, "created_at <= ?")
args = append(args, *f.CreatedTo)
}
var out []Task
err := d.db.SelectContext(ctx, &out, `
SELECT id, task_id, user_id, key_id, model_id, account_id, prompt, n, size, upscale, status,
conversation_id, file_ids, result_urls, error, estimated_credit, credit_cost,
created_at, started_at, finished_at
FROM image_tasks
WHERE user_id = ?
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id DESC
LIMIT ? OFFSET ?`, userID, limit, offset)
LIMIT ? OFFSET ?`, append(args, limit, offset)...)
return out, err
}

Expand Down
81 changes: 61 additions & 20 deletions internal/image/me_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,50 @@ func NewMeHandler(dao *DAO) *MeHandler { return &MeHandler{dao: dao} }

// taskView 是对外返回的视图结构,解码 JSON 列 + 隐藏内部字段。
type taskView struct {
ID uint64 `json:"id"`
TaskID string `json:"task_id"`
UserID uint64 `json:"user_id"`
ModelID uint64 `json:"model_id"`
AccountID uint64 `json:"account_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Upscale string `json:"upscale,omitempty"`
Status string `json:"status"`
ConversationID string `json:"conversation_id,omitempty"`
Error string `json:"error,omitempty"`
CreditCost int64 `json:"credit_cost"`
ImageURLs []string `json:"image_urls"`
FileIDs []string `json:"file_ids,omitempty"`
CreatedAt time.Time `json:"created_at"`
ID uint64 `json:"id"`
TaskID string `json:"task_id"`
UserID uint64 `json:"user_id"`
ModelID uint64 `json:"model_id"`
AccountID uint64 `json:"account_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Upscale string `json:"upscale,omitempty"`
Status string `json:"status"`
ConversationID string `json:"conversation_id,omitempty"`
Error string `json:"error,omitempty"`
CreditCost int64 `json:"credit_cost"`
ImageURLs []string `json:"image_urls"`
FileIDs []string `json:"file_ids,omitempty"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
}

func toView(t *Task) taskView {
urls := t.DecodeResultURLs()
fids := t.DecodeFileIDs()
for i, id := range fids {
fids[i] = strings.TrimPrefix(id, "sed:")
}

// 生成代理 URL 而不是直接返回上游 URL(防止 403)
urls := make([]string, 0, len(fids))
for i := range fids {
urls = append(urls, BuildProxyURL(t.TaskID, i, 24*time.Hour))
}

return taskView{
ID: t.ID, TaskID: t.TaskID, UserID: t.UserID, ModelID: t.ModelID,
AccountID: t.AccountID, Prompt: t.Prompt, N: t.N, Size: t.Size,
Upscale: t.Upscale,
Status: t.Status, ConversationID: t.ConversationID, Error: t.Error,
Status: t.Status, ConversationID: t.ConversationID, Error: t.Error,
CreditCost: t.CreditCost, ImageURLs: urls, FileIDs: fids,
CreatedAt: t.CreatedAt, StartedAt: t.StartedAt, FinishedAt: t.FinishedAt,
}
}

// GET /api/me/images/tasks
// 查询参数:limit(默认 20,上限 100), offset
// 查询参数:limit(默认 20,上限 100), offset, status, keyword, start_at, end_at
func (h *MeHandler) List(c *gin.Context) {
uid := middleware.UserID(c)
if uid == 0 {
Expand All @@ -77,7 +83,27 @@ func (h *MeHandler) List(c *gin.Context) {
if offset < 0 {
offset = 0
}
tasks, err := h.dao.ListByUser(c.Request.Context(), uid, limit, offset)
filter := MyTaskFilter{
Status: strings.TrimSpace(c.Query("status")),
Keyword: strings.TrimSpace(c.Query("keyword")),
}
if startAt := strings.TrimSpace(c.Query("start_at")); startAt != "" {
tm, err := parseFilterTime(startAt)
if err != nil {
resp.Fail(c, resp.CodeBadRequest, "start_at 格式错误,期望 2006-01-02 15:04:05")
return
}
filter.CreatedAt = &tm
}
if endAt := strings.TrimSpace(c.Query("end_at")); endAt != "" {
tm, err := parseFilterTime(endAt)
if err != nil {
resp.Fail(c, resp.CodeBadRequest, "end_at 格式错误,期望 2006-01-02 15:04:05")
return
}
filter.CreatedTo = &tm
}
tasks, err := h.dao.ListByUser(c.Request.Context(), uid, filter, limit, offset)
if err != nil {
resp.Internal(c, err.Error())
return
Expand All @@ -89,6 +115,21 @@ func (h *MeHandler) List(c *gin.Context) {
resp.OK(c, gin.H{"items": items, "limit": limit, "offset": offset})
}

func parseFilterTime(s string) (time.Time, error) {
loc := time.Local
layouts := []string{
"2006-01-02 15:04:05",
time.RFC3339,
"2006-01-02",
}
for _, layout := range layouts {
if t, err := time.ParseInLocation(layout, s, loc); err == nil {
return t, nil
}
}
return time.Time{}, errors.New("invalid time")
}

// GET /api/me/images/tasks/:id
func (h *MeHandler) Get(c *gin.Context) {
uid := middleware.UserID(c)
Expand Down
Loading