diff --git a/cmd/controller/controllers/volume_state_controller.go b/cmd/controller/controllers/volume_state_controller.go index 4c72f465..28d79d0b 100644 --- a/cmd/controller/controllers/volume_state_controller.go +++ b/cmd/controller/controllers/volume_state_controller.go @@ -65,11 +65,15 @@ func (c *VolumeStateController) fetchInitialStorageState(ctx context.Context, vo return nil } - instanceIDs := make([]string, len(nodes)) + instanceIDs := make([]string, 0, len(nodes)) instanceIDToNodeName := make(map[string]string, len(nodes)) - for i, node := range nodes { + for _, node := range nodes { instanceID := extractInstanceIDFromProviderID(node.Spec.ProviderID) - instanceIDs[i] = instanceID + if instanceID == "" { + c.log.WithField("provider_id", node.Spec.ProviderID).Warn("could not extract instance id from provider id") + continue + } + instanceIDs = append(instanceIDs, instanceID) instanceIDToNodeName[instanceID] = node.Name } @@ -114,11 +118,15 @@ func (c *VolumeStateController) runRefreshLoop(ctx context.Context, volumeIndex continue } - instanceIDs := make([]string, len(nodes)) + instanceIDs := make([]string, 0, len(nodes)) instanceIDToNodeName := make(map[string]string, len(nodes)) - for i, node := range nodes { + for _, node := range nodes { instanceID := extractInstanceIDFromProviderID(node.Spec.ProviderID) - instanceIDs[i] = instanceID + if instanceID == "" { + c.log.WithField("provider_id", node.Spec.ProviderID).Warn("could not extract instance id from provider id") + continue + } + instanceIDs = append(instanceIDs, instanceID) instanceIDToNodeName[instanceID] = node.Name } @@ -162,10 +170,10 @@ func extractInstanceIDFromProviderID(providerID string) string { } // GCP format: gce://project-id/zone/instance-name - if strings.HasPrefix(providerID, "gce://") { - parts := strings.Split(providerID, "/") - if len(parts) >= 4 { - return parts[len(parts)-1] + if instanceID, ok := strings.CutPrefix(providerID, "gce://"); ok { + parts := strings.Split(instanceID, "/") + if len(parts) == 3 { + return instanceID } } diff --git a/cmd/controller/controllers/volume_state_controller_test.go b/cmd/controller/controllers/volume_state_controller_test.go index c048650b..11fdf01b 100644 --- a/cmd/controller/controllers/volume_state_controller_test.go +++ b/cmd/controller/controllers/volume_state_controller_test.go @@ -87,7 +87,12 @@ func TestExtractInstanceIDFromProviderID(t *testing.T) { { name: "GCP format", providerID: "gce://my-project/us-central1-a/instance-name", - want: "instance-name", + want: "my-project/us-central1-a/instance-name", + }, + { + name: "GCP format with missing zone", + providerID: "gce://my-project/instance-name", + want: "", }, { name: "Azure format", diff --git a/pkg/cloudprovider/aws/provider.go b/pkg/cloudprovider/aws/provider.go index 6db7fdd2..da5fb8ff 100644 --- a/pkg/cloudprovider/aws/provider.go +++ b/pkg/cloudprovider/aws/provider.go @@ -17,10 +17,6 @@ type Provider struct { // AWS clients ec2Client *ec2.Client - // Cached storage state - storageStateMu sync.RWMutex - storageState *types.StorageState - // Cached network state networkStateMu sync.RWMutex networkState *types.NetworkState diff --git a/pkg/cloudprovider/aws/storage_state.go b/pkg/cloudprovider/aws/storage_state.go index d017d8dd..d0629c25 100644 --- a/pkg/cloudprovider/aws/storage_state.go +++ b/pkg/cloudprovider/aws/storage_state.go @@ -15,9 +15,8 @@ func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) ( p.log.Debug("refreshing storage state") state := &types.StorageState{ - Domain: "amazonaws.com", - Provider: types.TypeAWS, - InstanceVolumes: make(map[string][]types.Volume), + Domain: "amazonaws.com", + Provider: types.TypeAWS, } instanceVolumes, err := p.fetchInstanceVolumes(ctx, instanceIds...) @@ -26,11 +25,7 @@ func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) ( } state.InstanceVolumes = instanceVolumes - p.storageStateMu.Lock() - defer p.storageStateMu.Unlock() - p.storageState = state - - return p.storageState, nil + return state, nil } // fetchInstanceVolumes retrieves instance volumes from https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_Volume.html diff --git a/pkg/cloudprovider/gcp/provider.go b/pkg/cloudprovider/gcp/provider.go index 01e93157..2d304026 100644 --- a/pkg/cloudprovider/gcp/provider.go +++ b/pkg/cloudprovider/gcp/provider.go @@ -2,6 +2,7 @@ package gcp import ( "context" + "errors" "fmt" "sync" @@ -17,6 +18,7 @@ type Provider struct { // GCP clients networksClient *compute.NetworksClient subnetworksClient *compute.SubnetworksClient + disksClient *compute.DisksClient // Cached network state networkStateMu sync.RWMutex @@ -43,11 +45,17 @@ func NewProvider(ctx context.Context, cfg types.ProviderConfig) (types.Provider, return nil, fmt.Errorf("creating subnetworks client: %w", err) } + disksClient, err := compute.NewDisksRESTClient(ctx, clientOptions...) + if err != nil { + return nil, fmt.Errorf("creating disks client: %w", err) + } + p := &Provider{ log: log, cfg: cfg, networksClient: networksClient, subnetworksClient: subnetworksClient, + disksClient: disksClient, } log.With("project", cfg.GCPProjectID).Info("gcp provider initialized") @@ -71,10 +79,14 @@ func (p *Provider) Close() error { errs = append(errs, fmt.Errorf("closing subnetworks client: %w", err)) } } + if p.disksClient != nil { + if err := p.disksClient.Close(); err != nil { + errs = append(errs, fmt.Errorf("closing disks client: %w", err)) + } + } if len(errs) > 0 { - return fmt.Errorf("errors closing GCP provider: %v", errs) + return fmt.Errorf("errors closing GCP provider: %w", errors.Join(errs...)) } - return nil } diff --git a/pkg/cloudprovider/gcp/storage_state.go b/pkg/cloudprovider/gcp/storage_state.go index 83cbba2b..0ed4d70d 100644 --- a/pkg/cloudprovider/gcp/storage_state.go +++ b/pkg/cloudprovider/gcp/storage_state.go @@ -2,11 +2,140 @@ package gcp import ( "context" + "errors" "fmt" + "math" + "path" + "strings" + + "cloud.google.com/go/compute/apiv1/computepb" + "github.com/samber/lo" + "google.golang.org/api/iterator" "github.com/castai/kvisor/pkg/cloudprovider/types" ) func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) (*types.StorageState, error) { - return nil, fmt.Errorf("GetStorageState not yet implemented for GCP") + p.log.Debug("refreshing storage state") + + state := &types.StorageState{ + Domain: "googleapis.com", + Provider: types.TypeGCP, + } + + instanceVolumes, err := p.fetchInstanceVolumes(ctx, instanceIds...) + if err != nil { + return nil, fmt.Errorf("fetching volumes: %w", err) + } + state.InstanceVolumes = instanceVolumes + + return state, nil +} + +// fetchInstanceVolumes retrieves instance volumes from https://docs.cloud.google.com/compute/docs/reference/rest/v1/disks/aggregatedList +func (p *Provider) fetchInstanceVolumes(ctx context.Context, instanceIds ...string) (map[string][]types.Volume, error) { + instanceVolumes := make(map[string][]types.Volume, len(instanceIds)) + + if len(instanceIds) == 0 { + return instanceVolumes, nil + } + + instanceUrlsMap := make(map[string]string, len(instanceIds)) + for _, instanceId := range instanceIds { + url := buildInstanceUrlFromId(instanceId) + if url == "" { + p.log.WithField("instance_id", instanceId).Warn("could not build instance url") + continue + } + instanceUrlsMap[url] = instanceId + } + + filter := buildDisksUsedByInstanceFilter(lo.Keys(instanceUrlsMap)) + + req := &computepb.AggregatedListDisksRequest{ + Project: p.cfg.GCPProjectID, + Filter: &filter, + } + + it := p.disksClient.AggregatedList(ctx, req) + for result, err := range it.All() { + if errors.Is(err, iterator.Done) { + break + } + + if err != nil { + return instanceVolumes, fmt.Errorf("listing disks: %w", err) + } + + for _, disk := range result.Value.Disks { + if disk.GetName() == "" { + p.log.Error("disk missing name, skipping") + continue + } + + for _, instanceUrl := range disk.Users { + instanceId, ok := instanceUrlsMap[instanceUrl] + if !ok { + continue + } + + volume := types.Volume{ + VolumeID: disk.GetName(), + VolumeState: strings.ToLower(disk.GetStatus()), + Encrypted: true, // GCP disks are encrypted by default + } + + if disk.GetType() != "" { + volume.VolumeType = path.Base(disk.GetType()) + } + + if disk.GetZone() != "" { + volume.Zone = path.Base(disk.GetZone()) + } + + if disk.GetSizeGb() > 0 { + // Size is in GB, convert to bytes + volume.SizeBytes = disk.GetSizeGb() * 1024 * 1024 * 1024 + } + + if disk.GetProvisionedIops() > 0 { + volume.IOPS = safeInt64ToInt32(disk.GetProvisionedIops()) + } + + if disk.GetProvisionedThroughput() > 0 { + // Throughput is in MB/s, convert to bytes/s + volume.ThroughputBytes = safeInt64ToInt32(disk.GetProvisionedThroughput() * 1024 * 1024) + } + + instanceVolumes[instanceId] = append(instanceVolumes[instanceId], volume) + } + } + } + + return instanceVolumes, nil +} + +// buildInstanceUrlFromId converts an instance ID (project/zone/instance-name) to a full GCP instance URL +func buildInstanceUrlFromId(instanceId string) string { + parts := strings.Split(instanceId, "/") + if len(parts) != 3 { + return "" + } + return fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/zones/%s/instances/%s", parts[0], parts[1], parts[2]) +} + +// buildDisksUsedByInstanceFilter builds a GCP API filter for disks attached to specific instances +func buildDisksUsedByInstanceFilter(instanceUrls []string) string { + conditions := make([]string, len(instanceUrls)) + for i, url := range instanceUrls { + conditions[i] = fmt.Sprintf(`(users:%q)`, url) + } + return strings.Join(conditions, " OR ") +} + +func safeInt64ToInt32(val int64) int32 { + if val > math.MaxInt32 { + return math.MaxInt32 + } + return int32(val) // nolint:gosec } diff --git a/pkg/cloudprovider/gcp/test/env.example b/pkg/cloudprovider/gcp/test/env.example index 47f6c999..169d22f3 100644 --- a/pkg/cloudprovider/gcp/test/env.example +++ b/pkg/cloudprovider/gcp/test/env.example @@ -4,6 +4,10 @@ GCP_PROJECT_ID= # Required: Your network (VPC) name to test NETWORK_NAME= +# Required: GCP Instance ID to test volume listing +# Example: my-gcp-project/us-east4-a/my-gcp-pool-3556b234,my-gcp-project/us-east4-c/my-gcp-pool-a7579587 +GCP_INSTANCE_IDS= + # Optional: Path to service account key file # If not set, will use GOOGLE_APPLICATION_CREDENTIALS or default credentials GCP_CREDENTIALS_FILE= diff --git a/pkg/cloudprovider/gcp/test/integration_test.go b/pkg/cloudprovider/gcp/test/integration_test.go index 4823176d..a0cb5da4 100644 --- a/pkg/cloudprovider/gcp/test/integration_test.go +++ b/pkg/cloudprovider/gcp/test/integration_test.go @@ -5,6 +5,7 @@ package integration_test import ( "context" "os" + "strings" "testing" "github.com/joho/godotenv" @@ -84,3 +85,53 @@ func TestRefreshNetworkState(t *testing.T) { } // t.Logf(" Service Ranges: %+v", state.ServiceRanges) } + +// TestGetStorageState calls GetStorageState and prints the results. +func TestGetStorageState(t *testing.T) { + cfg := getTestConfig(t) + ctx := t.Context() + + provider, err := gcp.NewProvider(ctx, cfg) + if err != nil { + t.Fatalf("NewProvider failed: %v", err) + } + + p := provider.(*gcp.Provider) + + instanceIDsStr := os.Getenv("GCP_INSTANCE_IDS") + if instanceIDsStr == "" { + t.Fatal("GCP_INSTANCE_IDS not set") + } + + instanceIDs := strings.Split(instanceIDsStr, ",") + for i := range instanceIDs { + instanceIDs[i] = strings.TrimSpace(instanceIDs[i]) + } + + state, err := p.GetStorageState(ctx, instanceIDs...) + if err != nil { + t.Fatalf("GetStorageState failed: %v", err) + } + + for _, instanceID := range instanceIDs { + t.Logf("Testing instance: %s", instanceID) + + volumes, ok := state.InstanceVolumes[instanceID] + if !ok { + t.Fatalf("No volumes found for instance %s", instanceID) + } + + t.Logf("Found %d volumes attached to instance %s:", len(volumes), instanceID) + for _, v := range volumes { + t.Logf(" Volume:") + t.Logf(" VolumeID: %s", v.VolumeID) + t.Logf(" VolumeType: %s", v.VolumeType) + t.Logf(" VolumeState: %s", v.VolumeState) + t.Logf(" SizeBytes: %d", v.SizeBytes) + t.Logf(" Zone: %s", v.Zone) + t.Logf(" Encrypted: %v", v.Encrypted) + t.Logf(" IOPS: %d", v.IOPS) + t.Logf(" ThroughputBytes: %d B/s", v.ThroughputBytes) + } + } +}