Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions internal/handlers/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func getUserProj(ctx context.Context, user, project string) (string, string, int

// Create a new embeddings
func postProjEmbeddingsFunc(ctx context.Context, input *models.PostProjEmbeddingsRequest) (*models.UploadProjEmbeddingsResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Get the database connection pool from the context
pool, err := GetDBPool(ctx)
Expand Down Expand Up @@ -258,6 +262,11 @@ func getProjEmbeddingsFunc(ctx context.Context, input *models.GetProjEmbeddingsR
}

func deleteProjEmbeddingsFunc(ctx context.Context, input *models.DeleteProjEmbeddingsRequest) (*models.DeleteProjEmbeddingsResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if user and project exist
_, _, _, err := getUserProj(ctx, input.UserHandle, input.ProjectHandle)
if err != nil {
Expand Down Expand Up @@ -355,6 +364,11 @@ func getDocEmbeddingsFunc(ctx context.Context, input *models.GetDocEmbeddingsReq
}

func deleteDocEmbeddingsFunc(ctx context.Context, input *models.DeleteEmbeddingsByDocIDRequest) (*models.DeleteEmbeddingsByDocIDResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if user and project exist
_, _, _, err := getUserProj(ctx, input.UserHandle, input.ProjectHandle)
if err != nil {
Expand Down
24 changes: 24 additions & 0 deletions internal/handlers/llm_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ func getEncryptionKey() *crypto.EncryptionKey {
// === Sharing LLM Service Definitions ===

func putDefinitionFunc(ctx context.Context, input *models.PutDefinitionRequest) (*models.UploadDefinitionResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

if input.DefinitionHandle != input.Body.DefinitionHandle {
return nil, huma.Error400BadRequest(fmt.Sprintf("definition handle in URL (\"%s\") does not match definition handle in body (\"%s\")", input.DefinitionHandle, input.Body.DefinitionHandle))
}
Expand Down Expand Up @@ -234,6 +239,10 @@ func getUserDefinitionsFunc(ctx context.Context, input *models.GetUserDefinition
}

func deleteDefinitionFunc(ctx context.Context, input *models.DeleteDefinitionRequest) (*models.DeleteDefinitionResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if user exists
u, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle})
Expand Down Expand Up @@ -400,6 +409,11 @@ func getDefinitionSharedUsersFunc(ctx context.Context, input *models.GetDefiniti

// Create a llm service instance (with a handle being present in the URL)
func putInstanceFunc(ctx context.Context, input *models.PutInstanceRequest) (*models.UploadInstanceResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

if input.InstanceHandle != input.Body.InstanceHandle {
return nil, huma.Error400BadRequest(fmt.Sprintf("instance handle in URL (\"%s\") does not match instance handle in body (\"%s\")", input.InstanceHandle, input.Body.InstanceHandle))
}
Expand Down Expand Up @@ -485,6 +499,11 @@ func postInstanceFunc(ctx context.Context, input *models.PostInstanceRequest) (*

// Create a llm service instance based on a definition
func postInstanceFromDefinitionFunc(ctx context.Context, input *models.PostInstanceFromDefinitionRequest) (*models.UploadInstanceResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

if input.UserHandle != input.Body.UserHandle {
return nil, huma.Error400BadRequest(fmt.Sprintf("user handle in URL (\"%s\") does not match user handle in body (\"%s\")", input.UserHandle, input.Body.UserHandle))
}
Expand Down Expand Up @@ -756,6 +775,11 @@ func getUserInstancesFunc(ctx context.Context, input *models.GetUserInstancesReq
}

func deleteInstanceFunc(ctx context.Context, input *models.DeleteInstanceRequest) (*models.DeleteInstanceResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if user exists
u, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle})
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions internal/handlers/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ const (

// Create a new project
func putProjectFunc(ctx context.Context, input *models.PutProjectRequest) (*models.UploadProjectResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

if input.ProjectHandle != input.Body.ProjectHandle {
return nil, huma.Error400BadRequest(fmt.Sprintf("project handle in URL (%s) does not match project handle in body (%s)", input.ProjectHandle, input.Body.ProjectHandle))
}
Expand Down Expand Up @@ -354,6 +359,11 @@ func getProjectFunc(ctx context.Context, input *models.GetProjectRequest) (*mode
}

func deleteProjectFunc(ctx context.Context, input *models.DeleteProjectRequest) (*models.DeleteProjectResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if user exists
if _, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle}); err != nil {
return nil, err
Expand Down
10 changes: 10 additions & 0 deletions internal/handlers/similars.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (

// Define handler functions for each route
func getSimilarFunc(ctx context.Context, input *models.GetSimilarRequest) (*models.SimilarResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if only one of input.MetadataField and input.MetadataValue are given
if input.MetadataPath != "" && input.MetadataValue == "" {
return nil, huma.Error400BadRequest("metadata_path is set but metadata_value is not")
Expand Down Expand Up @@ -111,6 +116,11 @@ func getSimilarFunc(ctx context.Context, input *models.GetSimilarRequest) (*mode
}

func postSimilarFunc(ctx context.Context, input *models.PostSimilarRequest) (*models.SimilarResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Check if only one of input.MetadataPath and input.MetadataValue are given
if input.MetadataPath != "" && input.MetadataValue == "" {
return nil, huma.Error400BadRequest("metadata_path is set but metadata_value is not")
Expand Down
198 changes: 198 additions & 0 deletions internal/handlers/system_user_restrictions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package handlers_test

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

// TestSystemUserRestrictions verifies that the _system user cannot send requests
// The _system user is a read-only account that can only own resources, not create them
func TestSystemUserRestrictions(t *testing.T) {
// Get the database connection pool from package variable
pool := connPool

// Create a mock key generator
mockKeyGen := new(MockKeyGen)
mockKeyGen.On("RandomKey", 32).Return("12345678901234567890123456789012", nil).Maybe()

// Start the server
err, shutDownServer := startTestServer(t, pool, mockKeyGen)
assert.NoError(t, err)
defer shutDownServer()

// Create a regular user (alice) for testing
aliceJSON := `{"user_handle": "alice", "name": "Alice Doe", "email": "alice@foo.bar"}`
aliceAPIKey, err := createUser(t, aliceJSON)
if err != nil {
t.Fatalf("Error creating user alice for testing: %v\n", err)
}

// Create API standard for LLM service testing
apiStandardJSON := `{"api_standard_handle": "openai", "description": "OpenAI Embeddings API", "key_method": "auth_bearer", "key_field": "Authorization" }`
_, err = createAPIStandard(t, apiStandardJSON, options.AdminKey)
if err != nil {
t.Fatalf("Error creating API standard openai for testing: %v\n", err)
}

// Create LLM Service for alice
instanceJSON := `{ "instance_handle": "embedding1", "endpoint": "https://api.foo.bar/v1/embed", "description": "An LLM Service just for testing", "api_standard": "openai", "model": "embed-test1", "dimensions": 5}`
_, err = createInstance(t, instanceJSON, "alice", aliceAPIKey)
if err != nil {
t.Fatalf("Error creating LLM service for testing: %v\n", err)
}

// Create project for alice
projectJSON := `{"project_handle": "test1", "description": "A test project", "instance_owner": "alice", "instance_handle": "embedding1"}`
_, err = createProject(t, projectJSON, "alice", aliceAPIKey)
if err != nil {
t.Fatalf("Error creating project alice/test1 for testing: %v\n", err)
}

// Define test cases - all should fail with 403 Forbidden
// We use admin authentication to bypass owner auth and test our validation layer
testCases := []struct {
name string
method string
path string
body string
expectStatus int
}{
{
name: "System user cannot create projects",
method: http.MethodPut,
path: "/v1/projects/_system/forbidden-project",
body: `{"project_handle": "forbidden-project", "description": "Should not be allowed"}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete projects",
method: http.MethodDelete,
path: "/v1/projects/_system/some-project",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot upload embeddings",
method: http.MethodPost,
path: "/v1/embeddings/_system/test1",
body: `{"embeddings": []}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete all project embeddings",
method: http.MethodDelete,
path: "/v1/embeddings/_system/test1",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete specific document embeddings",
method: http.MethodDelete,
path: "/v1/embeddings/_system/test1/doc-id-123",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot post similarity requests",
method: http.MethodPost,
path: "/v1/similars/_system/test1",
body: `{"vector": [1.0, 2.0, 3.0, 4.0, 5.0]}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot get similarity requests",
method: http.MethodGet,
path: "/v1/similars/_system/test1/some-text-id",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot update itself",
method: http.MethodPut,
path: "/v1/users/_system",
body: `{"user_handle": "_system", "name": "System", "email": "system@test.com"}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete itself",
method: http.MethodDelete,
path: "/v1/users/_system",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot create LLM service definitions",
method: http.MethodPut,
path: "/v1/llm-definitions/_system/test-def",
body: `{"definition_handle": "test-def", "api_standard": "openai", "model": "test", "dimensions": 5}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete LLM service definitions",
method: http.MethodDelete,
path: "/v1/llm-definitions/_system/openai-large",
body: "",
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot create LLM service instances",
method: http.MethodPut,
path: "/v1/llm-instances/_system/test-instance",
body: `{"instance_handle": "test-instance", "endpoint": "https://test.com", "api_standard": "openai", "model": "test", "dimensions": 5}`,
expectStatus: http.StatusForbidden,
},
{
name: "System user cannot delete LLM service instances",
method: http.MethodDelete,
path: "/v1/llm-instances/_system/test-instance",
body: "",
expectStatus: http.StatusForbidden,
},
}

// Run all test cases
// Note: We use admin authentication to bypass the owner authentication layer
// so that we can test our validation logic. In practice, both auth and validation
// will prevent _system from making requests.
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqBody := io.Reader(nil)
if tc.body != "" {
reqBody = bytes.NewReader([]byte(tc.body))
}

requestURL := fmt.Sprintf("http://%v:%d%v", options.Host, options.Port, tc.path)
req, err := http.NewRequest(tc.method, requestURL, reqBody)
assert.NoError(t, err)
// Use admin key to bypass owner authentication and reach our validation layer
req.Header.Set("Authorization", "Bearer "+options.AdminKey)
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, tc.expectStatus, resp.StatusCode,
"Expected %s request to %s to return status %d, got %d",
tc.method, tc.path, tc.expectStatus, resp.StatusCode)

// Verify error message contains expected text
if resp.StatusCode == http.StatusForbidden {
respBody, err := io.ReadAll(resp.Body)
assert.NoError(t, err)

var errorResp struct {
Detail string `json:"detail"`
}
err = json.Unmarshal(respBody, &errorResp)
assert.NoError(t, err)
assert.Contains(t, errorResp.Detail, "cannot send requests",
"Error message should indicate _system user cannot send requests")
}
})
}
}
10 changes: 10 additions & 0 deletions internal/handlers/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (

// putUserFunc creates or updates a user
func putUserFunc(ctx context.Context, input *models.PutUserRequest) (*models.UploadUserResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

if input.UserHandle != input.Body.UserHandle {
return nil, huma.Error400BadRequest(fmt.Sprintf("user handle in URL (%s) does not match user handle in body (%v).", input.UserHandle, input.Body.UserHandle))
}
Expand Down Expand Up @@ -236,6 +241,11 @@ func getUserFunc(ctx context.Context, input *models.GetUserRequest) (*models.Get

// Delete a specific user
func deleteUserFunc(ctx context.Context, input *models.DeleteUserRequest) (*models.DeleteUserResponse, error) {
// Validate that _system user cannot send requests
if err := ValidateNotSystemUser(input.UserHandle); err != nil {
return nil, err
}

// Get the database connection pool from the context
pool, err := GetDBPool(ctx)
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions internal/handlers/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"

"github.com/danielgtaylor/huma/v2"
"github.com/mpilhlt/dhamps-vdb/internal/models"
"github.com/xeipuuv/gojsonschema"
)
Expand Down Expand Up @@ -97,3 +98,12 @@ func ValidateEmbeddingMetadataAgainstProjectSchema(metadata json.RawMessage, sch
}
return nil
}

// ValidateNotSystemUser checks if a user handle is "_system" and returns an error if so
// The _system user is read-only and cannot send requests
func ValidateNotSystemUser(userHandle string) error {
if userHandle == "_system" {
return huma.Error403Forbidden("_system user cannot send requests - this is a read-only account")
}
return nil
}
Loading
Loading