@@ -854,8 +854,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
854854 }
855855
856856 var (
857- c * credentials.Credential
858- exists bool
857+ c * credentials.Credential
858+ resultCredential credentials.Credential
859+ exists bool
860+ refresh bool
859861 )
860862
861863 rm := runtimeWithLogger (callCtx , monitor , r .runtimeManager )
@@ -886,6 +888,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
886888 if ! exists || c .IsExpired () {
887889 // If the existing credential is expired, we need to provide it to the cred tool through the environment.
888890 if exists && c .IsExpired () {
891+ refresh = true
889892 credJSON , err := json .Marshal (c )
890893 if err != nil {
891894 return nil , fmt .Errorf ("failed to marshal credential: %w" , err )
@@ -916,39 +919,56 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
916919 continue
917920 }
918921
919- if err := json .Unmarshal ([]byte (* res .Result ), & c ); err != nil {
922+ if err := json .Unmarshal ([]byte (* res .Result ), & resultCredential ); err != nil {
920923 return nil , fmt .Errorf ("failed to unmarshal credential tool %s response: %w" , ref .Reference , err )
921924 }
922- c .ToolName = credName
923- c .Type = credentials .CredentialTypeTool
925+ resultCredential .ToolName = credName
926+ resultCredential .Type = credentials .CredentialTypeTool
927+
928+ if refresh {
929+ // If this is a credential refresh, we need to make sure we use the same context.
930+ resultCredential .Context = c .Context
931+ } else {
932+ // If it is a new credential, let the credential store determine the context.
933+ resultCredential .Context = ""
934+ }
924935
925936 isEmpty := true
926- for _ , v := range c .Env {
937+ for _ , v := range resultCredential .Env {
927938 if v != "" {
928939 isEmpty = false
929940 break
930941 }
931942 }
932943
933- if ! c .Ephemeral {
944+ if ! resultCredential .Ephemeral {
934945 // Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
935946 if (isGitHubTool (toolName ) && callCtx .Program .ToolSet [ref .ToolID ].Source .Repo != nil ) || credentialAlias != "" {
936947 if isEmpty {
937948 log .Warnf ("Not saving empty credential for tool %s" , toolName )
938- } else if err := r .credStore .Add (callCtx .Ctx , * c ); err != nil {
939- return nil , fmt .Errorf ("failed to add credential for tool %s: %w" , toolName , err )
949+ } else {
950+ if refresh {
951+ err = r .credStore .Refresh (callCtx .Ctx , resultCredential )
952+ } else {
953+ err = r .credStore .Add (callCtx .Ctx , resultCredential )
954+ }
955+ if err != nil {
956+ return nil , fmt .Errorf ("failed to save credential for tool %s: %w" , toolName , err )
957+ }
940958 }
941959 } else {
942960 log .Warnf ("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases." , toolName )
943961 }
944962 }
963+ } else {
964+ resultCredential = * c
945965 }
946966
947- if c .ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration .After (* c .ExpiresAt )) {
948- nearestExpiration = c .ExpiresAt
967+ if resultCredential .ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration .After (* resultCredential .ExpiresAt )) {
968+ nearestExpiration = resultCredential .ExpiresAt
949969 }
950970
951- for k , v := range c .Env {
971+ for k , v := range resultCredential .Env {
952972 env = append (env , fmt .Sprintf ("%s=%s" , k , v ))
953973 }
954974 }
0 commit comments