From c9725213a1b09cedec0c140068888c24fbc40caa Mon Sep 17 00:00:00 2001 From: Michael Caulley Date: Mon, 2 Oct 2023 19:12:37 -0400 Subject: [PATCH] feat: allow aggregate ordering of through table edges --- entgql/internal/todo/ent.graphql | 1 + entgql/internal/todo/ent.resolvers.go | 1 + entgql/internal/todo/ent/gql_pagination.go | 27 +++++++- entgql/internal/todo/ent/schema/user.go | 5 +- entgql/internal/todo/todo_test.go | 78 ++++++++++++++++++++++ entgql/internal/todogotype/generated.go | 5 +- entgql/internal/todopulid/generated.go | 1 + entgql/internal/todouuid/generated.go | 1 + entgql/template.go | 9 ++- entgql/testdata/schema.graphql | 1 + entgql/testdata/schema_relay.graphql | 1 + 11 files changed, 124 insertions(+), 6 deletions(-) diff --git a/entgql/internal/todo/ent.graphql b/entgql/internal/todo/ent.graphql index 03c085a81..444f1d58a 100644 --- a/entgql/internal/todo/ent.graphql +++ b/entgql/internal/todo/ent.graphql @@ -1387,6 +1387,7 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } """ UserWhereInput is used for filtering User objects. diff --git a/entgql/internal/todo/ent.resolvers.go b/entgql/internal/todo/ent.resolvers.go index 00ac2b24d..671ccd23c 100644 --- a/entgql/internal/todo/ent.resolvers.go +++ b/entgql/internal/todo/ent.resolvers.go @@ -79,6 +79,7 @@ func (r *queryResolver) Users(ctx context.Context, after *entgql.Cursor[int], fi return r.client.User.Query(). Paginate(ctx, after, first, before, last, ent.WithUserFilter(where.Filter), + ent.WithUserOrder(orderBy), ) } diff --git a/entgql/internal/todo/ent/gql_pagination.go b/entgql/internal/todo/ent/gql_pagination.go index 89d25786a..2d4f94032 100644 --- a/entgql/internal/todo/ent/gql_pagination.go +++ b/entgql/internal/todo/ent/gql_pagination.go @@ -2492,7 +2492,7 @@ func (p *userPager) applyOrder(query *UserQuery) *UserQuery { query = query.Order(DefaultUserOrder.Field.toTerm(direction.OrderTermOption())) } switch p.order.Field.column { - case UserOrderFieldGroupsCount.column: + case UserOrderFieldGroupsCount.column, UserOrderFieldFriendsCount.column: default: if len(query.ctx.Fields) > 0 { query.ctx.AppendFieldOnce(p.order.Field.column) @@ -2507,7 +2507,7 @@ func (p *userPager) orderExpr(query *UserQuery) sql.Querier { direction = direction.Reverse() } switch p.order.Field.column { - case UserOrderFieldGroupsCount.column: + case UserOrderFieldGroupsCount.column, UserOrderFieldFriendsCount.column: query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) default: if len(query.ctx.Fields) > 0 { @@ -2595,6 +2595,25 @@ var ( } }, } + // UserOrderFieldFriendsCount orders by FRIENDS_COUNT. + UserOrderFieldFriendsCount = &UserOrderField{ + Value: func(u *User) (ent.Value, error) { + return u.Value("friends_count") + }, + column: "friends_count", + toTerm: func(opts ...sql.OrderTermOption) user.OrderOption { + return user.ByFriendsCount( + append(opts, sql.OrderSelectAs("friends_count"))..., + ) + }, + toCursor: func(u *User) Cursor { + cv, _ := u.Value("friends_count") + return Cursor{ + ID: u.ID, + Value: cv, + } + }, + } ) // String implement fmt.Stringer interface. @@ -2603,6 +2622,8 @@ func (f UserOrderField) String() string { switch f.column { case UserOrderFieldGroupsCount.column: str = "GROUPS_COUNT" + case UserOrderFieldFriendsCount.column: + str = "FRIENDS_COUNT" } return str } @@ -2621,6 +2642,8 @@ func (f *UserOrderField) UnmarshalGQL(v interface{}) error { switch str { case "GROUPS_COUNT": *f = *UserOrderFieldGroupsCount + case "FRIENDS_COUNT": + *f = *UserOrderFieldFriendsCount default: return fmt.Errorf("%s is not a valid UserOrderField", str) } diff --git a/entgql/internal/todo/ent/schema/user.go b/entgql/internal/todo/ent/schema/user.go index c37be6637..149786f49 100644 --- a/entgql/internal/todo/ent/schema/user.go +++ b/entgql/internal/todo/ent/schema/user.go @@ -55,7 +55,10 @@ func (User) Edges() []ent.Edge { ), edge.To("friends", User.Type). Through("friendships", Friendship.Type). - Annotations(entgql.RelayConnection()), + Annotations( + entgql.RelayConnection(), + entgql.OrderField("FRIENDS_COUNT"), + ), } } diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 7e2136c7b..a70209838 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -1989,6 +1989,84 @@ func TestMutation_ClearChildren(t *testing.T) { require.False(t, root.QueryChildren().ExistX(ctx)) } +func TestQuery_SortUserByFriendshipsCount(t *testing.T) { + ec := enttest.Open(t, dialect.SQLite, + fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()), + enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)), + ) + srv := handler.NewDefaultServer(gen.NewSchema(ec)) + srv.Use(entgql.Transactioner{TxOpener: ec}) + gqlc := client.New(srv) + + ctx := context.Background() + user := ec.User.Create().SetRequiredMetadata(map[string]any{}).SaveX(ctx) + friend := ec.User.Create().SetRequiredMetadata(map[string]any{}).AddFriends(user).SaveX(ctx) + secondFriend := ec.User.Create().SetRequiredMetadata(map[string]any{}).AddFriends(user, friend).SaveX(ctx) + thirdFried := ec.User.Create().SetRequiredMetadata(map[string]any{}).AddFriends(user).SaveX(ctx) + + require.True(t, user.QueryFriends().ExistX(ctx)) + require.True(t, friend.QueryFriends().ExistX(ctx)) + require.True(t, secondFriend.QueryFriends().ExistX(ctx)) + require.True(t, thirdFried.QueryFriends().ExistX(ctx)) + + var rsp struct { + Users struct { + Edges []struct { + Node struct { + ID string + Friends struct { + TotalCount int + } + } + } + } + } + + query := ` + query testThroughEdgeOrderByQuery($orderDirection: OrderDirection!) { + users (orderBy:{ field: FRIENDS_COUNT, direction: $orderDirection }) { + edges { + node { + id + friends { + totalCount + } + } + } + } + } + ` + + testCases := []struct { + direction string + expectedCountOrder []int + }{ + { + direction: "DESC", + expectedCountOrder: []int{3, 2, 2, 1}, + }, + { + direction: "ASC", + expectedCountOrder: []int{1, 2, 2, 3}, + }, + } + + for _, tc := range testCases { + t.Run(tc.direction, func(t *testing.T) { + err := gqlc.Post( + query, + &rsp, + client.Var("orderDirection", tc.direction)) + + require.NoError(t, err) + require.Len(t, rsp.Users.Edges, 4) + for i, edge := range rsp.Users.Edges { + require.Equal(t, edge.Node.Friends.TotalCount, tc.expectedCountOrder[i]) + } + }) + } +} + func TestMutation_ClearFriend(t *testing.T) { ec := enttest.Open(t, dialect.SQLite, fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()), diff --git a/entgql/internal/todogotype/generated.go b/entgql/internal/todogotype/generated.go index 7b7091bad..a4ea604b1 100644 --- a/entgql/internal/todogotype/generated.go +++ b/entgql/internal/todogotype/generated.go @@ -2796,6 +2796,7 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } """ UserWhereInput is used for filtering User objects. @@ -20473,7 +20474,7 @@ func (ec *executionContext) marshalOInt2áš–int(ctx context.Context, sel ast.Sele return res } -func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[string]any, error) { +func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[string]interface{}, error) { if v == nil { return nil, nil } @@ -20481,7 +20482,7 @@ func (ec *executionContext) unmarshalOMap2map(ctx context.Context, v any) (map[s return res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) marshalOMap2map(ctx context.Context, sel ast.SelectionSet, v map[string]any) graphql.Marshaler { +func (ec *executionContext) marshalOMap2map(ctx context.Context, sel ast.SelectionSet, v map[string]interface{}) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/entgql/internal/todopulid/generated.go b/entgql/internal/todopulid/generated.go index 356ba9d63..df896199c 100644 --- a/entgql/internal/todopulid/generated.go +++ b/entgql/internal/todopulid/generated.go @@ -2802,6 +2802,7 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } """ UserWhereInput is used for filtering User objects. diff --git a/entgql/internal/todouuid/generated.go b/entgql/internal/todouuid/generated.go index 8ebf13494..217e08e40 100644 --- a/entgql/internal/todouuid/generated.go +++ b/entgql/internal/todouuid/generated.go @@ -2803,6 +2803,7 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } """ UserWhereInput is used for filtering User objects. diff --git a/entgql/template.go b/entgql/template.go index 69f20dc3d..8c04b6167 100644 --- a/entgql/template.go +++ b/entgql/template.go @@ -452,12 +452,16 @@ func orderFields(n *gen.Type) ([]*OrderTerm, error) { }) } } + edgeNamesWithThroughTables := make(map[string]interface{}) for _, e := range n.Edges { name := strings.ToUpper(e.Name) switch ant, err := annotation(e.Annotations); { case err != nil: return nil, err case ant.Skip.Is(SkipOrderField), ant.OrderField == "": + case strings.HasSuffix(ant.OrderField, "_COUNT") && + edgeNamesWithThroughTables[strings.TrimSuffix(ant.OrderField, "_COUNT")] != nil: + // skip the through table annotations, annotations are applied on the `edge.To` edge instead case ant.OrderField == fmt.Sprintf("%s_COUNT", name): // Validate that the edge has a count ordering. if _, err := e.OrderCountName(); err != nil { @@ -471,7 +475,7 @@ func orderFields(n *gen.Type) ([]*OrderTerm, error) { Count: true, }) case strings.HasPrefix(ant.OrderField, name+"_"): - // Validate that the edge has a edge field ordering. + // Validate that the edge has an edge field ordering. if _, err := e.OrderFieldName(); err != nil { return nil, fmt.Errorf("entgql: invalid order field %s defined on edge %s.%s: %w", ant.OrderField, n.Name, e.Name, err) } @@ -493,6 +497,9 @@ func orderFields(n *gen.Type) ([]*OrderTerm, error) { default: return nil, fmt.Errorf("entgql: invalid order field defined on edge %s.%s", n.Name, e.Name) } + if e.Through != nil { + edgeNamesWithThroughTables[name] = true + } } return terms, nil } diff --git a/entgql/testdata/schema.graphql b/entgql/testdata/schema.graphql index 289e579b3..182dade7e 100644 --- a/entgql/testdata/schema.graphql +++ b/entgql/testdata/schema.graphql @@ -313,4 +313,5 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } diff --git a/entgql/testdata/schema_relay.graphql b/entgql/testdata/schema_relay.graphql index 8f0d69b91..7b1f7fc18 100644 --- a/entgql/testdata/schema_relay.graphql +++ b/entgql/testdata/schema_relay.graphql @@ -1335,6 +1335,7 @@ Properties by which User connections can be ordered. """ enum UserOrderField { GROUPS_COUNT + FRIENDS_COUNT } """ UserWhereInput is used for filtering User objects.