diff --git a/lib/passport/passport.serializer.ts b/lib/passport/passport.serializer.ts index 066095b5..e267fe91 100644 --- a/lib/passport/passport.serializer.ts +++ b/lib/passport/passport.serializer.ts @@ -1,16 +1,47 @@ +import { IncomingMessage } from 'http'; import * as passport from 'passport'; -export abstract class PassportSerializer { - abstract serializeUser(user: any, done: Function); - abstract deserializeUser(payload: any, done: Function); +export abstract class PassportSerializer< + UserType extends unknown = unknown, + PayloadType extends unknown = unknown, + RequestType extends IncomingMessage = IncomingMessage +> { + abstract serializeUser( + user: UserType, + req?: RequestType + ): Promise | PayloadType; + abstract deserializeUser( + payload: PayloadType, + req?: RequestType + ): Promise | UserType; constructor() { const passportInstance = this.getPassportInstance(); - passportInstance.serializeUser((user, done) => - this.serializeUser(user, done) + passportInstance.serializeUser( + async ( + req: RequestType, + user: UserType, + done: (err: unknown, payload?: PayloadType) => unknown + ) => { + try { + done(null, await this.serializeUser(user, req)); + } catch (err) { + done(err); + } + } ); - passportInstance.deserializeUser((payload, done) => - this.deserializeUser(payload, done) + passportInstance.deserializeUser( + async ( + req: RequestType, + payload: PayloadType, + done: (err: unknown, user?: UserType) => unknown + ) => { + try { + done(null, await this.deserializeUser(payload, req)); + } catch (err) { + done(err); + } + } ); }