Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions internal/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions internal/database/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ LIMIT 1;
-- name: GetAllProjects :many
SELECT projects."owner", projects."project_handle"
FROM projects
ORDER BY "owner" ASC, "project_handle" ASC;
ORDER BY "owner" ASC, "project_handle" ASC LIMIT $1 OFFSET $2;

-- name: CountAllProjects :one
SELECT COUNT(*)
Expand Down Expand Up @@ -279,7 +279,7 @@ ON definitions."definition_id" = definitions_shared_with."definition_id"
WHERE definitions."owner" = $1
AND definitions."definition_handle" = $2
AND definitions_shared_with."user_handle" != '*'
ORDER BY "user_handle" ASC;
ORDER BY "user_handle" ASC LIMIT $3 OFFSET $4;

-- name: GetAccessibleDefinitionsByUser :many
-- Get all definitions accessible to a user (owned + shared + _system)
Expand Down Expand Up @@ -490,7 +490,7 @@ JOIN instances
ON instances."instance_id" = instances_shared_with."instance_id"
WHERE instances."owner" = $1
AND instances."instance_handle" = $2
ORDER BY "user_handle" ASC;
ORDER BY "user_handle" ASC LIMIT $3 OFFSET $4;

-- name: RetrieveSharedInstance :one
-- Get single instance, but only if it is shared with requesting user
Expand Down
6 changes: 5 additions & 1 deletion internal/handlers/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ func sanityCheckFunc(ctx context.Context, input *models.SanityCheckRequest) (*mo
queries := database.New(pool)

// Get all projects with their metadata schemes
projects, err := queries.GetAllProjects(ctx)
// Use maxAdminQueryLimit since sanity check needs to validate all projects
projects, err := queries.GetAllProjects(ctx, database.GetAllProjectsParams{
Limit: maxAdminQueryLimit,
Offset: 0,
})
if err != nil {
return nil, huma.Error500InternalServerError(fmt.Sprintf("unable to get projects. %v", err))
}
Expand Down
17 changes: 13 additions & 4 deletions internal/handlers/llm_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,12 @@ func getDefinitionSharedUsersFunc(ctx context.Context, input *models.GetDefiniti
}
queries := database.New(pool)

// Get shared users
// Get shared users - use input parameters for user-facing pagination
sharedUsers, err := queries.GetSharedUsersForDefinition(ctx, database.GetSharedUsersForDefinitionParams{
Owner: input.UserHandle,
DefinitionHandle: input.DefinitionHandle,
Limit: int32(input.Limit),
Offset: int32(input.Offset),
})
if err != nil {
if err.Error() == "no rows in result set" {
Expand Down Expand Up @@ -539,10 +541,12 @@ func postInstanceFromDefinitionFunc(ctx context.Context, input *models.PostInsta
// Check if user has access to the definition (either owner or shared)
if !definition.IsPublic && definition.Owner != ctx.Value(auth.AuthUserKey).(string) {
hasAccess := false
// Check if shared with user
// Check if shared with user - use maxSharedUsersPerQuery for authorization check
sharedUsers, err := queries.GetSharedUsersForDefinition(ctx, database.GetSharedUsersForDefinitionParams{
Owner: input.Body.DefinitionOwner,
DefinitionHandle: input.Body.DefinitionHandle,
Limit: maxSharedUsersPerQuery,
Offset: 0,
})
if err != nil && err.Error() != "no rows in result set" {
return huma.Error500InternalServerError(fmt.Sprintf("unable to retrieve shared users for definition %s/%s: %v", input.Body.DefinitionOwner, input.Body.DefinitionHandle, err))
Expand Down Expand Up @@ -648,7 +652,7 @@ func getInstanceFunc(ctx context.Context, input *models.GetInstanceRequest) (*mo
if authUserHandle, ok := ctx.Value(auth.AuthUserKey).(string); ok {
acessibleInstances, err := queries.GetAccessibleInstancesByUser(ctx, database.GetAccessibleInstancesByUserParams{
Owner: authUserHandle,
Limit: 999,
Limit: maxSharedUsersPerQuery,
Offset: 0,
})
if err != nil && err != pgx.ErrNoRows {
Expand Down Expand Up @@ -877,9 +881,12 @@ func unshareInstanceFunc(ctx context.Context, input *models.UnshareInstanceReque
}

// Check if target user exists and is currently shared
// Use maxSharedUsersPerQuery for internal check to ensure we scan all shared users
sharedUsers, err := queries.GetSharedUsersForInstance(ctx, database.GetSharedUsersForInstanceParams{
Owner: input.UserHandle,
InstanceHandle: input.InstanceHandle,
Limit: maxSharedUsersPerQuery,
Offset: 0,
})
if err != nil {
if err.Error() == "no rows in result set" {
Expand Down Expand Up @@ -915,10 +922,12 @@ func getInstanceSharedUsersFunc(ctx context.Context, input *models.GetInstanceSh
}
queries := database.New(pool)

// Get shared users
// Get shared users - use input parameters for user-facing pagination
sharedUsers, err := queries.GetSharedUsersForInstance(ctx, database.GetSharedUsersForInstanceParams{
Owner: input.UserHandle,
InstanceHandle: input.InstanceHandle,
Limit: int32(input.Limit),
Offset: int32(input.Offset),
})
if err != nil {
if err.Error() == "no rows in result set" {
Expand Down
72 changes: 72 additions & 0 deletions internal/handlers/pagination_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package handlers_test

import (
"encoding/json"
"fmt"
"io"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestPaginationForGetUsers(t *testing.T) {
// Get the database connection pool from package variable
pool := connPool

// Create a mock key generator
mockKeyGen := new(MockKeyGen)
mockKeyGen.On("RandomKey", 32).Return("12345678901234567890123456789012", nil).Maybe()

// Start the server
err, shutDownServer := startTestServer(t, pool, mockKeyGen)
assert.NoError(t, err)
defer shutDownServer()

fmt.Printf("\nRunning pagination tests ...\n\n")

// Test pagination: limit=1, offset=0 (should get first user)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d/v1/users?limit=1&offset=0", options.Port), nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+options.AdminKey)

client := &http.Client{}
resp, err := client.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode)

bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)

var userList []string
err = json.Unmarshal(bodyBytes, &userList)
assert.NoError(t, err)
assert.Equal(t, 1, len(userList), "Expected exactly 1 user with limit=1")

fmt.Printf("First page (limit=1, offset=0): %v\n", userList)

// Test getting all users with high limit
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d/v1/users?limit=100", options.Port), nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+options.AdminKey)

resp3, err := client.Do(req)
assert.NoError(t, err)
defer resp3.Body.Close()

assert.Equal(t, http.StatusOK, resp3.StatusCode)

bodyBytes, err = io.ReadAll(resp3.Body)
assert.NoError(t, err)

err = json.Unmarshal(bodyBytes, &userList)
assert.NoError(t, err)
// We should have at least the _system user
assert.GreaterOrEqual(t, len(userList), 1, "Expected at least 1 user")

fmt.Printf("All users (limit=100): %d users found\n", len(userList))

fmt.Printf("\nPagination tests completed successfully!\n\n")
}
8 changes: 6 additions & 2 deletions internal/handlers/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ const (
// maxSharedUsersPerQuery is the maximum number of shared users to retrieve in a single query
// This prevents memory issues when a project is shared with many users
maxSharedUsersPerQuery = 1000

// maxAdminQueryLimit is the maximum limit for admin operations that need to scan all records
// Used for operations like sanity checks that validate all data in the database
maxAdminQueryLimit = 999999
)

// Create a new project
Expand Down Expand Up @@ -277,7 +281,7 @@ func getProjectFunc(ctx context.Context, input *models.GetProjectRequest) (*mode
sharedUsers = append(sharedUsers, models.SharedUser{UserHandle: "*", Role: "reader"})
}
// Iterate all shared users
userRows, err := queries.GetUsersByProject(ctx, database.GetUsersByProjectParams{Owner: input.UserHandle, ProjectHandle: input.ProjectHandle, Limit: 999, Offset: 0})
userRows, err := queries.GetUsersByProject(ctx, database.GetUsersByProjectParams{Owner: input.UserHandle, ProjectHandle: input.ProjectHandle, Limit: maxSharedUsersPerQuery, Offset: 0})
if err != nil {
return nil, huma.Error500InternalServerError(fmt.Sprintf("unable to get authorized reader accounts for %s's project %s. %v", input.UserHandle, input.ProjectHandle, err))
}
Expand All @@ -303,7 +307,7 @@ func getProjectFunc(ctx context.Context, input *models.GetProjectRequest) (*mode
if llmRow.Owner == requestingUser.(string) {
accessRole = "owner"
} else {
sharedUsers, err := queries.GetSharedUsersForInstance(ctx, database.GetSharedUsersForInstanceParams{Owner: llmRow.Owner, InstanceHandle: llmRow.InstanceHandle})
sharedUsers, err := queries.GetSharedUsersForInstance(ctx, database.GetSharedUsersForInstanceParams{Owner: llmRow.Owner, InstanceHandle: llmRow.InstanceHandle, Limit: maxSharedUsersPerQuery, Offset: 0})
if err != nil {
return nil, huma.Error500InternalServerError(fmt.Sprintf("unable to get shared users for LLM Service Instance %s owned by %s. %v", llmRow.InstanceHandle, llmRow.Owner, err))
}
Expand Down
8 changes: 6 additions & 2 deletions internal/handlers/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ func getUserFunc(ctx context.Context, input *models.GetUserRequest) (*models.Get

// Get projects the user is a member of
projects := models.ProjectMemberships{}
ps, err := queries.GetProjectsByUser(ctx, database.GetProjectsByUserParams{UserHandle: input.UserHandle})
ps, err := queries.GetProjectsByUser(ctx, database.GetProjectsByUserParams{
UserHandle: input.UserHandle,
Limit: maxSharedUsersPerQuery,
Offset: 0,
})
if err != nil {
if err.Error() == "no rows in result set" {
fmt.Printf("Warning: No LLM Services registered for user %s.", input.UserHandle)
Expand All @@ -182,7 +186,7 @@ func getUserFunc(ctx context.Context, input *models.GetUserRequest) (*models.Get
imemberships := models.InstanceMemberships{}
instances, err := queries.GetAccessibleInstancesByUser(ctx, database.GetAccessibleInstancesByUserParams{
Owner: input.UserHandle,
Limit: 999,
Limit: maxSharedUsersPerQuery,
Offset: 0,
})
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/models/llm_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ type UnshareDefinitionResponse struct {
type GetDefinitionSharedUsersRequest struct {
UserHandle string `json:"user_handle" path:"user_handle" maxLength:"20" minLength:"3" example:"_system" doc:"Definition owner handle"`
DefinitionHandle string `json:"definition_handle" path:"definition_handle" maxLength:"20" minLength:"3" example:"openai-large" doc:"Definition handle"`
Limit int `json:"limit,omitempty" query:"limit" minimum:"1" maximum:"200" example:"20" default:"100" doc:"Maximum number of users to return"`
Offset int `json:"offset,omitempty" query:"offset" minimum:"0" example:"0" default:"0" doc:"Offset into the list of users"`
}

type GetDefinitionSharedUsersResponse struct {
Expand Down Expand Up @@ -367,6 +369,8 @@ type UnshareInstanceResponse struct {
type GetInstanceSharedUsersRequest struct {
UserHandle string `json:"user_handle" path:"user_handle" maxLength:"20" minLength:"3" example:"alice" doc:"Instance owner handle"`
InstanceHandle string `json:"instance_handle" path:"instance_handle" maxLength:"20" minLength:"3" example:"my-openai" doc:"Instance handle"`
Limit int `json:"limit,omitempty" query:"limit" minimum:"1" maximum:"200" example:"20" default:"100" doc:"Maximum number of users to return"`
Offset int `json:"offset,omitempty" query:"offset" minimum:"0" example:"0" default:"0" doc:"Offset into the list of users"`
}

type GetInstanceSharedUsersResponse struct {
Expand Down