diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index 0477ff4..215ab1b 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -44,6 +44,10 @@ func getUserProj(ctx context.Context, user, project string) (string, string, int // Create a new embeddings func postProjEmbeddingsFunc(ctx context.Context, input *models.PostProjEmbeddingsRequest) (*models.UploadProjEmbeddingsResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } // Get the database connection pool from the context pool, err := GetDBPool(ctx) @@ -258,6 +262,11 @@ func getProjEmbeddingsFunc(ctx context.Context, input *models.GetProjEmbeddingsR } func deleteProjEmbeddingsFunc(ctx context.Context, input *models.DeleteProjEmbeddingsRequest) (*models.DeleteProjEmbeddingsResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if user and project exist _, _, _, err := getUserProj(ctx, input.UserHandle, input.ProjectHandle) if err != nil { @@ -355,6 +364,11 @@ func getDocEmbeddingsFunc(ctx context.Context, input *models.GetDocEmbeddingsReq } func deleteDocEmbeddingsFunc(ctx context.Context, input *models.DeleteEmbeddingsByDocIDRequest) (*models.DeleteEmbeddingsByDocIDResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if user and project exist _, _, _, err := getUserProj(ctx, input.UserHandle, input.ProjectHandle) if err != nil { diff --git a/internal/handlers/llm_services.go b/internal/handlers/llm_services.go index 12d8109..49060f7 100644 --- a/internal/handlers/llm_services.go +++ b/internal/handlers/llm_services.go @@ -30,6 +30,11 @@ func getEncryptionKey() *crypto.EncryptionKey { // === Sharing LLM Service Definitions === func putDefinitionFunc(ctx context.Context, input *models.PutDefinitionRequest) (*models.UploadDefinitionResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + if input.DefinitionHandle != input.Body.DefinitionHandle { return nil, huma.Error400BadRequest(fmt.Sprintf("definition handle in URL (\"%s\") does not match definition handle in body (\"%s\")", input.DefinitionHandle, input.Body.DefinitionHandle)) } @@ -234,6 +239,10 @@ func getUserDefinitionsFunc(ctx context.Context, input *models.GetUserDefinition } func deleteDefinitionFunc(ctx context.Context, input *models.DeleteDefinitionRequest) (*models.DeleteDefinitionResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } // Check if user exists u, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle}) @@ -400,6 +409,11 @@ func getDefinitionSharedUsersFunc(ctx context.Context, input *models.GetDefiniti // Create a llm service instance (with a handle being present in the URL) func putInstanceFunc(ctx context.Context, input *models.PutInstanceRequest) (*models.UploadInstanceResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + if input.InstanceHandle != input.Body.InstanceHandle { return nil, huma.Error400BadRequest(fmt.Sprintf("instance handle in URL (\"%s\") does not match instance handle in body (\"%s\")", input.InstanceHandle, input.Body.InstanceHandle)) } @@ -485,6 +499,11 @@ func postInstanceFunc(ctx context.Context, input *models.PostInstanceRequest) (* // Create a llm service instance based on a definition func postInstanceFromDefinitionFunc(ctx context.Context, input *models.PostInstanceFromDefinitionRequest) (*models.UploadInstanceResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + if input.UserHandle != input.Body.UserHandle { return nil, huma.Error400BadRequest(fmt.Sprintf("user handle in URL (\"%s\") does not match user handle in body (\"%s\")", input.UserHandle, input.Body.UserHandle)) } @@ -756,6 +775,11 @@ func getUserInstancesFunc(ctx context.Context, input *models.GetUserInstancesReq } func deleteInstanceFunc(ctx context.Context, input *models.DeleteInstanceRequest) (*models.DeleteInstanceResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if user exists u, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle}) if err != nil { diff --git a/internal/handlers/projects.go b/internal/handlers/projects.go index 87ba121..b297ca7 100644 --- a/internal/handlers/projects.go +++ b/internal/handlers/projects.go @@ -27,6 +27,11 @@ const ( // Create a new project func putProjectFunc(ctx context.Context, input *models.PutProjectRequest) (*models.UploadProjectResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + if input.ProjectHandle != input.Body.ProjectHandle { return nil, huma.Error400BadRequest(fmt.Sprintf("project handle in URL (%s) does not match project handle in body (%s)", input.ProjectHandle, input.Body.ProjectHandle)) } @@ -354,6 +359,11 @@ func getProjectFunc(ctx context.Context, input *models.GetProjectRequest) (*mode } func deleteProjectFunc(ctx context.Context, input *models.DeleteProjectRequest) (*models.DeleteProjectResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if user exists if _, err := getUserFunc(ctx, &models.GetUserRequest{UserHandle: input.UserHandle}); err != nil { return nil, err diff --git a/internal/handlers/similars.go b/internal/handlers/similars.go index 8180b21..5e67bf0 100644 --- a/internal/handlers/similars.go +++ b/internal/handlers/similars.go @@ -17,6 +17,11 @@ import ( // Define handler functions for each route func getSimilarFunc(ctx context.Context, input *models.GetSimilarRequest) (*models.SimilarResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if only one of input.MetadataField and input.MetadataValue are given if input.MetadataPath != "" && input.MetadataValue == "" { return nil, huma.Error400BadRequest("metadata_path is set but metadata_value is not") @@ -111,6 +116,11 @@ func getSimilarFunc(ctx context.Context, input *models.GetSimilarRequest) (*mode } func postSimilarFunc(ctx context.Context, input *models.PostSimilarRequest) (*models.SimilarResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Check if only one of input.MetadataPath and input.MetadataValue are given if input.MetadataPath != "" && input.MetadataValue == "" { return nil, huma.Error400BadRequest("metadata_path is set but metadata_value is not") diff --git a/internal/handlers/system_user_restrictions_test.go b/internal/handlers/system_user_restrictions_test.go new file mode 100644 index 0000000..1d80c0d --- /dev/null +++ b/internal/handlers/system_user_restrictions_test.go @@ -0,0 +1,198 @@ +package handlers_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestSystemUserRestrictions verifies that the _system user cannot send requests +// The _system user is a read-only account that can only own resources, not create them +func TestSystemUserRestrictions(t *testing.T) { + // Get the database connection pool from package variable + pool := connPool + + // Create a mock key generator + mockKeyGen := new(MockKeyGen) + mockKeyGen.On("RandomKey", 32).Return("12345678901234567890123456789012", nil).Maybe() + + // Start the server + err, shutDownServer := startTestServer(t, pool, mockKeyGen) + assert.NoError(t, err) + defer shutDownServer() + + // Create a regular user (alice) for testing + aliceJSON := `{"user_handle": "alice", "name": "Alice Doe", "email": "alice@foo.bar"}` + aliceAPIKey, err := createUser(t, aliceJSON) + if err != nil { + t.Fatalf("Error creating user alice for testing: %v\n", err) + } + + // Create API standard for LLM service testing + apiStandardJSON := `{"api_standard_handle": "openai", "description": "OpenAI Embeddings API", "key_method": "auth_bearer", "key_field": "Authorization" }` + _, err = createAPIStandard(t, apiStandardJSON, options.AdminKey) + if err != nil { + t.Fatalf("Error creating API standard openai for testing: %v\n", err) + } + + // Create LLM Service for alice + instanceJSON := `{ "instance_handle": "embedding1", "endpoint": "https://api.foo.bar/v1/embed", "description": "An LLM Service just for testing", "api_standard": "openai", "model": "embed-test1", "dimensions": 5}` + _, err = createInstance(t, instanceJSON, "alice", aliceAPIKey) + if err != nil { + t.Fatalf("Error creating LLM service for testing: %v\n", err) + } + + // Create project for alice + projectJSON := `{"project_handle": "test1", "description": "A test project", "instance_owner": "alice", "instance_handle": "embedding1"}` + _, err = createProject(t, projectJSON, "alice", aliceAPIKey) + if err != nil { + t.Fatalf("Error creating project alice/test1 for testing: %v\n", err) + } + + // Define test cases - all should fail with 403 Forbidden + // We use admin authentication to bypass owner auth and test our validation layer + testCases := []struct { + name string + method string + path string + body string + expectStatus int + }{ + { + name: "System user cannot create projects", + method: http.MethodPut, + path: "/v1/projects/_system/forbidden-project", + body: `{"project_handle": "forbidden-project", "description": "Should not be allowed"}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete projects", + method: http.MethodDelete, + path: "/v1/projects/_system/some-project", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot upload embeddings", + method: http.MethodPost, + path: "/v1/embeddings/_system/test1", + body: `{"embeddings": []}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete all project embeddings", + method: http.MethodDelete, + path: "/v1/embeddings/_system/test1", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete specific document embeddings", + method: http.MethodDelete, + path: "/v1/embeddings/_system/test1/doc-id-123", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot post similarity requests", + method: http.MethodPost, + path: "/v1/similars/_system/test1", + body: `{"vector": [1.0, 2.0, 3.0, 4.0, 5.0]}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot get similarity requests", + method: http.MethodGet, + path: "/v1/similars/_system/test1/some-text-id", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot update itself", + method: http.MethodPut, + path: "/v1/users/_system", + body: `{"user_handle": "_system", "name": "System", "email": "system@test.com"}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete itself", + method: http.MethodDelete, + path: "/v1/users/_system", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot create LLM service definitions", + method: http.MethodPut, + path: "/v1/llm-definitions/_system/test-def", + body: `{"definition_handle": "test-def", "api_standard": "openai", "model": "test", "dimensions": 5}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete LLM service definitions", + method: http.MethodDelete, + path: "/v1/llm-definitions/_system/openai-large", + body: "", + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot create LLM service instances", + method: http.MethodPut, + path: "/v1/llm-instances/_system/test-instance", + body: `{"instance_handle": "test-instance", "endpoint": "https://test.com", "api_standard": "openai", "model": "test", "dimensions": 5}`, + expectStatus: http.StatusForbidden, + }, + { + name: "System user cannot delete LLM service instances", + method: http.MethodDelete, + path: "/v1/llm-instances/_system/test-instance", + body: "", + expectStatus: http.StatusForbidden, + }, + } + + // Run all test cases + // Note: We use admin authentication to bypass the owner authentication layer + // so that we can test our validation logic. In practice, both auth and validation + // will prevent _system from making requests. + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody := io.Reader(nil) + if tc.body != "" { + reqBody = bytes.NewReader([]byte(tc.body)) + } + + requestURL := fmt.Sprintf("http://%v:%d%v", options.Host, options.Port, tc.path) + req, err := http.NewRequest(tc.method, requestURL, reqBody) + assert.NoError(t, err) + // Use admin key to bypass owner authentication and reach our validation layer + req.Header.Set("Authorization", "Bearer "+options.AdminKey) + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.expectStatus, resp.StatusCode, + "Expected %s request to %s to return status %d, got %d", + tc.method, tc.path, tc.expectStatus, resp.StatusCode) + + // Verify error message contains expected text + if resp.StatusCode == http.StatusForbidden { + respBody, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + + var errorResp struct { + Detail string `json:"detail"` + } + err = json.Unmarshal(respBody, &errorResp) + assert.NoError(t, err) + assert.Contains(t, errorResp.Detail, "cannot send requests", + "Error message should indicate _system user cannot send requests") + } + }) + } +} diff --git a/internal/handlers/users.go b/internal/handlers/users.go index e11d02c..55958a3 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -17,6 +17,11 @@ import ( // putUserFunc creates or updates a user func putUserFunc(ctx context.Context, input *models.PutUserRequest) (*models.UploadUserResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + if input.UserHandle != input.Body.UserHandle { return nil, huma.Error400BadRequest(fmt.Sprintf("user handle in URL (%s) does not match user handle in body (%v).", input.UserHandle, input.Body.UserHandle)) } @@ -236,6 +241,11 @@ func getUserFunc(ctx context.Context, input *models.GetUserRequest) (*models.Get // Delete a specific user func deleteUserFunc(ctx context.Context, input *models.DeleteUserRequest) (*models.DeleteUserResponse, error) { + // Validate that _system user cannot send requests + if err := ValidateNotSystemUser(input.UserHandle); err != nil { + return nil, err + } + // Get the database connection pool from the context pool, err := GetDBPool(ctx) if err != nil { diff --git a/internal/handlers/validation.go b/internal/handlers/validation.go index 0291a72..643853f 100644 --- a/internal/handlers/validation.go +++ b/internal/handlers/validation.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "github.com/danielgtaylor/huma/v2" "github.com/mpilhlt/dhamps-vdb/internal/models" "github.com/xeipuuv/gojsonschema" ) @@ -97,3 +98,12 @@ func ValidateEmbeddingMetadataAgainstProjectSchema(metadata json.RawMessage, sch } return nil } + +// ValidateNotSystemUser checks if a user handle is "_system" and returns an error if so +// The _system user is read-only and cannot send requests +func ValidateNotSystemUser(userHandle string) error { + if userHandle == "_system" { + return huma.Error403Forbidden("_system user cannot send requests - this is a read-only account") + } + return nil +} diff --git a/internal/handlers/validation_unit_test.go b/internal/handlers/validation_unit_test.go index 74cf225..7d93bee 100644 --- a/internal/handlers/validation_unit_test.go +++ b/internal/handlers/validation_unit_test.go @@ -166,3 +166,44 @@ func TestValidateMetadataAgainstSchema(t *testing.T) { }) } } + +func TestValidateNotSystemUser(t *testing.T) { + tests := []struct { + name string + userHandle string + wantErr bool + errContains string + }{ + { + name: "Valid regular user", + userHandle: "alice", + wantErr: false, + }, + { + name: "Valid user with underscore", + userHandle: "alice_smith", + wantErr: false, + }, + { + name: "System user should be rejected", + userHandle: "_system", + wantErr: true, + errContains: "cannot send requests", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateNotSystemUser(tt.userHandle) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateNotSystemUser() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil && tt.errContains != "" { + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("ValidateNotSystemUser() error = %v, should contain %v", err.Error(), tt.errContains) + } + } + }) + } +}