diff --git a/gino/loader.py b/gino/loader.py index 9cec1114..b565c700 100644 --- a/gino/loader.py +++ b/gino/loader.py @@ -2,6 +2,7 @@ from sqlalchemy import select from sqlalchemy.schema import Column +from sqlalchemy.sql.elements import Label from .declarative import Model @@ -19,6 +20,8 @@ def get(cls, value): rv = AliasLoader(value) elif isinstance(value, Column): rv = ColumnLoader(value) + elif isinstance(value, Label): + rv = ColumnLoader(value.name) elif isinstance(value, tuple): rv = TupleLoader(value) elif callable(value): diff --git a/tests/test_loader.py b/tests/test_loader.py index 3d001ce1..30b6b2ae 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,8 +1,10 @@ import random from datetime import datetime -import pytest from async_generator import yield_, async_generator +import pytest +from sqlalchemy import select +from sqlalchemy.sql.functions import count from gino.loader import AliasLoader from .models import db, User, Team, Company @@ -12,11 +14,11 @@ @pytest.fixture @async_generator -async def user(bind, random_name): +async def user(bind): c = await Company.create() t1 = await Team.create(company_id=c.id) t2 = await Team.create(company_id=c.id, parent_id=t1.id) - u = await User.create(nickname=random_name, team_id=t2.id) + u = await User.create(team_id=t2.id) u.team = t2 t2.parent = t1 t2.company = c @@ -161,6 +163,47 @@ async def test_alias_loader_columns(user): assert u.id is not None +async def test_multiple_models_in_one_query(bind): + for _ in range(3): + await User.create() + + ua1 = User.alias() + ua2 = User.alias() + join_query = select([ua1, ua2]).where(ua1.id < ua2.id) + result = await join_query.gino.load((ua1.load('id'), ua2.load('id'))).all() + assert len(result) == 3 + for u1, u2 in result: + assert u1.id is not None + assert u2.id is not None + assert u1.id < u2.id + + +async def test_loader_with_aggregation(user): + count_col = count().label('count') + user_count = select( + [User.team_id, count_col] + ).group_by( + User.team_id + ).alias() + query = Team.outerjoin(user_count).select() + result = await query.gino.load( + (Team.id, Team.name, user_count.columns.team_id, count_col) + ).all() + assert len(result) == 2 + # team 1 doesn't have users, team 2 has 1 user + # third and forth columns are None for team 1 + for team_id, team_name, user_team_id, user_count in result: + if team_id == user.team_id: + assert team_name == user.team.name + assert user_team_id == user.team_id + assert user_count == 1 + else: + assert team_id is not None + assert team_name is not None + assert user_team_id is None + assert user_count is None + + async def test_adjacency_list_query_builder(user): group = Team.alias() u = await User.load(team=Team.load(parent=group.on(