diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ff1c1b8865..acdd0d18b2 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -716,6 +716,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -755,31 +756,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -788,32 +789,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_model", @@ -828,17 +829,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 652adcac70..ff58fa9eb2 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -18239,6 +18239,7 @@ type UsageLogMutation struct { id *int64 request_id *string model *string + upstream_model *string input_tokens *int addinput_tokens *int output_tokens *int @@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() { m.model = nil } +// SetUpstreamModel sets the "upstream_model" field. +func (m *UsageLogMutation) SetUpstreamModel(s string) { + m.upstream_model = &s +} + +// UpstreamModel returns the value of the "upstream_model" field in the mutation. +func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) { + v := m.upstream_model + if v == nil { + return + } + return *v, true +} + +// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err) + } + return oldValue.UpstreamModel, nil +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (m *UsageLogMutation) ClearUpstreamModel() { + m.upstream_model = nil + m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{} +} + +// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation. +func (m *UsageLogMutation) UpstreamModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldUpstreamModel] + return ok +} + +// ResetUpstreamModel resets all changes to the "upstream_model" field. +func (m *UsageLogMutation) ResetUpstreamModel() { + m.upstream_model = nil + delete(m.clearedFields, usagelog.FieldUpstreamModel) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 33) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string { if m.model != nil { fields = append(fields, usagelog.FieldModel) } + if m.upstream_model != nil { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case usagelog.FieldModel: return m.Model() + case usagelog.FieldUpstreamModel: + return m.UpstreamModel() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestID(ctx) case usagelog.FieldModel: return m.OldModel(ctx) + case usagelog.FieldUpstreamModel: + return m.OldUpstreamModel(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetModel(v) return nil + case usagelog.FieldUpstreamModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamModel(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UsageLogMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usagelog.FieldUpstreamModel) { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UsageLogMutation) ClearField(name string) error { switch name { + case usagelog.FieldUpstreamModel: + m.ClearUpstreamModel() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldModel: m.ResetModel() return nil + case usagelog.FieldUpstreamModel: + m.ResetUpstreamModel() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index b8facf362b..2401e5538b 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -821,92 +821,96 @@ func init() { return nil } }() + // usagelogDescUpstreamModel is the schema descriptor for upstream_model field. + usagelogDescUpstreamModel := usagelogFields[5].Descriptor() + // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[7].Descriptor() + usagelogDescInputTokens := usagelogFields[8].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[8].Descriptor() + usagelogDescOutputTokens := usagelogFields[9].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[13].Descriptor() + usagelogDescInputCost := usagelogFields[14].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[14].Descriptor() + usagelogDescOutputCost := usagelogFields[15].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + usagelogDescCacheReadCost := usagelogFields[17].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[17].Descriptor() + usagelogDescTotalCost := usagelogFields[18].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[18].Descriptor() + usagelogDescActualCost := usagelogFields[19].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + usagelogDescRateMultiplier := usagelogFields[20].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[21].Descriptor() + usagelogDescBillingType := usagelogFields[22].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[22].Descriptor() + usagelogDescStream := usagelogFields[23].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[25].Descriptor() + usagelogDescUserAgent := usagelogFields[26].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[26].Descriptor() + usagelogDescIPAddress := usagelogFields[27].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[27].Descriptor() + usagelogDescImageCount := usagelogFields[28].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[28].Descriptor() + usagelogDescImageSize := usagelogFields[29].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[29].Descriptor() + usagelogDescMediaType := usagelogFields[30].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[31].Descriptor() + usagelogDescCreatedAt := usagelogFields[32].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index dcca1a0ad2..f66ea34fed 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field { field.String("model"). MaxLen(100). NotEmpty(), + // upstream_model: 实际发往上游的模型名(经过模型映射后)。 + // NULL 表示无映射(与 model 相同)。 + field.String("upstream_model"). + MaxLen(100). + Optional(). + Nillable(), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index f6968d0d97..014851c99e 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -32,6 +32,8 @@ type UsageLog struct { RequestID string `json:"request_id,omitempty"` // Model holds the value of the "model" field. Model string `json:"model,omitempty"` + // UpstreamModel holds the value of the "upstream_model" field. + UpstreamModel *string `json:"upstream_model,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Model = value.String } + case usagelog.FieldUpstreamModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) + } else if value.Valid { + _m.UpstreamModel = new(string) + *_m.UpstreamModel = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -477,6 +486,11 @@ func (_m *UsageLog) String() string { builder.WriteString("model=") builder.WriteString(_m.Model) builder.WriteString(", ") + if v := _m.UpstreamModel; v != nil { + builder.WriteString("upstream_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index ba97b84376..789407e71f 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -24,6 +24,8 @@ const ( FieldRequestID = "request_id" // FieldModel holds the string denoting the model field in the database. FieldModel = "model" + // FieldUpstreamModel holds the string denoting the upstream_model field in the database. + FieldUpstreamModel = "upstream_model" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -135,6 +137,7 @@ var Columns = []string{ FieldAccountID, FieldRequestID, FieldModel, + FieldUpstreamModel, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -179,6 +182,8 @@ var ( RequestIDValidator func(string) error // ModelValidator is a validator for the "model" field. It is called by the builders before save. ModelValidator func(string) error + // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + UpstreamModelValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModel, opts...).ToFunc() } +// ByUpstreamModel orders the results by the upstream_model field. +func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index af96033559..5f341976e9 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) } +// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. +func UpstreamModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) } +// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. +func UpstreamModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field. +func UpstreamModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelIn applies the In predicate on the "upstream_model" field. +func UpstreamModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field. +func UpstreamModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelGT applies the GT predicate on the "upstream_model" field. +func UpstreamModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v)) +} + +// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field. +func UpstreamModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v)) +} + +// UpstreamModelLT applies the LT predicate on the "upstream_model" field. +func UpstreamModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v)) +} + +// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field. +func UpstreamModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v)) +} + +// UpstreamModelContains applies the Contains predicate on the "upstream_model" field. +func UpstreamModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v)) +} + +// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field. +func UpstreamModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v)) +} + +// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field. +func UpstreamModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v)) +} + +// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field. +func UpstreamModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel)) +} + +// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field. +func UpstreamModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel)) +} + +// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field. +func UpstreamModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v)) +} + +// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field. +func UpstreamModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index e0285a5ef2..26be5dcb2d 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { return _c } +// SetUpstreamModel sets the "upstream_model" field. +func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { + _c.mutation.SetUpstreamModel(v) + return _c +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetUpstreamModel(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _c.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldModel, field.TypeString, value) _node.Model = value } + if value, ok := _c.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + _node.UpstreamModel = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { return u } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldUpstreamModel, v) + return u +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUpstreamModel) + return u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldUpstreamModel) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index b46e5b56e5..b7c4632c10 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 205ccd65b8..28fa66a6b9 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -550,6 +550,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { } return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), + UpstreamModel: l.UpstreamModel, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index d9ccda2d1f..fd650cb263 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -377,6 +377,9 @@ type UsageLog struct { type AdminUsageLog struct { UsageLog + // UpstreamModel 实际发往上游的模型名(仅管理员可见,nil 表示与 Model 相同) + UpstreamModel *string `json:"upstream_model,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e514..be047570db 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -108,6 +108,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -140,12 +141,12 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -167,12 +168,15 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) requestIDArg = requestID } + upstreamModel := nullString(log.UpstreamModel) + args := []any{ log.UserID, log.APIKeyID, log.AccountID, requestIDArg, log.Model, + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -2481,6 +2485,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -2521,6 +2526,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -2625,6 +2631,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 7d82b4d0c7..c020b74fd3 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id log.InputTokens, @@ -112,6 +113,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.AccountID, log.RequestID, log.Model, + log.UpstreamModel, sqlmock.AnyArg(), sqlmock.AnyArg(), log.InputTokens, @@ -318,6 +320,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, "gpt-5", // model + sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id 1, // input_tokens @@ -367,6 +370,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{}, // upstream_model sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, @@ -406,6 +410,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{}, // upstream_model sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2193bc0400..feb70227c5 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1692,15 +1692,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, firstTokenMs = streamRes.firstTokenMs } - return &ForwardResult{ + fr := &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, // 使用映射模型用于计费和日志 + Model: originalModel, Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, - }, nil + } + if billingModel != originalModel { + fr.UpstreamModel = billingModel + } + return fr, nil } func isSignatureRelatedError(respBody []byte) bool { @@ -2276,17 +2280,21 @@ handleSuccess: imageCount = 1 } - return &ForwardResult{ + fr := &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, + Model: originalModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, ImageCount: imageCount, ImageSize: imageSize, - }, nil + } + if billingModel != originalModel { + fr.UpstreamModel = billingModel + } + return fr, nil } func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool { diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 6096383896..4a539d51cf 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -501,7 +501,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { result, err := svc.Forward(context.Background(), c, account, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "claude-sonnet-4-5", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel @@ -553,7 +554,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } // TestStreamUpstreamResponse_UsageAndFirstToken diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5dcda1de43..8d4c99243c 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -751,7 +751,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc rateLimitService: &RateLimitService{}, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 12, result.Usage.InputTokens) @@ -778,7 +778,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp } svc := &GatewayService{} - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "requires apikey token") @@ -803,7 +803,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest } account := newAnthropicAPIKeyAccountForTest() - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "upstream request failed") @@ -836,7 +836,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo httpUpstream: upstream, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "empty response") diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 080de063df..1a42132357 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -456,7 +456,8 @@ type ClaudeUsage struct { type ForwardResult struct { RequestID string Usage ClaudeUsage - Model string + Model string // 用户请求的原始模型名 + UpstreamModel string // 实际发往上游的模型名(空字符串表示与 Model 相同) Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -3937,6 +3938,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { passthroughBody := parsed.Body + passthroughOriginalModel := parsed.Model passthroughModel := parsed.Model if passthroughModel != "" { if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { @@ -3945,7 +3947,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A passthroughModel = mappedModel } } - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughOriginalModel, passthroughModel, parsed.Stream, startTime) } // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. @@ -4453,7 +4455,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } } - return &ForwardResult{ + fr := &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 @@ -4461,7 +4463,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, - }, nil + } + if reqModel != originalModel { + fr.UpstreamModel = reqModel + } + return fr, nil } func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( @@ -4469,6 +4475,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( c *gin.Context, account *Account, body []byte, + originalModel string, reqModel string, reqStream bool, startTime time.Time, @@ -4673,15 +4680,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( usage = &ClaudeUsage{} } - return &ForwardResult{ + fr := &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, - Model: reqModel, + Model: originalModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, - }, nil + } + if reqModel != originalModel { + fr.UpstreamModel = reqModel + } + return fr, nil } func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( @@ -6712,7 +6723,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else if result.MediaType == "prompt" { cost = &CostBreakdown{} } else if result.ImageCount > 0 { - // 图片生成计费 + // 图片生成计费:使用上游模型名(映射后的模型名)查找价格 + billingModel := result.Model + if result.UpstreamModel != "" { + billingModel = result.UpstreamModel + } var groupConfig *ImagePriceConfig if apiKey.Group != nil { groupConfig = &ImagePriceConfig{ @@ -6721,9 +6736,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { - // Token 计费 + // Token 计费:使用上游模型名(映射后的模型名)查找价格 + billingModel := result.Model + if result.UpstreamModel != "" { + billingModel = result.UpstreamModel + } tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6733,7 +6752,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -6758,12 +6777,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() + var upstreamModel *string + if result.UpstreamModel != "" { + upstreamModel = &result.UpstreamModel + } usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, + UpstreamModel: upstreamModel, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, @@ -6889,6 +6913,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * var cost *CostBreakdown + // 使用上游模型名(映射后的模型名)查找价格 + billingModelLC := result.Model + if result.UpstreamModel != "" { + billingModelLC = result.UpstreamModel + } + // 根据请求类型选择计费方式 if result.ImageCount > 0 { // 图片生成计费 @@ -6900,7 +6930,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModelLC, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ @@ -6912,7 +6942,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + cost, err = s.billingService.CalculateCostWithLongContext(billingModelLC, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -6933,12 +6963,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() + var upstreamModelLC *string + if result.UpstreamModel != "" { + upstreamModelLC = &result.UpstreamModel + } usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, + UpstreamModel: upstreamModelLC, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 54068f2b2f..913ca1fa91 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -240,7 +240,8 @@ type OpenAIForwardResult struct { BillingModel string // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // Nil means the request did not specify a recognized tier. - ServiceTier *string + ServiceTier *string + UpstreamModel string // 实际发往上游的模型名(空字符串表示与 Model 相同) // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string @@ -2162,7 +2163,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) serviceTier := extractOpenAIServiceTier(reqBody) - return &OpenAIForwardResult{ + fr := &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, @@ -2172,7 +2173,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if mappedModel != originalModel { + fr.UpstreamModel = mappedModel + } + return fr, nil } func (s *OpenAIGatewayService) forwardOpenAIPassthrough( @@ -3765,6 +3770,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec billingModel := result.Model if result.BillingModel != "" { billingModel = result.BillingModel + } else if result.UpstreamModel != "" { + billingModel = result.UpstreamModel } serviceTier := "" if result.ServiceTier != nil { @@ -3785,6 +3792,10 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + var upstreamModel *string + if result.UpstreamModel != "" { + upstreamModel = &result.UpstreamModel + } usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, @@ -3792,6 +3803,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec RequestID: result.RequestID, Model: billingModel, ServiceTier: result.ServiceTier, + UpstreamModel: upstreamModel, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 52bb8590d0..c69628d885 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2303,7 +2303,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( clientDisconnected, ) - return &OpenAIForwardResult{ + fr := &OpenAIForwardResult{ RequestID: responseID, Usage: *usage, Model: originalModel, @@ -2314,7 +2314,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if mappedModel != originalModel { + fr.UpstreamModel = mappedModel + } + return fr, nil } // ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 @@ -2920,7 +2924,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( clientDisconnected, ) } - return &OpenAIForwardResult{ + fr := &OpenAIForwardResult{ RequestID: responseID, Usage: usage, Model: originalModel, @@ -2931,7 +2935,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, - }, nil + } + if mappedModel != "" && mappedModel != originalModel { + fr.UpstreamModel = mappedModel + } + return fr, nil } } } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index a7464956bb..af12d203eb 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -99,7 +99,8 @@ type UsageLog struct { RequestID string Model string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". - ServiceTier *string + ServiceTier *string + UpstreamModel *string // 实际发往上游的模型名(nil 表示与 Model 相同,无映射) // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. ReasoningEffort *string diff --git a/backend/migrations/071_add_upstream_model.sql b/backend/migrations/071_add_upstream_model.sql new file mode 100644 index 0000000000..60c0cbbcf9 --- /dev/null +++ b/backend/migrations/071_add_upstream_model.sql @@ -0,0 +1 @@ +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 72f7c01057..178e0d4d05 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -25,8 +25,13 @@ {{ row.account?.name || '-' }} -