Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions cmd/controller/controllers/volume_state_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ func (c *VolumeStateController) fetchInitialStorageState(ctx context.Context, vo
return nil
}

instanceIDs := make([]string, len(nodes))
instanceIDs := make([]string, 0, len(nodes))
instanceIDToNodeName := make(map[string]string, len(nodes))
for i, node := range nodes {
for _, node := range nodes {
instanceID := extractInstanceIDFromProviderID(node.Spec.ProviderID)
instanceIDs[i] = instanceID
if instanceID == "" {
c.log.WithField("provider_id", node.Spec.ProviderID).Warn("could not extract instance id from provider id")
continue
}
instanceIDs = append(instanceIDs, instanceID)
instanceIDToNodeName[instanceID] = node.Name
}

Expand Down Expand Up @@ -114,11 +118,15 @@ func (c *VolumeStateController) runRefreshLoop(ctx context.Context, volumeIndex
continue
}

instanceIDs := make([]string, len(nodes))
instanceIDs := make([]string, 0, len(nodes))
instanceIDToNodeName := make(map[string]string, len(nodes))
for i, node := range nodes {
for _, node := range nodes {
instanceID := extractInstanceIDFromProviderID(node.Spec.ProviderID)
instanceIDs[i] = instanceID
if instanceID == "" {
c.log.WithField("provider_id", node.Spec.ProviderID).Warn("could not extract instance id from provider id")
continue
}
instanceIDs = append(instanceIDs, instanceID)
instanceIDToNodeName[instanceID] = node.Name
}

Expand Down Expand Up @@ -162,10 +170,10 @@ func extractInstanceIDFromProviderID(providerID string) string {
}

// GCP format: gce://project-id/zone/instance-name
if strings.HasPrefix(providerID, "gce://") {
parts := strings.Split(providerID, "/")
if len(parts) >= 4 {
return parts[len(parts)-1]
if instanceID, ok := strings.CutPrefix(providerID, "gce://"); ok {
parts := strings.Split(instanceID, "/")
if len(parts) == 3 {
return instanceID
}
}

Expand Down
7 changes: 6 additions & 1 deletion cmd/controller/controllers/volume_state_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ func TestExtractInstanceIDFromProviderID(t *testing.T) {
{
name: "GCP format",
providerID: "gce://my-project/us-central1-a/instance-name",
want: "instance-name",
want: "my-project/us-central1-a/instance-name",
},
{
name: "GCP format with missing zone",
providerID: "gce://my-project/instance-name",
want: "",
},
{
name: "Azure format",
Expand Down
4 changes: 0 additions & 4 deletions pkg/cloudprovider/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ type Provider struct {
// AWS clients
ec2Client *ec2.Client

// Cached storage state
storageStateMu sync.RWMutex
storageState *types.StorageState

// Cached network state
networkStateMu sync.RWMutex
networkState *types.NetworkState
Expand Down
11 changes: 3 additions & 8 deletions pkg/cloudprovider/aws/storage_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) (
p.log.Debug("refreshing storage state")

state := &types.StorageState{
Domain: "amazonaws.com",
Provider: types.TypeAWS,
InstanceVolumes: make(map[string][]types.Volume),
Domain: "amazonaws.com",
Provider: types.TypeAWS,
}

instanceVolumes, err := p.fetchInstanceVolumes(ctx, instanceIds...)
Expand All @@ -26,11 +25,7 @@ func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) (
}
state.InstanceVolumes = instanceVolumes

p.storageStateMu.Lock()
defer p.storageStateMu.Unlock()
p.storageState = state

return p.storageState, nil
return state, nil
}

// fetchInstanceVolumes retrieves instance volumes from https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_Volume.html
Expand Down
16 changes: 14 additions & 2 deletions pkg/cloudprovider/gcp/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gcp

import (
"context"
"errors"
"fmt"
"sync"

Expand All @@ -17,6 +18,7 @@ type Provider struct {
// GCP clients
networksClient *compute.NetworksClient
subnetworksClient *compute.SubnetworksClient
disksClient *compute.DisksClient

// Cached network state
networkStateMu sync.RWMutex
Expand All @@ -43,11 +45,17 @@ func NewProvider(ctx context.Context, cfg types.ProviderConfig) (types.Provider,
return nil, fmt.Errorf("creating subnetworks client: %w", err)
}

disksClient, err := compute.NewDisksRESTClient(ctx, clientOptions...)
if err != nil {
return nil, fmt.Errorf("creating disks client: %w", err)
}

p := &Provider{
log: log,
cfg: cfg,
networksClient: networksClient,
subnetworksClient: subnetworksClient,
disksClient: disksClient,
}

log.With("project", cfg.GCPProjectID).Info("gcp provider initialized")
Expand All @@ -71,10 +79,14 @@ func (p *Provider) Close() error {
errs = append(errs, fmt.Errorf("closing subnetworks client: %w", err))
}
}
if p.disksClient != nil {
if err := p.disksClient.Close(); err != nil {
errs = append(errs, fmt.Errorf("closing disks client: %w", err))
}
}

if len(errs) > 0 {
return fmt.Errorf("errors closing GCP provider: %v", errs)
return fmt.Errorf("errors closing GCP provider: %w", errors.Join(errs...))
}

return nil
}
131 changes: 130 additions & 1 deletion pkg/cloudprovider/gcp/storage_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,140 @@ package gcp

import (
"context"
"errors"
"fmt"
"math"
"path"
"strings"

"cloud.google.com/go/compute/apiv1/computepb"
"github.com/samber/lo"
"google.golang.org/api/iterator"

"github.com/castai/kvisor/pkg/cloudprovider/types"
)

func (p *Provider) GetStorageState(ctx context.Context, instanceIds ...string) (*types.StorageState, error) {
return nil, fmt.Errorf("GetStorageState not yet implemented for GCP")
p.log.Debug("refreshing storage state")

state := &types.StorageState{
Domain: "googleapis.com",
Provider: types.TypeGCP,
}

instanceVolumes, err := p.fetchInstanceVolumes(ctx, instanceIds...)
if err != nil {
return nil, fmt.Errorf("fetching volumes: %w", err)
}
state.InstanceVolumes = instanceVolumes

return state, nil
}

// fetchInstanceVolumes retrieves instance volumes from https://docs.cloud.google.com/compute/docs/reference/rest/v1/disks/aggregatedList
func (p *Provider) fetchInstanceVolumes(ctx context.Context, instanceIds ...string) (map[string][]types.Volume, error) {
instanceVolumes := make(map[string][]types.Volume, len(instanceIds))

if len(instanceIds) == 0 {
return instanceVolumes, nil
}

instanceUrlsMap := make(map[string]string, len(instanceIds))
for _, instanceId := range instanceIds {
url := buildInstanceUrlFromId(instanceId)
if url == "" {
p.log.WithField("instance_id", instanceId).Warn("could not build instance url")
continue
}
instanceUrlsMap[url] = instanceId
}

filter := buildDisksUsedByInstanceFilter(lo.Keys(instanceUrlsMap))

req := &computepb.AggregatedListDisksRequest{
Project: p.cfg.GCPProjectID,
Filter: &filter,
}

it := p.disksClient.AggregatedList(ctx, req)
for result, err := range it.All() {
if errors.Is(err, iterator.Done) {
break
}

if err != nil {
return instanceVolumes, fmt.Errorf("listing disks: %w", err)
}

for _, disk := range result.Value.Disks {
if disk.GetName() == "" {
p.log.Error("disk missing name, skipping")
continue
}

for _, instanceUrl := range disk.Users {
instanceId, ok := instanceUrlsMap[instanceUrl]
if !ok {
continue
}

volume := types.Volume{
VolumeID: disk.GetName(),
VolumeState: strings.ToLower(disk.GetStatus()),
Encrypted: true, // GCP disks are encrypted by default
}

if disk.GetType() != "" {
volume.VolumeType = path.Base(disk.GetType())
}

if disk.GetZone() != "" {
volume.Zone = path.Base(disk.GetZone())
}

if disk.GetSizeGb() > 0 {
// Size is in GB, convert to bytes
volume.SizeBytes = disk.GetSizeGb() * 1024 * 1024 * 1024
}

if disk.GetProvisionedIops() > 0 {
volume.IOPS = safeInt64ToInt32(disk.GetProvisionedIops())
}

if disk.GetProvisionedThroughput() > 0 {
// Throughput is in MB/s, convert to bytes/s
volume.ThroughputBytes = safeInt64ToInt32(disk.GetProvisionedThroughput() * 1024 * 1024)
}

instanceVolumes[instanceId] = append(instanceVolumes[instanceId], volume)
}
}
}

return instanceVolumes, nil
}

// buildInstanceUrlFromId converts an instance ID (project/zone/instance-name) to a full GCP instance URL
func buildInstanceUrlFromId(instanceId string) string {
parts := strings.Split(instanceId, "/")
if len(parts) != 3 {
return ""
}
return fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/zones/%s/instances/%s", parts[0], parts[1], parts[2])
}

// buildDisksUsedByInstanceFilter builds a GCP API filter for disks attached to specific instances
func buildDisksUsedByInstanceFilter(instanceUrls []string) string {
conditions := make([]string, len(instanceUrls))
for i, url := range instanceUrls {
conditions[i] = fmt.Sprintf(`(users:%q)`, url)
}
return strings.Join(conditions, " OR ")
}

func safeInt64ToInt32(val int64) int32 {
if val > math.MaxInt32 {
return math.MaxInt32
}
return int32(val) // nolint:gosec
}
4 changes: 4 additions & 0 deletions pkg/cloudprovider/gcp/test/env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ GCP_PROJECT_ID=
# Required: Your network (VPC) name to test
NETWORK_NAME=

# Required: GCP Instance ID to test volume listing
# Example: my-gcp-project/us-east4-a/my-gcp-pool-3556b234,my-gcp-project/us-east4-c/my-gcp-pool-a7579587
GCP_INSTANCE_IDS=

# Optional: Path to service account key file
# If not set, will use GOOGLE_APPLICATION_CREDENTIALS or default credentials
GCP_CREDENTIALS_FILE=
Expand Down
51 changes: 51 additions & 0 deletions pkg/cloudprovider/gcp/test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package integration_test
import (
"context"
"os"
"strings"
"testing"

"github.com/joho/godotenv"
Expand Down Expand Up @@ -84,3 +85,53 @@ func TestRefreshNetworkState(t *testing.T) {
}
// t.Logf(" Service Ranges: %+v", state.ServiceRanges)
}

// TestGetStorageState calls GetStorageState and prints the results.
func TestGetStorageState(t *testing.T) {
cfg := getTestConfig(t)
ctx := t.Context()

provider, err := gcp.NewProvider(ctx, cfg)
if err != nil {
t.Fatalf("NewProvider failed: %v", err)
}

p := provider.(*gcp.Provider)

instanceIDsStr := os.Getenv("GCP_INSTANCE_IDS")
if instanceIDsStr == "" {
t.Fatal("GCP_INSTANCE_IDS not set")
}

instanceIDs := strings.Split(instanceIDsStr, ",")
for i := range instanceIDs {
instanceIDs[i] = strings.TrimSpace(instanceIDs[i])
}

state, err := p.GetStorageState(ctx, instanceIDs...)
if err != nil {
t.Fatalf("GetStorageState failed: %v", err)
}

for _, instanceID := range instanceIDs {
t.Logf("Testing instance: %s", instanceID)

volumes, ok := state.InstanceVolumes[instanceID]
if !ok {
t.Fatalf("No volumes found for instance %s", instanceID)
}

t.Logf("Found %d volumes attached to instance %s:", len(volumes), instanceID)
for _, v := range volumes {
t.Logf(" Volume:")
t.Logf(" VolumeID: %s", v.VolumeID)
t.Logf(" VolumeType: %s", v.VolumeType)
t.Logf(" VolumeState: %s", v.VolumeState)
t.Logf(" SizeBytes: %d", v.SizeBytes)
t.Logf(" Zone: %s", v.Zone)
t.Logf(" Encrypted: %v", v.Encrypted)
t.Logf(" IOPS: %d", v.IOPS)
t.Logf(" ThroughputBytes: %d B/s", v.ThroughputBytes)
}
}
}