Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipe/session/getClaimValue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
"github.com/supertokens/supertokens-golang/supertokens"
"github.com/supertokens/supertokens-golang/test/unittesting"
Expand Down
63 changes: 51 additions & 12 deletions recipe/session/recipeImplementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/MicahParks/keyfunc/v2"

"github.com/supertokens/supertokens-golang/recipe/multitenancy/multitenancymodels"
"github.com/supertokens/supertokens-golang/recipe/session/claims"
"github.com/supertokens/supertokens-golang/recipe/session/errors"
Expand Down Expand Up @@ -153,7 +154,6 @@ func GetCombinedJWKS() (*keyfunc.JWKS, error) {
}

func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) sessmodels.RecipeInterface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best way to solve this problem is to convert the return type and var result to *sessmodels.RecipeInterface

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would resolve it, cause the Functions override function is by value:

type OverrideStruct struct {
	Functions     func(originalImplementation RecipeInterface) RecipeInterface
	APIs          func(originalImplementation APIInterface) APIInterface
	OpenIdFeature *openidmodels.OverrideStruct
}

var result sessmodels.RecipeInterface

createNewSession := func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, tenantId string, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
supertokens.LogDebugMessage("createNewSession: Started")
Expand All @@ -174,7 +174,13 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ

frontToken := BuildFrontToken(sessionResponse.Session.UserID, sessionResponse.AccessToken.Expiry, parsedJWT.Payload)
session := sessionResponse.Session
sessionContainerInput := makeSessionContainerInput(sessionResponse.AccessToken.Token, session.Handle, session.UserID, session.TenantId, parsedJWT.Payload, result, frontToken, sessionResponse.AntiCsrfToken, nil, &sessionResponse.RefreshToken, true)

recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return nil, err
}

sessionContainerInput := makeSessionContainerInput(sessionResponse.AccessToken.Token, session.Handle, session.UserID, session.TenantId, parsedJWT.Payload, recipe.RecipeImpl, frontToken, sessionResponse.AntiCsrfToken, nil, &sessionResponse.RefreshToken, true)
return newSessionContainer(config, &sessionContainerInput), nil
}

Expand Down Expand Up @@ -283,7 +289,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
frontToken := BuildFrontToken(response.Session.UserID, response.Session.ExpiryTime, payload)
session := response.Session

sessionContainerInput := makeSessionContainerInput(accessTokenStringForSession, session.Handle, session.UserID, session.TenantId, payload, result, frontToken, antiCsrfToken, nil, nil, !accessTokenNil)
recipeInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
return nil, err
}

sessionContainerInput := makeSessionContainerInput(accessTokenStringForSession, session.Handle, session.UserID, session.TenantId, payload, recipeInstance.RecipeImpl, frontToken, antiCsrfToken, nil, nil, !accessTokenNil)
sessionContainer := newSessionContainer(config, &sessionContainerInput)

return sessionContainer, nil
Expand Down Expand Up @@ -314,7 +325,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
session := response.Session
frontToken := BuildFrontToken(session.UserID, response.AccessToken.Expiry, responseToken.Payload)

sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, session.TenantId, responseToken.Payload, result, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true)
recipeInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
return nil, err
}

sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, session.TenantId, responseToken.Payload, recipeInstance.RecipeImpl, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true)
sessionContainer := newSessionContainer(config, &sessionContainerInput)

return sessionContainer, nil
Expand Down Expand Up @@ -345,7 +361,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
}

mergeIntoAccessTokenPayload := func(sessionHandle string, accessTokenPayloadUpdate map[string]interface{}, userContext supertokens.UserContext) (bool, error) {
sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext)
recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return false, err
}

sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -433,7 +454,12 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
}

fetchAndSetClaim := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (bool, error) {
sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext)
recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return false, err
}

sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext)
if err != nil {
return false, err
}
Expand All @@ -444,16 +470,24 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
if err != nil {
return false, err
}
return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
}

setClaimValue := func(sessionHandle string, claim *claims.TypeSessionClaim, value interface{}, userContext supertokens.UserContext) (bool, error) {
recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return false, err
}
accessTokenPayloadUpdate := claim.AddToPayload_internal(map[string]interface{}{}, value, userContext)
return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
}

getClaimValue := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (sessmodels.GetClaimValueResult, error) {
sessionInfo, err := (*result.GetSessionInformation)(sessionHandle, userContext)
recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return sessmodels.GetClaimValueResult{}, err
}
sessionInfo, err := (*recipe.RecipeImpl.GetSessionInformation)(sessionHandle, userContext)
if err != nil {
return sessmodels.GetClaimValueResult{}, err
}
Expand All @@ -471,10 +505,16 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
}

removeClaim := func(sessionHandle string, claim *claims.TypeSessionClaim, userContext supertokens.UserContext) (bool, error) {
recipe, err := getRecipeInstanceOrThrowError()
if err != nil {
return false, err
}

accessTokenPayloadUpdate := claim.RemoveFromPayloadByMerge_internal(map[string]interface{}{}, userContext)
return (*result.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
return (*recipe.RecipeImpl.MergeIntoAccessTokenPayload)(sessionHandle, accessTokenPayloadUpdate, userContext)
}
result = sessmodels.RecipeInterface{

return sessmodels.RecipeInterface{
CreateNewSession: &createNewSession,
GetSession: &getSession,
RefreshSession: &refreshSession,
Expand All @@ -496,5 +536,4 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
RemoveClaim: &removeClaim,
}

return result
}
1 change: 0 additions & 1 deletion recipe/session/sessionRequestFunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"strconv"

"github.com/supertokens/supertokens-golang/recipe/session/claims"

"github.com/supertokens/supertokens-golang/recipe/session/errors"
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
"github.com/supertokens/supertokens-golang/supertokens"
Expand Down