diff --git a/recipe/session/getClaimValue_test.go b/recipe/session/getClaimValue_test.go index ba66f832..157c59ac 100644 --- a/recipe/session/getClaimValue_test.go +++ b/recipe/session/getClaimValue_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" "github.com/supertokens/supertokens-golang/supertokens" "github.com/supertokens/supertokens-golang/test/unittesting" diff --git a/recipe/session/recipeImplementation.go b/recipe/session/recipeImplementation.go index cc5f7134..a8db7e5e 100644 --- a/recipe/session/recipeImplementation.go +++ b/recipe/session/recipeImplementation.go @@ -25,6 +25,7 @@ import ( "time" "github.com/MicahParks/keyfunc/v2" + "github.com/supertokens/supertokens-golang/recipe/multitenancy/multitenancymodels" "github.com/supertokens/supertokens-golang/recipe/session/claims" "github.com/supertokens/supertokens-golang/recipe/session/errors" @@ -153,7 +154,6 @@ func GetCombinedJWKS() (*keyfunc.JWKS, error) { } func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) sessmodels.RecipeInterface { - var result sessmodels.RecipeInterface createNewSession := func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, tenantId string, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { supertokens.LogDebugMessage("createNewSession: Started") @@ -174,7 +174,13 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ frontToken := BuildFrontToken(sessionResponse.Session.UserID, sessionResponse.AccessToken.Expiry, parsedJWT.Payload) session := sessionResponse.Session - sessionContainerInput := makeSessionContainerInput(sessionResponse.AccessToken.Token, session.Handle, session.UserID, session.TenantId, parsedJWT.Payload, result, frontToken, sessionResponse.AntiCsrfToken, nil, &sessionResponse.RefreshToken, true) + + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return nil, err + } + + sessionContainerInput := makeSessionContainerInput(sessionResponse.AccessToken.Token, session.Handle, session.UserID, session.TenantId, parsedJWT.Payload, recipe.RecipeImpl, frontToken, sessionResponse.AntiCsrfToken, nil, &sessionResponse.RefreshToken, true) return newSessionContainer(config, &sessionContainerInput), nil } @@ -283,7 +289,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ frontToken := BuildFrontToken(response.Session.UserID, response.Session.ExpiryTime, payload) session := response.Session - sessionContainerInput := makeSessionContainerInput(accessTokenStringForSession, session.Handle, session.UserID, session.TenantId, payload, result, frontToken, antiCsrfToken, nil, nil, !accessTokenNil) + recipeInstance, err := getRecipeInstanceOrThrowError() + if err != nil { + return nil, err + } + + sessionContainerInput := makeSessionContainerInput(accessTokenStringForSession, session.Handle, session.UserID, session.TenantId, payload, recipeInstance.RecipeImpl, frontToken, antiCsrfToken, nil, nil, !accessTokenNil) sessionContainer := newSessionContainer(config, &sessionContainerInput) return sessionContainer, nil @@ -314,7 +325,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ session := response.Session frontToken := BuildFrontToken(session.UserID, response.AccessToken.Expiry, responseToken.Payload) - sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, session.TenantId, responseToken.Payload, result, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true) + recipeInstance, err := getRecipeInstanceOrThrowError() + if err != nil { + return nil, err + } + + sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, session.TenantId, responseToken.Payload, recipeInstance.RecipeImpl, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true) sessionContainer := newSessionContainer(config, &sessionContainerInput) return sessionContainer, nil @@ -345,7 +361,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ } mergeIntoAccessTokenPayload := func(sessionHandle string, accessTokenPayloadUpdate map[string]interface{}, userContext supertokens.UserContext) (bool, error) { - sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext) + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return false, err + } + + sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext) if err != nil { return false, err } @@ -433,7 +454,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ } fetchAndSetClaim := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (bool, error) { - sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext) + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return false, err + } + + sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext) if err != nil { return false, err } @@ -444,16 +470,24 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ if err != nil { return false, err } - return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) + return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) } setClaimValue := func(sessionHandle string, claim *claims.TypeSessionClaim, value interface{}, userContext supertokens.UserContext) (bool, error) { + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return false, err + } accessTokenPayloadUpdate := claim.AddToPayload_internal(map[string]interface{}{}, value, userContext) - return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) + return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) } getClaimValue := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (sessmodels.GetClaimValueResult, error) { - sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext) + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return sessmodels.GetClaimValueResult{}, err + } + sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext) if err != nil { return sessmodels.GetClaimValueResult{}, err } @@ -471,10 +505,16 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ } removeClaim := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (bool, error) { + recipe, err := getRecipeInstanceOrThrowError() + if err != nil { + return false, err + } + accessTokenPayloadUpdate := claim.RemoveFromPayloadByMerge_internal(map[string]interface{}{}, userContext) - return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) + return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext) } - result = sessmodels.RecipeInterface{ + + return sessmodels.RecipeInterface{ CreateNewSession: &createNewSession, GetSession: &getSession, RefreshSession: &refreshSession, @@ -496,5 +536,4 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ RemoveClaim: &removeClaim, } - return result } diff --git a/recipe/session/sessionRequestFunctions.go b/recipe/session/sessionRequestFunctions.go index a714ac7d..aae758d5 100644 --- a/recipe/session/sessionRequestFunctions.go +++ b/recipe/session/sessionRequestFunctions.go @@ -22,7 +22,6 @@ import ( "strconv" "github.com/supertokens/supertokens-golang/recipe/session/claims" - "github.com/supertokens/supertokens-golang/recipe/session/errors" "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" "github.com/supertokens/supertokens-golang/supertokens"