From 4388277446c2d23454a7c25147f50c1c9380c81f Mon Sep 17 00:00:00 2001 From: Diego Balseiro Date: Tue, 25 Nov 2025 19:54:00 -0500 Subject: [PATCH] Check for a valid token before asking for a new one --- cmd/client.go | 2 +- pkg/k8s/azure_keychain.go | 100 ++++++++++++++++++++++++++++++++++++++ pkg/k8s/keychains.go | 56 +++++++++++++++------ 3 files changed, 142 insertions(+), 16 deletions(-) create mode 100644 pkg/k8s/azure_keychain.go diff --git a/cmd/client.go b/cmd/client.go index 9c4fc81987..7eadb8673f 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -104,7 +104,7 @@ func newTransport(insecureSkipVerify bool) fnhttp.RoundTripCloser { func newCredentialsProvider(configPath string, t http.RoundTripper) oci.CredentialsProvider { additionalLoaders := append(k8s.GetOpenShiftDockerCredentialLoaders(), k8s.GetGoogleCredentialLoader()...) additionalLoaders = append(additionalLoaders, k8s.GetECRCredentialLoader()...) - additionalLoaders = append(additionalLoaders, k8s.GetACRCredentialLoader()...) + additionalLoaders = append(additionalLoaders, k8s.GetACRCredentialLoader(configPath)...) options := []creds.Opt{ creds.WithPromptForCredentials(prompt.NewPromptForCredentials(os.Stdin, os.Stdout, os.Stderr)), creds.WithPromptForCredentialStore(prompt.NewPromptForCredentialStore()), diff --git a/pkg/k8s/azure_keychain.go b/pkg/k8s/azure_keychain.go new file mode 100644 index 0000000000..aade1c9c42 --- /dev/null +++ b/pkg/k8s/azure_keychain.go @@ -0,0 +1,100 @@ +package k8s + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" +) + +type AzureToken struct { + Token string `json:"token"` // Azure access token + ExpiresOn int64 `json:"expires_on"` // The expiration time of the token in Unix format +} + +func GetAzureToken(configPath string, registry string) (string, error) { + // Retrieves an Azure token for the specified registry. + // If a valid cached token exists, it is returned. Otherwise, a new token is fetched and cached. + cachedToken, err := ReadAzureToken(configPath, registry) + if err != nil { + return "", fmt.Errorf("failed to read Azure token: %w", err) + } + + if cachedToken == "" { + azCredential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return "", fmt.Errorf("failed to create default Azure credentials: %v", err) + } + // Define the default scope for Azure token requests + defaultScope := "https://management.azure.com/.default" + aztoken, err := azCredential.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{defaultScope}}) + if err != nil { + return "", fmt.Errorf("failed to get Azure access token: %v", err) + } + SaveAzureToken(registry, aztoken) + return aztoken.Token, nil + } + + return cachedToken, nil +} + +func TokenFile(configPath string, registry string) string { + return filepath.Join(configPath, fmt.Sprintf("azure-token-%s.json", registry)) +} + +func ReadAzureToken(configPath string, registry string) (string, error) { + // Returns an empty string if the token does not exist or is expired. + tokenFile := TokenFile(configPath, registry) + if _, err := os.Stat(tokenFile); err != nil { + if os.IsNotExist(err) { + return "", nil // No token file found, return empty token + } + return "", fmt.Errorf("failed to stat Azure token file: %w", err) + } + + tokenBytes, err := os.ReadFile(tokenFile) + if err != nil { + return "", fmt.Errorf("failed to read Azure token file: %w", err) + } + + var token AzureToken + if err := json.Unmarshal(tokenBytes, &token); err != nil { + return "", fmt.Errorf("failed to unmarshal Azure token: %w", err) + } + + if time.Now().After(time.Unix(0, token.ExpiresOn)) { + // Token has expired, return empty token + return "", nil + } + + return token.Token, nil +} + +func SaveAzureToken(configPath string, registry string, token azcore.AccessToken) error { + // Saves an Azure token to the file system for caching. + tokenBytes, err := json.Marshal(AzureToken{ + Token: token.Token, + ExpiresOn: token.ExpiresOn.Unix(), + }) + if err != nil { + return fmt.Errorf("failed to marshal Azure token: %w", err) + } + + tokenFile := TokenFile(configPath, registry) + // Ensure the directory for the token file exists + if err := os.MkdirAll(filepath.Dir(tokenFile), 0755); err != nil { + return fmt.Errorf("failed to create directory for Azure token file: %w", err) + } + + if err := os.WriteFile(tokenFile, tokenBytes, 0600); err != nil { + return fmt.Errorf("failed to write Azure token file: %w", err) + } + + return nil +} diff --git a/pkg/k8s/keychains.go b/pkg/k8s/keychains.go index bc3487cbdd..9447dde748 100644 --- a/pkg/k8s/keychains.go +++ b/pkg/k8s/keychains.go @@ -1,16 +1,12 @@ package k8s import ( - "context" "fmt" "strings" "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/google" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "knative.dev/func/pkg/creds" "knative.dev/func/pkg/oci" ) @@ -49,29 +45,59 @@ func GetECRCredentialLoader() []creds.CredentialsCallback { return []creds.CredentialsCallback{} // TODO: Implement ECR credentials loader } -func GetACRCredentialLoader() []creds.CredentialsCallback { +func GetACRCredentialLoader(configPath string) []creds.CredentialsCallback { return []creds.CredentialsCallback{ func(registry string) (oci.Credentials, error) { if !strings.HasSuffix(registry, ".azurecr.io") { return oci.Credentials{}, nil } - // TODO: Save token somewhere and check expiration before asking for a new one - - azCredential, err := azidentity.NewDefaultAzureCredential(nil) - if err != nil { - return oci.Credentials{}, fmt.Errorf("Failed to create default Azure credentials: %v", err) - } - - token, err := azCredential.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}}) + token, err := GetAzureToken(configPath, registry) if err != nil { - return oci.Credentials{}, fmt.Errorf("Failed to get Azure access token: %v", err) + return oci.Credentials{}, fmt.Errorf("failed to get Azure access token: %v", err) } return oci.Credentials{ Username: "00000000-0000-0000-0000-000000000000", - Password: token.Token, + Password: token, }, nil + /* + if token == nil { + azCredential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return oci.Credentials{}, fmt.Errorf("failed to create default Azure credentials: %v", err) + } + + token, err = azCredential.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}}) + if err != nil { + return oci.Credentials{}, fmt.Errorf("failed to get Azure access token: %v", err) + } + + } + + // check if there is a valid token already generated for this target + err := config.CreatePaths() + if err != nil { + return oci.Credentials{}, fmt.Errorf("error checking for generated tokens: %v", err) + } + configPath := config.Dir() + if configPath == "" { + return oci.Credentials{}, fmt.Errorf("could not determine config path") + } + tokenFile := filepath.Join(configPath, "az_tokens", registry+".token") + + azCredential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return oci.Credentials{}, fmt.Errorf("failed to create default Azure credentials: %v", err) + } + + token, err := azCredential.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com/.default"}}) + if err != nil { + return oci.Credentials{}, fmt.Errorf("failed to get Azure access token: %v", err) + } + + // save token for future use + */ }, } }