diff --git a/README.md b/README.md index 52f32c8..6478a0e 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,50 @@ values := []any{"aztec", "nuke", "", 2, 10} (given "customdata" is configured with `filter.WithNestedJSONB("customdata", "password", "playerCount")`) +## Order By Support + +In addition to filtering, this package also supports converting MongoDB-style sort objects into PostgreSQL ORDER BY clauses using the `ConvertOrderBy` method: + +```go +// Convert a sort object to an ORDER BY clause +sortInput := []byte(`{"playerCount": -1, "name": 1}`) +orderBy, err := converter.ConvertOrderBy(sortInput) +if err != nil { + // handle error +} +fmt.Println(orderBy) // "playerCount" DESC, "name" ASC + +db.Query("SELECT * FROM games ORDER BY " + orderBy) +``` + +### Sort Direction Values: +- `1`: Ascending (ASC) +- `-1`: Descending (DESC) + +### Return value +The `ConvertOrderBy` method returns a string that can be directly used in an SQL ORDER BY clause. When the input is an empty object or `nil`, it returns an empty string. Keep in mind that the method does not add the `ORDER BY` keyword itself; you need to include it in your SQL query. + +### JSONB Field Sorting: +For JSONB fields, the package generates sophisticated ORDER BY clauses that handle both numeric and text sorting: + +```go +// With WithNestedJSONB("metadata", "created_at"): +sortInput := []byte(`{"score": -1}`) +orderBy, err := converter.ConvertOrderBy(sortInput) +// Generates: (CASE WHEN jsonb_typeof(metadata->'score') = 'number' THEN (metadata->>'score')::numeric END) DESC NULLS LAST, metadata->>'score' DESC NULLS LAST +``` + +This ensures proper sorting whether the JSONB field contains numeric or text values. + +> [!TIP] +> Always add an `, id ASC` to your ORDER BY clause to ensure a consistent order (where `id` is your primary key). +> ```go +> if orderBy != "" { +> orderBy += ", " +> } +> orderBy += "id ASC" +> ``` + ## Difference with MongoDB - The MongoDB query filters don't have the option to compare fields with each other. This package adds the `$field` operator to compare fields with each other. diff --git a/filter/converter.go b/filter/converter.go index ed89c33..c81b3af 100644 --- a/filter/converter.go +++ b/filter/converter.go @@ -208,7 +208,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string // `column != ANY(...)` does not work, so we need to do `NOT column = ANY(...)` instead. neg = "NOT " } - inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key), paramIndex)) + inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key, true), paramIndex)) paramIndex++ if c.arrayDriver != nil { v[operator] = c.arrayDriver(v[operator]) @@ -245,7 +245,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string // // EXISTS (SELECT 1 FROM unnest("foo") AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1)) // - inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key), c.placeholderName, innerConditions)) + inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key, true), c.placeholderName, innerConditions)) } values = append(values, innerValues...) case "$field": @@ -254,7 +254,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string return "", nil, fmt.Errorf("invalid value for $field operator (must be string): %v", v[operator]) } - inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key), c.columnName(vv))) + inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key, true), c.columnName(vv, true))) default: value := v[operator] isNumericOperator := false @@ -274,8 +274,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string return "", nil, fmt.Errorf("invalid value for %s operator (must be object with $field key only): %v", operator, value) } - left := c.columnName(key) - right := c.columnName(field) + left := c.columnName(key, true) + right := c.columnName(field, true) if isNumericOperator { if c.isNestedColumn(key) { @@ -304,9 +304,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string } if isNumericOperator && isNumeric(value) && c.isNestedColumn(key) { - inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key), op, paramIndex)) + inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key, true), op, paramIndex)) } else { - inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex)) + inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key, true), op, paramIndex)) } paramIndex++ values = append(values, value) @@ -329,9 +329,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string } } if isNestedColumn { - conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key))) + conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key, true))) } else { - conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key))) + conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key, true))) } default: // Prevent cryptic errors like: @@ -341,9 +341,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string } if isNumeric(value) && c.isNestedColumn(key) { // If the value is numeric and the column is a nested JSONB column, we need to cast the column to numeric. - conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key), paramIndex)) + conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key, true), paramIndex)) } else { - conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key), paramIndex)) + conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key, true), paramIndex)) } paramIndex++ values = append(values, value) @@ -358,7 +358,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string return result, values, nil } -func (c *Converter) columnName(column string) string { +func (c *Converter) columnName(column string, jsonFieldAsText bool) string { if column == c.placeholderName { return fmt.Sprintf(`%q::text`, column) } @@ -370,7 +370,10 @@ func (c *Converter) columnName(column string) string { return fmt.Sprintf("%q", column) } } - return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column) + if jsonFieldAsText { + return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column) + } + return fmt.Sprintf(`%q->'%s'`, c.nestedColumn, column) } func (c *Converter) isColumnAllowed(column string) bool { @@ -404,3 +407,77 @@ func (c *Converter) isNestedColumn(column string) bool { } return true } + +// ConvertOrderBy converts a JSON object with field names and sort directions +// into a PostgreSQL ORDER BY clause. The JSON object should have keys with values +// of 1 (ASC) or -1 (DESC). +// +// For JSONB fields, it generates clauses that handle both numeric and text sorting. +// +// Example: {"playerCount": -1, "name": 1} -> "playerCount DESC, name ASC" +func (c *Converter) ConvertOrderBy(query []byte) (string, error) { + keyValues, err := objectInOrder(query) + if err != nil { + return "", err + } + + parts := make([]string, 0, len(keyValues)) + + for _, kv := range keyValues { + key, value := kv.Key, kv.Value + + if !isValidPostgresIdentifier(key) { + return "", fmt.Errorf("invalid column name: %s", key) + } + if !c.isColumnAllowed(key) { + return "", ColumnNotAllowedError{Column: key} + } + + // Convert value to number for direction + var direction string + switch v := value.(type) { + case json.Number: + if num, err := v.Int64(); err == nil { + switch num { + case 1: + direction = "ASC" + case -1: + direction = "DESC" + default: + return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value) + } + } else { + return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value) + } + case float64: + switch v { + case 1: + direction = "ASC" + case -1: + direction = "DESC" + default: + return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value) + } + default: + return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value) + } + + var fieldClause string + if c.isNestedColumn(key) { + // For JSONB fields, handle both numeric and text sorting. + // We need to use the raw JSONB reference for jsonb_typeof, but columnName() for the actual sorting + fieldClause = fmt.Sprintf("(CASE WHEN jsonb_typeof(%s) = 'number' THEN (%s)::numeric END) %s NULLS LAST, %s %s NULLS LAST", c.columnName(key, false), c.columnName(key, true), direction, c.columnName(key, true), direction) + } else { + // Regular field. + fieldClause = fmt.Sprintf(`%s %s NULLS LAST`, c.columnName(key, true), direction) + } + + parts = append(parts, fieldClause) + } + + if len(parts) == 0 { + return "", nil + } + + return strings.Join(parts, ", "), nil +} diff --git a/filter/converter_test.go b/filter/converter_test.go index 9916744..ea14524 100644 --- a/filter/converter_test.go +++ b/filter/converter_test.go @@ -641,3 +641,134 @@ func TestConverter_AccessControl(t *testing.T) { t.Run("nested but disallow password, disallow", f(`{"password": "hacks"}`, no("password"), filter.WithNestedJSONB("meta", "created_at"), filter.WithDisallowColumns("password"))) } + +func TestConverter_ConvertOrderBy(t *testing.T) { + tests := []struct { + name string + options []filter.Option + input string + expected string + err error + }{ + { + "single field ascending", + []filter.Option{filter.WithAllowAllColumns()}, + `{"playerCount": 1}`, + `"playerCount" ASC NULLS LAST`, + nil, + }, + { + "single field descending", + []filter.Option{filter.WithAllowAllColumns()}, + `{"playerCount": -1}`, + `"playerCount" DESC NULLS LAST`, + nil, + }, + { + "multiple fields", + []filter.Option{filter.WithAllowAllColumns()}, + `{"playerCount": -1, "name": 1}`, + `"playerCount" DESC NULLS LAST, "name" ASC NULLS LAST`, + nil, + }, + { + "nested JSONB single field ascending", + []filter.Option{filter.WithNestedJSONB("customdata", "created_at")}, + `{"map": 1}`, + `(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST`, + nil, + }, + { + "nested JSONB single field descending", + []filter.Option{filter.WithNestedJSONB("customdata", "created_at")}, + `{"map": -1}`, + `(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`, + nil, + }, + { + "nested JSONB multiple fields", + []filter.Option{filter.WithNestedJSONB("customdata", "created_at")}, + `{"map": 1, "bar": -1}`, + `(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'bar') = 'number' THEN ("customdata"->>'bar')::numeric END) DESC NULLS LAST, "customdata"->>'bar' DESC NULLS LAST`, + nil, + }, + { + "mixed nested and regular fields", + []filter.Option{filter.WithNestedJSONB("customdata", "created_at")}, + `{"created_at": 1, "map": -1}`, + `"created_at" ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`, + nil, + }, + { + "field name with spaces", + []filter.Option{filter.WithAllowAllColumns()}, + `{"my_field": 1}`, + `"my_field" ASC NULLS LAST`, + nil, + }, + { + "empty object", + []filter.Option{filter.WithAllowAllColumns()}, + `{}`, + ``, + nil, + }, + { + "invalid field name for SQL injection", + []filter.Option{filter.WithAllowAllColumns()}, + `{"my field": 1}`, + ``, + fmt.Errorf("invalid column name: my field"), + }, + { + "invalid direction value", + []filter.Option{filter.WithAllowAllColumns()}, + `{"playerCount": 2}`, + ``, + fmt.Errorf("invalid order direction for field playerCount: 2 (must be 1 or -1)"), + }, + { + "invalid direction string", + []filter.Option{filter.WithAllowAllColumns()}, + `{"playerCount": "asc"}`, + ``, + fmt.Errorf("invalid order direction for field playerCount: asc (must be 1 or -1)"), + }, + { + "disallowed column", + []filter.Option{filter.WithAllowColumns("name")}, + `{"playerCount": 1}`, + ``, + filter.ColumnNotAllowedError{Column: "playerCount"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + converter, err := filter.NewConverter(tt.options...) + if err != nil { + t.Fatalf("Failed to create converter: %v", err) + } + + result, err := converter.ConvertOrderBy([]byte(tt.input)) + + if tt.err != nil { + if err == nil { + t.Fatalf("Expected error %v, got nil", tt.err) + } + if err.Error() != tt.err.Error() { + t.Errorf("Expected error %v, got %v", tt.err, err) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} diff --git a/filter/errors.go b/filter/errors.go index 480e173..cae98d6 100644 --- a/filter/errors.go +++ b/filter/errors.go @@ -12,3 +12,12 @@ type ColumnNotAllowedError struct { func (e ColumnNotAllowedError) Error() string { return fmt.Sprintf("column not allowed: %s", e.Column) } + +type InvalidOrderDirectionError struct { + Field string + Value any +} + +func (e InvalidOrderDirectionError) Error() string { + return fmt.Sprintf("invalid order direction for field %s: %v (must be 1 or -1)", e.Field, e.Value) +} diff --git a/filter/util.go b/filter/util.go index e94f99c..830519e 100644 --- a/filter/util.go +++ b/filter/util.go @@ -1,5 +1,11 @@ package filter +import ( + "bytes" + "encoding/json" + "fmt" +) + func isNumeric(v any) bool { // json.Unmarshal returns float64 for all numbers // so we only need to check for float64. @@ -71,3 +77,55 @@ func isValidPostgresIdentifier(s string) bool { return true } + +func objectInOrder(b []byte) ([]struct { + Key string + Value any +}, error) { + dec := json.NewDecoder(bytes.NewReader(b)) + + // expect { + tok, err := dec.Token() + if err != nil { + return nil, err + } + if d, ok := tok.(json.Delim); !ok || d != '{' { + return nil, fmt.Errorf("expected object, got %v", tok) + } + + var result []struct { + Key string + Value any + } + + for dec.More() { + // key + tok, err := dec.Token() + if err != nil { + return nil, err + } + key, ok := tok.(string) + if !ok { + return nil, fmt.Errorf("expected string key, got %v", tok) + } + + // value + var v any + if err := dec.Decode(&v); err != nil { + return nil, err + } + + result = append(result, struct { + Key string + Value any + }{Key: key, Value: v}) + } + + // consume } + _, err = dec.Token() + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/fuzz/fuzz_test.go b/fuzz/fuzz_test.go index 8bbed36..0372f08 100644 --- a/fuzz/fuzz_test.go +++ b/fuzz/fuzz_test.go @@ -112,3 +112,95 @@ func FuzzConverter(f *testing.F) { } }) } + +func FuzzConverterOrderBy(f *testing.F) { + tcs := []string{ + `{"level": 1}`, + `{"level": -1}`, + `{"name": 1, "level": -1}`, + `{"created_at": 1}`, + `{"guild_id": -1}`, + `{"pet": 1, "level": -1}`, + `{"class": 1, "level": -1, "name": 1}`, + `{}`, + `{"invalid_direction": 2}`, + `{"invalid_string": "asc"}`, + `{"field_with_spaces": 1}`, + `{"level": 1.0}`, + `{"level": -1.0}`, + `{"level": 0}`, + `{"field_name": -1}`, + `{"validField": 1}`, + `{"user_id": -1, "created_at": 1}`, + } + for _, tc := range tcs { + f.Add(tc, true) + f.Add(tc, false) + } + + f.Fuzz(func(t *testing.T, in string, jsonb bool) { + options := []filter.Option{ + filter.WithAllowAllColumns(), + filter.WithArrayDriver(pq.Array), + } + if jsonb { + options = append(options, filter.WithNestedJSONB("meta", "created_at")) + } + c, _ := filter.NewConverter(options...) + orderBy, err := c.ConvertOrderBy([]byte(in)) + if err == nil && orderBy != "" { + // Test that the generated ORDER BY clause is valid SQL syntax + sql := "SELECT * FROM test ORDER BY " + orderBy + j, err := pg_query.ParseToJSON(sql) + if err != nil { + // If the SQL is invalid, this might indicate a bug in the ConvertOrderBy function + // Log it but don't fail the fuzz test since this helps us find edge cases + t.Logf("Invalid SQL generated for input %q: %q -> error: %v", in, orderBy, err) + return + } + + t.Log("Input:", in, "-> ORDER BY:", orderBy) + + // Parse the JSON to ensure it contains valid ORDER BY structure + var q struct { + Stmts []struct { + Stmt struct { + SelectStmt struct { + FromClause []struct { + RangeVar struct { + Relname string `json:"relname"` + } `json:"RangeVar"` + } `json:"fromClause"` + SortClause []any `json:"sortClause"` + } `json:"SelectStmt"` + } `json:"stmt"` + } `json:"stmts"` + } + if err := json.Unmarshal([]byte(j), &q); err != nil { + t.Fatal(err) + } + + if len(q.Stmts) != 1 { + t.Fatal("Expected exactly 1 statement, got", len(q.Stmts)) + } + + if len(q.Stmts[0].Stmt.SelectStmt.FromClause) != 1 { + t.Fatal("Expected exactly 1 from clause, got", len(q.Stmts[0].Stmt.SelectStmt.FromClause)) + } + + if q.Stmts[0].Stmt.SelectStmt.FromClause[0].RangeVar.Relname != "test" { + t.Fatal("Expected table name 'test', got", q.Stmts[0].Stmt.SelectStmt.FromClause[0].RangeVar.Relname) + } + + // Verify we have sort clauses when ORDER BY is not empty + if len(q.Stmts[0].Stmt.SelectStmt.SortClause) == 0 { + t.Fatal("Expected sort clauses for ORDER BY:", orderBy) + } + + // Check for SQL injection attempts + if strings.Contains(j, "CommentStmt") { + t.Fatal("CommentStmt found - potential SQL injection in:", orderBy) + } + } + }) +} diff --git a/integration/postgres_test.go b/integration/postgres_test.go index 5ada588..541a7e6 100644 --- a/integration/postgres_test.go +++ b/integration/postgres_test.go @@ -632,3 +632,101 @@ func TestIntegration_Logic(t *testing.T) { }) } } + +func TestIntegration_OrderBy(t *testing.T) { + db := setupPQ(t) + + createPlayersTable(t, db) + + tests := []struct { + name string + orderBy string + expectedOrder []int // Expected player IDs in order + converterOpts []filter.Option + }{ + { + "single field ascending", + `{"level": 1}`, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + []filter.Option{filter.WithAllowAllColumns()}, + }, + { + "single field descending", + `{"level": -1}`, + []int{10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + []filter.Option{filter.WithAllowAllColumns()}, + }, + { + "multiple fields", + `{"class": 1, "level": -1}`, + []int{3, 8, 5, 2, 9, 6, 10, 7, 4, 1}, // dog, mage (desc level), rogue (desc level), warrior (desc level) + []filter.Option{filter.WithAllowAllColumns()}, + }, + { + "jsonb field ascending", + `{"guild_id": 1}`, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + []filter.Option{filter.WithNestedJSONB("metadata", "name", "level", "class")}, + }, + { + "jsonb field descending", + `{"guild_id": -1}`, + []int{10, 9, 7, 8, 6, 5, 4, 3, 1, 2}, // 60, 60, 50, 50, 40, 40, 30, 30, 20, 20 with secondary text sort + []filter.Option{filter.WithNestedJSONB("metadata", "name", "level", "class")}, + }, + { + "mixed regular and jsonb fields", + `{"pet": 1, "level": -1}`, + []int{8, 6, 4, 2, 7, 5, 3, 1, 10, 9}, // "cat" (desc level), then "dog" (desc level), then null/missing pets + []filter.Option{filter.WithNestedJSONB("metadata", "name", "level", "class")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := filter.NewConverter(tt.converterOpts...) + if err != nil { + t.Fatal(err) + } + + orderBy, err := c.ConvertOrderBy([]byte(tt.orderBy)) + if err != nil { + t.Fatal(err) + } + + t.Logf("Generated ORDER BY: %s", orderBy) + + rows, err := db.Query(` + SELECT id + FROM players + ORDER BY ` + orderBy + `; + `) + if err != nil { + t.Fatal(err) + } + + var playerIDs []int + for rows.Next() { + var id int + if err := rows.Scan(&id); err != nil { + t.Fatal(err) + } + playerIDs = append(playerIDs, id) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + + if len(playerIDs) != len(tt.expectedOrder) { + t.Fatalf("expected %d players, got %d", len(tt.expectedOrder), len(playerIDs)) + } + + for i, expectedID := range tt.expectedOrder { + if playerIDs[i] != expectedID { + t.Fatalf("at position %d: expected player ID %d, got %d\nExpected order: %v\nActual order: %v", + i, expectedID, playerIDs[i], tt.expectedOrder, playerIDs) + } + } + }) + } +}