33from sqlalchemy .orm import relationship
44from copy import deepcopy
55
6- from .._db import Session
6+ from .._db import Session , WriteSession , default_session
77from delphi .epidata .common .logger import get_structured_logger
88
99from typing import Set , Optional , List
@@ -25,7 +25,7 @@ def _default_date_now():
2525class User (Base ):
2626 __tablename__ = "api_user"
2727 id = Column (Integer , primary_key = True , autoincrement = True )
28- roles = relationship ("UserRole" , secondary = association_table )
28+ roles = relationship ("UserRole" , secondary = association_table , lazy = "joined" ) # last arg does an eager load of this property from foreign tables
2929 api_key = Column (String (50 ), unique = True , nullable = False )
3030 email = Column (String (320 ), unique = True , nullable = False )
3131 created = Column (Date , default = _default_date_now )
@@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None:
3535 self .api_key = api_key
3636 self .email = email
3737
38- @staticmethod
39- def list_users () -> List ["User" ]:
40- with Session () as session :
41- return session .query (User ).all ()
42-
4338 @property
4439 def as_dict (self ):
4540 return {
4641 "id" : self .id ,
4742 "api_key" : self .api_key ,
4843 "email" : self .email ,
49- "roles" : User . get_user_roles ( self .id ),
44+ "roles" : set ( role . name for role in self .roles ),
5045 "created" : self .created ,
5146 "last_time_used" : self .last_time_used
5247 }
5348
54- @staticmethod
55- def get_user_roles (user_id : int ) -> Set [str ]:
56- with Session () as session :
57- user = session .query (User ).filter (User .id == user_id ).first ()
58- return set ([role .name for role in user .roles ])
59-
6049 def has_role (self , required_role : str ) -> bool :
61- return required_role in User . get_user_roles ( self .id )
50+ return required_role in set ( role . name for role in self .roles )
6251
6352 @staticmethod
6453 def _assign_roles (user : "User" , roles : Optional [Set [str ]], session ) -> None :
65- # NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`...
66- # that is the responsibility of the caller!
6754 get_structured_logger ("api_user_models" ).info ("setting roles" , roles = roles , user_id = user .id , api_key = user .api_key )
6855 db_user = session .query (User ).filter (User .id == user .id ).first ()
6956 # TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`?
57+ # or even use this as a bound method instead of a static??
58+ # same goes for `update_user()` and `delete_user()` below...
7059 if roles :
71- roles_to_assign = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
72- db_user .roles = roles_to_assign
60+ db_user .roles = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
7361 else :
7462 db_user .roles = []
63+ session .commit ()
64+ # retrieve the newly updated User object
65+ return session .query (User ).filter (User .id == user .id ).first ()
7566
7667 @staticmethod
68+ @default_session (Session )
7769 def find_user (* , # asterisk forces explicit naming of all arguments when calling this method
78- user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
70+ session ,
71+ user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
7972 ) -> "User" :
8073 # NOTE: be careful, using multiple arguments could match multiple users, but this will return only one!
81- with Session () as session :
82- user = (
83- session .query (User )
84- .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
85- .first ()
86- )
74+ user = (
75+ session .query (User )
76+ .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
77+ .first ()
78+ )
8779 return user if user else None
8880
8981 @staticmethod
90- def create_user (api_key : str , email : str , user_roles : Optional [Set [str ]] = None ) -> "User" :
82+ @default_session (WriteSession )
83+ def create_user (api_key : str , email : str , session , user_roles : Optional [Set [str ]] = None ) -> "User" :
9184 get_structured_logger ("api_user_models" ).info ("creating user" , api_key = api_key )
92- with Session () as session :
93- new_user = User (api_key = api_key , email = email )
94- # TODO: we may need to populate 'created' field/column here, if the default
95- # specified above gets bound to the time of when that line of python was evaluated.
96- session .add (new_user )
97- session .commit ()
98- User ._assign_roles (new_user , user_roles , session )
99- session .commit ()
100- return new_user
85+ new_user = User (api_key = api_key , email = email )
86+ session .add (new_user )
87+ session .commit ()
88+ return User ._assign_roles (new_user , user_roles , session )
10189
10290 @staticmethod
91+ @default_session (WriteSession )
10392 def update_user (
10493 user : "User" ,
10594 email : Optional [str ],
10695 api_key : Optional [str ],
107- roles : Optional [Set [str ]]
96+ roles : Optional [Set [str ]],
97+ session
10898 ) -> "User" :
10999 get_structured_logger ("api_user_models" ).info ("updating user" , user_id = user .id , new_api_key = api_key )
110- with Session () as session :
111- user = User .find_user (user_id = user .id )
112- if user :
113- update_stmt = (
114- update (User )
115- .where (User .id == user .id )
116- .values (api_key = api_key , email = email )
117- )
118- session .execute (update_stmt )
119- User ._assign_roles (user , roles , session )
120- session .commit ()
121- return user
100+ user = User .find_user (user_id = user .id , session = session )
101+ if not user :
102+ raise Exception ('user not found' )
103+ update_stmt = (
104+ update (User )
105+ .where (User .id == user .id )
106+ .values (api_key = api_key , email = email )
107+ )
108+ session .execute (update_stmt )
109+ return User ._assign_roles (user , roles , session )
122110
123111 @staticmethod
124- def delete_user (user_id : int ) -> None :
112+ @default_session (WriteSession )
113+ def delete_user (user_id : int , session ) -> None :
125114 get_structured_logger ("api_user_models" ).info ("deleting user" , user_id = user_id )
126- with Session () as session :
127- session .execute (delete (User ).where (User .id == user_id ))
128- session .commit ()
115+ session .execute (delete (User ).where (User .id == user_id ))
116+ session .commit ()
129117
130118
131119class UserRole (Base ):
@@ -134,23 +122,23 @@ class UserRole(Base):
134122 name = Column (String (50 ), unique = True )
135123
136124 @staticmethod
137- def create_role (name : str ) -> None :
125+ @default_session (WriteSession )
126+ def create_role (name : str , session ) -> None :
138127 get_structured_logger ("api_user_models" ).info ("creating user role" , role = name )
139- with Session () as session :
140- session .execute (
141- f"""
128+ # TODO: check role doesnt already exist
129+ session .execute (f"""
142130 INSERT INTO user_role (name)
143131 SELECT '{ name } '
144132 WHERE NOT EXISTS
145133 (SELECT *
146134 FROM user_role
147135 WHERE name='{ name } ')
148- """
149- )
150- session . commit ()
136+ """ )
137+ session . commit ( )
138+ return session . query ( UserRole ). filter ( UserRole . name == name ). first ()
151139
152140 @staticmethod
153- def list_all_roles ():
154- with Session () as session :
155- roles = session .query (UserRole ).all ()
141+ @ default_session ( Session )
142+ def list_all_roles ( session ) :
143+ roles = session .query (UserRole ).all ()
156144 return [role .name for role in roles ]
0 commit comments