diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 76d9f0b5..40e1d705 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -77,7 +77,7 @@ jobs: build-args: | GO_LDFLAGS=-X 'github.com/NexusGPU/tensor-fusion/internal/version.BuildVersion=${{ needs.release.outputs.version }}' - publish_node_discovery_image: + publish_hypervisor_image: needs: - release if: needs.release.outputs.published == 'true' || github.event_name == 'workflow_dispatch' @@ -95,7 +95,7 @@ jobs: - id: meta uses: docker/metadata-action@v5 with: - images: tensorfusion/tensor-fusion-node-discovery + images: tensorfusion/tensor-fusion-hypervisor tags: ${{ github.event_name == 'workflow_dispatch' && steps.set_tag.outputs.tag || format('type=semver,pattern={{{{version}}}},value={0}', needs.release.outputs.version) }} - name: Login to DockerHub @@ -104,12 +104,14 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - name: Build and push node discovery + - name: Build and push hypervisor uses: docker/build-push-action@v6 with: context: . push: true - file: dockerfile/node-discovery.Dockerfile + file: dockerfile/hypervisor.Dockerfile tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} no-cache: true + build-args: | + GO_LDFLAGS=-X 'github.com/NexusGPU/tensor-fusion/internal/version.BuildVersion=${{ needs.release.outputs.version }}' diff --git a/.gitignore b/.gitignore index fc148c71..54b8a74d 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,13 @@ __debug* vendor logs -*.prof \ No newline at end of file +*.prof + +provider/build + +cmd/hypervisor/hypervisor +*.o + +_obj + +metrics.log \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 954d1d19..0c9c7fa9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -21,15 +21,18 @@ ] }, { - "name": "Debug Discovery", + "name": "Debug Hypervisor", "type": "go", "request": "launch", "mode": "auto", + "console": "integratedTerminal", "env": { - "HOSTNAME": "mocknode", - "KUBECONFIG": "~/.kube/config", + "KUBECONFIG": "~/.kube/config-local-studio", + "HYPERVISOR_PORT": "8042", + "GPU_NODE_NAME": "ubuntu", }, - "program": "${workspaceFolder}/cmd/nodediscovery/main.go", + "cwd": "${workspaceFolder}", + "program": "${workspaceFolder}/cmd/hypervisor/main.go", }, { "name": "Debug Dev Env Operator", @@ -62,7 +65,8 @@ "ENABLE_WEBHOOKS": "false", "ENABLE_SCHEDULER": "true", "ENABLE_CR_CONTROLLER": "true", - "NVIDIA_OPERATOR_PROGRESSIVE_MIGRATION": "true" + "NVIDIA_OPERATOR_PROGRESSIVE_MIGRATION": "true", + "IMPERSONATE_SERVICE_ACCOUNT": "system:serviceaccount:tensor-fusion-sys:tensor-fusion-sys" }, "args": [ "--metrics-path", "${workspaceFolder}/logs/metrics.log", diff --git a/.vscode/settings.json b/.vscode/settings.json index 5be70139..84f7e43a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,14 +9,20 @@ "AMDCDNA", "AMDRDNA", "apierrors", + "apiextensions", "apimachinery", "apimachineryruntime", "apiruntime", + "apiserver", "apiutil", + "appsv", "automount", "AWSGPU", "batchv", "Biren", + "bubbletea", + "BUILDPLATFORM", + "buildx", "burstable", "Cambricon", "CDNA", @@ -24,6 +30,8 @@ "certgen", "certificaterequests", "certmanager", + "CFLAGS", + "charmbracelet", "clientcmd", "clientcmdapi", "clientgoscheme", @@ -45,27 +53,36 @@ "datanode", "deepcopy", "defaultbinder", + "deviceplugin", "dylib", "eastus", "envtest", "essd", "Eventf", + "eventhandlers", "evictable", "featuregate", "finalizer", "Finalizers", "frameworkruntime", + "fsnotify", "FULLTEXT", + "GOARCH", + "GOBIN", "goconst", "gocyclo", "goerrors", + "golangci", "golint", "Gomega", "gonic", + "GOPATH", "gopsutil", "gorm", "gosec", + "GPGPU", "gpuallocator", + "GPUIDs", "gpunode", "gpunodeclaim", "gpunodeclaims", @@ -86,8 +103,11 @@ "imageutils", "indexallocator", "influxdata", + "Infof", "internalcache", "internalqueue", + "intstr", + "IVSHMEM", "jsonpatch", "karpenter", "karpv", @@ -99,9 +119,14 @@ "kubescheduler", "kubeschedulerconfig", "kustomization", + "LDFLAGS", + "libaccelerator", "libcuda", "libnvidia", "lineprotocol", + "lipgloss", + "LOCALBIN", + "logr", "mapstructure", "metav", "metricsserver", @@ -113,26 +138,36 @@ "nindent", "nodeclaim", "nodeclassref", + "nodelist", "noderesources", "nolint", "NUMA", + "nvdp", "Nvlink", "NVML", "objs", "omitempty", "onsi", + "pids", + "pluginapi", + "podname", "portallocator", "Postable", + "posthog", + "pprof", "printcolumn", "prometheusagents", "prometheuses", "prometheusrules", + "Ptrs", "queuesort", + "Radeon", "RDNA", "readyz", "replicaset", "replicasets", "rolebinding", + "RTXA", "runbook", "runpod", "samber", @@ -145,12 +180,19 @@ "schedv", "serviceaccount", "shirou", + "shmem", "shortuuid", + "sqlmock", "statefulset", "statefulsets", + "stdbool", + "stddef", + "stdint", + "stdlib", "strategicpatch", "strategicpatches", "stretchr", + "strncpy", "subresource", "Tabler", "tensorfusion", @@ -165,6 +207,8 @@ "testutil", "tflops", "timberio", + "Timeslicing", + "tmpfs", "Tmpl", "tokenreviews", "Tolerations", @@ -173,9 +217,16 @@ "utilerrors", "utilruntime", "vgpu", + "Warningf", "webhookcorev", + "workerstate", "workloadprofiles", "workqueue", - "Xlarge" - ] + "Xlarge", + "zapr" + ], + "files.associations": { + "__locale": "cpp", + "bitset": "cpp" + } } \ No newline at end of file diff --git a/Makefile b/Makefile index 87317a95..db0b7056 100644 --- a/Makefile +++ b/Makefile @@ -110,6 +110,26 @@ build: manifests generate fmt vet ## Build manager binary. run: manifests generate fmt vet ## Run a controller from your host. go run ./cmd/main.go +.PHONY: build-provider +build-provider: ## Build accelerator stub library. + $(MAKE) -C provider stub + +.PHONY: build-hypervisor +build-hypervisor: build-provider ## Build hypervisor binary with CGO enabled. + @PROVIDER_DIR=$$(pwd)/provider; \ + CGO_ENABLED=1 \ + CGO_CFLAGS="-I$$PROVIDER_DIR" \ + go build -o bin/hypervisor ./cmd/hypervisor + +.PHONY: build-hypervisor-tui +build-hypervisor-tui: + go build -o bin/hypervisor-tui ./cmd/hypervisor-tui + + +.PHONY: clean-cache +clean-cache: ## Clean Go build cache. + go clean -cache -testcache + # If you wish to build the manager image targeting other platforms you can use the --platform flag. # (i.e. docker build --platform linux/arm64). However, you must enable docker buildKit for it. # More info: https://docs.docker.com/develop/develop-images/build_enhancements/ diff --git a/README.md b/README.md index b327fa0e..346eea2a 100644 --- a/README.md +++ b/README.md @@ -57,30 +57,34 @@ Tensor Fusion is a state-of-the-art **GPU virtualization and pooling solution** - [x] Fractional GPU and flexible oversubscription - [x] Remote GPU sharing with SOTA GPU-over-IP technology, less than 4% performance loss -- [x] GPU VRAM expansion and hot/warm/cold tiering -- [ ] None NVIDIA GPU/NPU vendor support +- [x] GPU VRAM expansion and hot/cold tiering +- [x] None NVIDIA GPU/NPU vendor support ### Pooling & Scheduling & Management - [x] GPU/NPU pool management in Kubernetes -- [x] GPU-first scheduling and allocation, with single TFlops/MB precision -- [x] GPU node auto provisioning/termination +- [x] GPU-first scheduling and allocation, with 1 TFLOPs, 1% Computing, 1 MB precision +- [x] GPU node auto provisioning/termination, Karpenter integration - [x] GPU compaction/bin-packing +- [x] Take full control of GPU allocation with precision targeting by vendor, model, device index, and more - [x] Seamless onboarding experience for Pytorch, TensorFlow, llama.cpp, vLLM, Tensor-RT, SGlang and all popular AI training/serving frameworks +- [x] Seamless migration from existing NVIDIA operator and device-plugin stack - [x] Centralized Dashboard & Control Plane - [x] GPU-first autoscaling policies, auto set requests/limits/replicas - [x] Request multiple vGPUs with group scheduling for large models - [x] Support different QoS levels +- [x] Hardware partitioned mode isolation like NVIDIA Dynamic MIG +- [x] Support Kubernetes dynamic resource allocation (DRA) API ### Enterprise Features - [x] GPU live-migration, snapshot and restore GPU context cross cluster - [ ] AI model registry and preloading, build your own private MaaS(Model-as-a-Service) -- [ ] Advanced auto-scaling policies, scale to zero, rebalance of hot GPUs +- [x] Advanced auto-scaling policies, scale to zero, rebalance of hot GPUs - [ ] Advanced observability features, detailed metrics & tracing/profiling of CUDA calls -- [ ] Monetize your GPU cluster by multi-tenancy usage measurement & billing report -- [ ] Enterprise level high availability and resilience, support topology aware scheduling, GPU node auto failover etc. -- [ ] Enterprise level security, complete on-premise deployment support +- [x] Monetize your GPU cluster by multi-tenancy usage measurement & billing report +- [x] Enterprise level high availability and resilience, support topology aware scheduling, GPU node auto failover etc. +- [x] Enterprise level security, complete on-premise deployment support - [ ] Enterprise level compliance, SSO/SAML support, advanced audit, ReBAC control, SOC2 and other compliance reports available ### 🗳️ Platform Support diff --git a/api/v1/gpu_types.go b/api/v1/gpu_types.go index d59b747c..cea56e35 100644 --- a/api/v1/gpu_types.go +++ b/api/v1/gpu_types.go @@ -30,14 +30,15 @@ type GPUStatus struct { // +kubebuilder:default="NVIDIA" Vendor string `json:"vendor"` - // +optional - Model string `json:"model,omitempty"` - Capacity *Resource `json:"capacity"` Available *Resource `json:"available"` UUID string `json:"uuid"` + // +optional + // +kubebuilder:default=soft + IsolationMode IsolationModeType `json:"isolationMode,omitempty"` + // +optional Index *int32 `json:"index,omitempty"` @@ -61,15 +62,23 @@ type GPUStatus struct { // +optional RunningApps []*RunningAppDetail `json:"runningApps,omitempty"` + + // +optional + // PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + // Reported from discovery, each template has fixed resource allocation + PartitionTemplates []PartitionTemplate `json:"partitionTemplates,omitempty"` + + // +optional + // AllocatedPartitions tracks allocated partitions on this GPU + // Key is partitionUUID, value contains template info and allocated resources + AllocatedPartitions map[string]AllocatedPartition `json:"allocatedPartitions,omitempty"` } -// +kubebuilder:validation:Enum=tensor-fusion;nvidia-device-plugin // +default="tensor-fusion" type UsedBySystem string -const ( - UsedByTensorFusion UsedBySystem = "tensor-fusion" - UsedByNvidiaDevicePlugin UsedBySystem = "nvidia-device-plugin" +var ( + UsedByTensorFusion UsedBySystem = UsedBySystem(constants.Domain) ) type RunningAppDetail struct { @@ -94,6 +103,44 @@ type PodGPUInfo struct { QoS QoSLevel `json:"qos,omitempty"` } +// PartitionTemplate represents a hardware partition template (e.g., MIG profile) +// Only stores template ID and name in GPU status. Detailed resource information +// is stored in public GPU info config. +type PartitionTemplate struct { + // TemplateID is the unique identifier for this partition template (e.g., "1g.24gb", "4g.94gb") + TemplateID string `json:"templateId"` + + // Name is a human-readable name for this template + Name string `json:"name"` +} + +// AllocatedPartition represents an allocated partition on a GPU +// Key in AllocatedPartitions map is podUID +type AllocatedPartition struct { + // TemplateID is the template used to create this partition + TemplateID string `json:"templateId"` + + // PodUID is the UID of the pod using this partition (used as map key) + PodUID string `json:"podUid"` + + // PodName is the name of the pod using this partition + PodName string `json:"podName"` + + // Namespace is the namespace of the pod using this partition + Namespace string `json:"namespace"` + + // AllocatedAt is when this partition was allocated + AllocatedAt metav1.Time `json:"allocatedAt"` + + // AllocatedSlotStart is the starting slot position where this partition is allocated + // This is the actual hardware slot position (0-based index) + AllocatedSlotStart *uint32 `json:"allocatedSlotStart,omitempty"` + + // AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + // The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + AllocatedSlotEnd *uint32 `json:"allocatedSlotEnd,omitempty"` +} + // +kubebuilder:validation:Enum=Pending;Provisioning;Running;Unknown;Destroying;Migrating type TensorFusionGPUPhase string diff --git a/api/v1/gpupool_types.go b/api/v1/gpupool_types.go index 78fe7e84..5d3cf8a2 100644 --- a/api/v1/gpupool_types.go +++ b/api/v1/gpupool_types.go @@ -33,6 +33,10 @@ type GPUPoolSpec struct { // +optional DefaultUsingLocalGPU *bool `json:"defaultUsingLocalGPU,omitempty"` + // +optional + // +kubebuilder:default=NVIDIA + Vendor string `json:"vendor,omitempty"` + CapacityConfig *CapacityConfig `json:"capacityConfig,omitempty"` NodeManagerConfig *NodeManagerConfig `json:"nodeManagerConfig,omitempty"` @@ -88,12 +92,23 @@ type NodeManagerConfig struct { // +kubebuilder:default="AutoSelect" ProvisioningMode ProvisioningMode `json:"provisioningMode,omitempty"` + // +optional + // +kubebuilder:default=NVIDIA + // In single AI accelerator hardware vendor mode, when default vendor set + // All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + DefaultVendor string `json:"defaultVendor,omitempty"` + // +optional NodeProvisioner *NodeProvisioner `json:"nodeProvisioner,omitempty"` // +optional NodeSelector *corev1.NodeSelector `json:"nodeSelector,omitempty"` + // +optional + // When this field set, the GPU pool will be in multi AI accelerator vendor mode + // each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + MultiVendorNodeSelector map[string]*corev1.NodeSelector `json:"multiVendorNodeSelector,omitempty"` + // +optional NodeCompaction *NodeCompaction `json:"nodeCompaction,omitempty"` diff --git a/api/v1/gpuresourcequota_types.go b/api/v1/gpuresourcequota_types.go index e5ba09b8..322bc9c5 100644 --- a/api/v1/gpuresourcequota_types.go +++ b/api/v1/gpuresourcequota_types.go @@ -194,6 +194,12 @@ type AllocRequest struct { PodMeta metav1.ObjectMeta QoS QoSLevel + + Isolation IsolationModeType + + // PartitionTemplateID is the template ID used for partitioned mode allocation + // This is set by the scheduler when a partition is matched, or read from pod annotation + PartitionTemplateID string } func (p *AllocRequest) Clone() fwk.StateData { diff --git a/api/v1/schedulingconfigtemplate_types.go b/api/v1/schedulingconfigtemplate_types.go index b057ef5d..a4cf8775 100644 --- a/api/v1/schedulingconfigtemplate_types.go +++ b/api/v1/schedulingconfigtemplate_types.go @@ -126,22 +126,22 @@ type AutoSetResources struct { TargetResource string `json:"targetResource,omitempty"` // Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9 - TargetTflopsPercentile string `json:"targettflopspercentile,omitempty"` + TargetTflopsPercentile string `json:"targetTFlopsPercentile,omitempty"` // Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5 - LowerBoundTflopsPercentile string `json:"lowerboundtflopspercentile,omitempty"` + LowerBoundTflopsPercentile string `json:"lowerBoundTflopsPercentile,omitempty"` // Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95 - UpperBoundTflopsPercentile string `json:"upperboundtflopspercentile,omitempty"` + UpperBoundTflopsPercentile string `json:"upperBoundTflopsPercentile,omitempty"` // Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9 - TargetVramPercentile string `json:"targetvrampercentile,omitempty"` + TargetVramPercentile string `json:"targetVramPercentile,omitempty"` // Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5 - LowerBoundVramPercentile string `json:"lowerboundvrampercentile,omitempty"` + LowerBoundVramPercentile string `json:"lowerBoundVramPercentile,omitempty"` // Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95 - UpperBoundVramPercentile string `json:"upperboundvrampercentile,omitempty"` + UpperBoundVramPercentile string `json:"upperBoundVramPercentile,omitempty"` // Fraction of usage added as the safety margin to the recommended request. Default: 0.15 RequestMarginFraction string `json:"requestMarginFraction,omitempty"` diff --git a/api/v1/workloadprofile_types.go b/api/v1/workloadprofile_types.go index 5bd70f0c..57b7dec7 100644 --- a/api/v1/workloadprofile_types.go +++ b/api/v1/workloadprofile_types.go @@ -63,6 +63,11 @@ type WorkloadProfileSpec struct { // How to isolate resources, could be `shared` or `soft` or `hard` or `partitioned` Isolation IsolationModeType `json:"isolation,omitempty"` + // +optional + // PartitionTemplateID specifies the partition template ID for partitioned isolation mode + // This is read from pod annotation tensor-fusion.ai/partition if specified + PartitionTemplateID string `json:"partitionTemplateId,omitempty"` + // +optional // GPUModel specifies the required GPU model (e.g., "A100", "H100") GPUModel string `json:"gpuModel,omitempty"` diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index 110155a2..44089a1e 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -77,6 +77,32 @@ func (in *AllocRequest) DeepCopy() *AllocRequest { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *AllocatedPartition) DeepCopyInto(out *AllocatedPartition) { + *out = *in + in.AllocatedAt.DeepCopyInto(&out.AllocatedAt) + if in.AllocatedSlotStart != nil { + in, out := &in.AllocatedSlotStart, &out.AllocatedSlotStart + *out = new(uint32) + **out = **in + } + if in.AllocatedSlotEnd != nil { + in, out := &in.AllocatedSlotEnd, &out.AllocatedSlotEnd + *out = new(uint32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AllocatedPartition. +func (in *AllocatedPartition) DeepCopy() *AllocatedPartition { + if in == nil { + return nil + } + out := new(AllocatedPartition) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *AutoFreeze) DeepCopyInto(out *AutoFreeze) { *out = *in @@ -1350,6 +1376,18 @@ func (in *GPUStatus) DeepCopyInto(out *GPUStatus) { } } } + if in.PartitionTemplates != nil { + in, out := &in.PartitionTemplates, &out.PartitionTemplates + *out = make([]PartitionTemplate, len(*in)) + copy(*out, *in) + } + if in.AllocatedPartitions != nil { + in, out := &in.AllocatedPartitions, &out.AllocatedPartitions + *out = make(map[string]AllocatedPartition, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUStatus. @@ -1602,6 +1640,22 @@ func (in *NodeManagerConfig) DeepCopyInto(out *NodeManagerConfig) { *out = new(corev1.NodeSelector) (*in).DeepCopyInto(*out) } + if in.MultiVendorNodeSelector != nil { + in, out := &in.MultiVendorNodeSelector, &out.MultiVendorNodeSelector + *out = make(map[string]*corev1.NodeSelector, len(*in)) + for key, val := range *in { + var outVal *corev1.NodeSelector + if val == nil { + (*out)[key] = nil + } else { + inVal := (*in)[key] + in, out := &inVal, &outVal + *out = new(corev1.NodeSelector) + (*in).DeepCopyInto(*out) + } + (*out)[key] = outVal + } + } if in.NodeCompaction != nil { in, out := &in.NodeCompaction, &out.NodeCompaction *out = new(NodeCompaction) @@ -1725,6 +1779,21 @@ func (in *Oversubscription) DeepCopy() *Oversubscription { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PartitionTemplate) DeepCopyInto(out *PartitionTemplate) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PartitionTemplate. +func (in *PartitionTemplate) DeepCopy() *PartitionTemplate { + if in == nil { + return nil + } + out := new(PartitionTemplate) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *PeriodicalBudget) DeepCopyInto(out *PeriodicalBudget) { *out = *in diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml index a8c2b5a0..afe2df8b 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_gpupools.yaml @@ -249,6 +249,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The + terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -608,6 +710,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object status: description: GPUPoolStatus defines the observed state of GPUPool. diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml index 50c76bce..c96258f6 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_gpus.yaml @@ -69,6 +69,54 @@ spec: GPUStatus defines the observed state of GPU. NOTE: When new fields added, remember to update syncGPUMetadataAndStatusFromCluster properties: + allocatedPartitions: + additionalProperties: + description: |- + AllocatedPartition represents an allocated partition on a GPU + Key in AllocatedPartitions map is podUID + properties: + allocatedAt: + description: AllocatedAt is when this partition was allocated + format: date-time + type: string + allocatedSlotEnd: + description: |- + AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + format: int32 + type: integer + allocatedSlotStart: + description: |- + AllocatedSlotStart is the starting slot position where this partition is allocated + This is the actual hardware slot position (0-based index) + format: int32 + type: integer + namespace: + description: Namespace is the namespace of the pod using this + partition + type: string + podName: + description: PodName is the name of the pod using this partition + type: string + podUid: + description: PodUID is the UID of the pod using this partition + (used as map key) + type: string + templateId: + description: TemplateID is the template used to create this + partition + type: string + required: + - allocatedAt + - namespace + - podName + - podUid + - templateId + type: object + description: |- + AllocatedPartitions tracks allocated partitions on this GPU + Key is partitionUUID, value contains template info and allocated resources + type: object available: properties: compute: @@ -124,9 +172,15 @@ spec: index: format: int32 type: integer - message: + isolationMode: + default: soft + enum: + - shared + - soft + - hard + - partitioned type: string - model: + message: type: string nodeSelector: additionalProperties: @@ -138,6 +192,28 @@ spec: NUMA node format: int32 type: integer + partitionTemplates: + description: |- + PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + Reported from discovery, each template has fixed resource allocation + items: + description: |- + PartitionTemplate represents a hardware partition template (e.g., MIG profile) + Only stores template ID and name in GPU status. Detailed resource information + is stored in public GPU info config. + properties: + name: + description: Name is a human-readable name for this template + type: string + templateId: + description: TemplateID is the unique identifier for this partition + template (e.g., "1g.24gb", "4g.94gb") + type: string + required: + - name + - templateId + type: object + type: array phase: default: Pending enum: @@ -241,9 +317,6 @@ spec: Hypervisor will watch kubelet device plugin to report all GPUs already used by nvidia-device-plugin GPUs will be grouped by usedBy to be used by different Pods, tensor-fusion annotation or nvidia-device-plugin resource block - enum: - - tensor-fusion - - nvidia-device-plugin type: string uuid: type: string diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_schedulingconfigtemplates.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_schedulingconfigtemplates.yaml index c9e97ebf..245f455e 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_schedulingconfigtemplates.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_schedulingconfigtemplates.yaml @@ -92,11 +92,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -108,19 +108,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml index d80f589b..c43bb82b 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionclusters.yaml @@ -315,6 +315,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector + terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -675,6 +777,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object required: - specTemplate diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml index 6fe04c9a..450b825f 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_tensorfusionworkloads.yaml @@ -113,11 +113,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -129,19 +129,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string @@ -466,6 +466,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml b/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml index f7fd3820..ada997ea 100644 --- a/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml +++ b/charts/tensor-fusion/crds/tensor-fusion.ai_workloadprofiles.yaml @@ -100,11 +100,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -116,19 +116,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string @@ -453,6 +453,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/charts/tensor-fusion/templates/controller-deployment.yaml b/charts/tensor-fusion/templates/controller-deployment.yaml index c16c4aab..ef409a1d 100644 --- a/charts/tensor-fusion/templates/controller-deployment.yaml +++ b/charts/tensor-fusion/templates/controller-deployment.yaml @@ -57,7 +57,7 @@ spec: fieldPath: metadata.namespace # when deploy with AutoSelect mode, GPU node is managed by Kubernetes rather than TensorFusion, thus, need to specify the label selector to generate the GPUNode custom resource - name: INITIAL_GPU_NODE_LABEL_SELECTOR - value: "{{ default "nvidia.com/gpu.present=true" .Values.initialGpuNodeLabelSelector }}" + value: "{{ .Values.initialGpuNodeLabelSelector }}" - name: TSDB_MYSQL_HOST value: "{{ .Values.greptime.host }}" - name: TSDB_MYSQL_PORT diff --git a/charts/tensor-fusion/templates/node-overlay.yaml b/charts/tensor-fusion/templates/node-overlay.yaml new file mode 100644 index 00000000..ce1b7b8a --- /dev/null +++ b/charts/tensor-fusion/templates/node-overlay.yaml @@ -0,0 +1,25 @@ +{{- if lookup "apiextensions.k8s.io/v1" "CustomResourceDefinition" "karpenter.sh" "NodeOverlay" -}} +apiVersion: karpenter.sh/v1alpha1 +kind: NodeOverlay +metadata: + name: tensor-fusion-overlay +spec: + requirements: [] + capacity: + tensor-fusion.ai/index_0: 36 + tensor-fusion.ai/index_1: 36 + tensor-fusion.ai/index_2: 36 + tensor-fusion.ai/index_3: 36 + tensor-fusion.ai/index_4: 36 + tensor-fusion.ai/index_5: 36 + tensor-fusion.ai/index_6: 36 + tensor-fusion.ai/index_7: 36 + tensor-fusion.ai/index_8: 36 + tensor-fusion.ai/index_9: 36 + tensor-fusion.ai/index_a: 36 + tensor-fusion.ai/index_b: 36 + tensor-fusion.ai/index_c: 36 + tensor-fusion.ai/index_d: 36 + tensor-fusion.ai/index_e: 36 + tensor-fusion.ai/index_f: 36 +{{- end }} \ No newline at end of file diff --git a/charts/tensor-fusion/templates/rbac.yaml b/charts/tensor-fusion/templates/rbac.yaml index 89a4b1e8..ab043c64 100644 --- a/charts/tensor-fusion/templates/rbac.yaml +++ b/charts/tensor-fusion/templates/rbac.yaml @@ -177,22 +177,9 @@ rules: - apiGroups: - karpenter.sh resources: - - nodeclaims - verbs: - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - karpenter.* - resources: - '*' verbs: - - get - - list - - watch + - '*' - apiGroups: - authentication.k8s.io resources: diff --git a/charts/tensor-fusion/values-multi-vendor.yaml b/charts/tensor-fusion/values-multi-vendor.yaml new file mode 100644 index 00000000..66233244 --- /dev/null +++ b/charts/tensor-fusion/values-multi-vendor.yaml @@ -0,0 +1 @@ +initialGpuNodeLabelSelector: "" diff --git a/cmd/hypervisor-tui/main.go b/cmd/hypervisor-tui/main.go new file mode 100644 index 00000000..e0e1294a --- /dev/null +++ b/cmd/hypervisor-tui/main.go @@ -0,0 +1,54 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "flag" + "os" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/tui" + tea "github.com/charmbracelet/bubbletea" + "k8s.io/klog/v2" +) + +var ( + host = flag.String("host", "localhost", "Hypervisor server host") + port = flag.Int("port", 8001, "Hypervisor server port") +) + +func main() { + flag.Parse() + klog.InitFlags(nil) + defer klog.Flush() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create HTTP client + client := tui.NewClient(*host, *port) + + // Create TUI model + model := tui.NewModel(ctx, client) + + // Start TUI + p := tea.NewProgram(model, tea.WithAltScreen()) + if _, err := p.Run(); err != nil { + klog.Fatalf("Error running TUI: %v", err) + os.Exit(1) + } +} diff --git a/cmd/hypervisor/main.go b/cmd/hypervisor/main.go new file mode 100644 index 00000000..4c515631 --- /dev/null +++ b/cmd/hypervisor/main.go @@ -0,0 +1,164 @@ +package main + +import ( + "context" + "flag" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/cmd/hypervisor/shm_init" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/kubernetes" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/single_node" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/metrics" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "github.com/NexusGPU/tensor-fusion/internal/version" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" +) + +var ( + acceleratorVendor = flag.String("vendor", "NVIDIA", "Accelerator vendor: NVIDIA, AMD, Intel, etc.") + acceleratorLibPath = flag.String("accelerator-lib", + "./provider/build/libaccelerator_stub.so", "Path to accelerator library") + isolationMode = flag.String("isolation-mode", "shared", + "Isolation mode: shared, soft, hard, partitioned") + backendType = flag.String("backend-type", "kubernetes", "Backend type: kubernetes, simple") + discoveryInterval = flag.Duration("discovery-interval", + 12*time.Hour, "Device discovery interval") + metricsPath = flag.String("metrics-output-path", "metrics.log", "Path to metrics output file") + + httpPort = flag.Int("port", int(constants.HypervisorDefaultPortNumber), "HTTP port for hypervisor API") +) + +func main() { + // Check for subcommands (used inside init container for initializing shared memory of limiter of soft isolation) + if len(os.Args) > 1 && os.Args[1] == constants.MountShmSubcommand { + shm_init.RunMountShm() + return + } + + flag.Parse() + klog.InitFlags(nil) + defer klog.Flush() + + ctx, cancel := context.WithCancel(context.Background()) + klog.Info("tensor fusion hypervisor starting. ", version.VersionInfo()) + + utils.NormalizeKubeConfigEnv() + + // Determine accelerator library path from env var or flag + libPath := *acceleratorLibPath + if envLibPath := os.Getenv(constants.TFAcceleratorLibPathEnv); envLibPath != "" { + libPath = envLibPath + klog.Infof("Using accelerator library path from env: %s", libPath) + } + if vendor := os.Getenv(constants.TFHardwareVendorEnv); vendor != "" { + acceleratorVendor = ptr.To(vendor) + klog.Infof("Hardware vendor from env: %s", vendor) + } + + // Create and start device controller + deviceController, err := device.NewController(ctx, libPath, *acceleratorVendor, *discoveryInterval, *isolationMode) + if err != nil { + klog.Fatalf("Failed to create device controller: %v", err) + } + if err := deviceController.Start(); err != nil { + klog.Fatalf("Failed to start device manager: %v", err) + } + klog.Info("Device manager started") + + mode := tfv1.IsolationModeType(*isolationMode) + + // initialize data backend and worker controller + var backend framework.Backend + var workerController framework.WorkerController + + switch *backendType { + case "kubernetes": + // Get Kubernetes rest config + var restConfig *rest.Config + kubeconfig := os.Getenv("KUBECONFIG") + if kubeconfig != "" { + restConfig, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + } else { + restConfig, err = rest.InClusterConfig() + } + if err != nil { + klog.Fatalf("Failed to get Kubernetes config: %v", err) + } + + backend, err = kubernetes.NewKubeletBackend(ctx, deviceController, workerController, restConfig) + if err != nil { + klog.Fatalf("Failed to create Kubernetes backend: %v", err) + } + workerController = worker.NewWorkerController(deviceController, mode, backend) + case "simple": + backend = single_node.NewSingleNodeBackend(ctx, deviceController) + workerController = worker.NewWorkerController(deviceController, mode, backend) + default: + klog.Fatalf("Invalid backend type: %s", *backendType) + } + deviceController.RegisterDeviceUpdateHandler(backend.GetDeviceChangeHandler()) + klog.Info("Device change handler registered from backend", "backend", *backendType) + + err = workerController.Start() + if err != nil { + klog.Fatalf("Failed to start worker controller: %v", err) + } + defer func() { + _ = workerController.Stop() + }() + + klog.Info("Worker controller started") + + // initialize metrics recorder + metricsRecorder := metrics.NewHypervisorMetricsRecorder(ctx, *metricsPath, deviceController, workerController) + metricsRecorder.Start() + klog.Info("Metrics recorder started") + + // initialize and start HTTP server + httpPortNum := *httpPort + if httpPortEnv := os.Getenv(constants.HypervisorPortEnv); httpPortEnv != "" { + httpPortNum, err = strconv.Atoi(httpPortEnv) + if err != nil { + klog.Fatalf("Failed to convert HTTP port from env: %v", err) + } + } + httpServer := server.NewServer(ctx, deviceController, workerController, metricsRecorder, backend, httpPortNum) + go func() { + if err := httpServer.Start(); err != nil && err != http.ErrServerClosed { + klog.Fatalf("Failed to start HTTP server: %v", err) + } + }() + klog.Info("HTTP server started") + + // Wait for interrupt signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + klog.Info("Hypervisor running") + <-sigCh + klog.Info("Stopping hypervisor...") + + // Shutdown HTTP server + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := httpServer.Stop(shutdownCtx); err != nil { + klog.Errorf("Error shutting down HTTP server: %v", err) + } + + cancel() + klog.Info("Hypervisor stopped") +} diff --git a/cmd/hypervisor/shm_init/mount_shm.go b/cmd/hypervisor/shm_init/mount_shm.go new file mode 100644 index 00000000..cd6eea08 --- /dev/null +++ b/cmd/hypervisor/shm_init/mount_shm.go @@ -0,0 +1,91 @@ +package shm_init + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "k8s.io/klog/v2" +) + +// runMountShm handles the "mount-shm" subcommand +func RunMountShm() { + // Create a new flag set for mount-shm subcommand + mountShmFlags := flag.NewFlagSet("mount-shm", flag.ExitOnError) + mountPoint := mountShmFlags.String("mount-point", "", "Mount point directory path (required)") + sizeMB := mountShmFlags.Int("size", 0, "Size in MB (required)") + + klog.InitFlags(nil) + if err := mountShmFlags.Parse(os.Args[2:]); err != nil { + klog.Fatalf("Failed to parse flags: %v", err) + } + + if *mountPoint == "" { + klog.Fatalf("mount-point is required") + } + if *sizeMB <= 0 { + klog.Fatalf("size must be greater than 0") + } + + klog.Infof("mount point: %s", *mountPoint) + klog.Infof("size: %d MB", *sizeMB) + + // Create mount point directory if it doesn't exist + if _, err := os.Stat(*mountPoint); os.IsNotExist(err) { + klog.Infof("create mount point directory: %s", *mountPoint) + if err := os.MkdirAll(*mountPoint, 0755); err != nil { + klog.Fatalf("create mount point directory failed: %v", err) + } + } + + // Check if tmpfs is already mounted + mountCmd := exec.Command("mount") + mountOutput, err := mountCmd.Output() + if err != nil { + klog.Fatalf("execute mount command failed: %v", err) + } + + mountInfo := string(mountOutput) + mountPointAbs, err := filepath.Abs(*mountPoint) + if err != nil { + klog.Fatalf("get absolute path failed: %v", err) + } + + expectedMountStr := fmt.Sprintf("on %s type tmpfs", mountPointAbs) + if strings.Contains(mountInfo, expectedMountStr) { + klog.Infof("tmpfs is already mounted on %s", *mountPoint) + } else { + // Mount tmpfs + klog.Infof("mount tmpfs on %s", *mountPoint) + sizeArg := fmt.Sprintf("size=%dM", *sizeMB) + + mountTmpfsCmd := exec.Command("mount", + "-t", "tmpfs", + "-o", fmt.Sprintf("rw,nosuid,nodev,%s", sizeArg), + "tmpfs", + mountPointAbs, + ) + + if err := mountTmpfsCmd.Run(); err != nil { + klog.Fatalf("mount tmpfs failed: %v", err) + } + + klog.Info("mount tmpfs successfully") + } + + // Set directory permissions to 0777 + // Save old umask + oldUmask := syscall.Umask(0) + defer syscall.Umask(oldUmask) + + // Set permissions + if err := os.Chmod(*mountPoint, 0777); err != nil { + klog.Fatalf("set permissions failed: %v", err) + } + + klog.Info("mount-shm completed successfully") +} diff --git a/cmd/main.go b/cmd/main.go index c55a219c..fbd70b0c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -20,9 +20,7 @@ import ( "context" "crypto/tls" "flag" - "fmt" "os" - "strings" "time" // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.) @@ -47,6 +45,8 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/utils" "github.com/NexusGPU/tensor-fusion/internal/version" webhookcorev1 "github.com/NexusGPU/tensor-fusion/internal/webhook/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -56,11 +56,13 @@ import ( clientgoscheme "k8s.io/client-go/kubernetes/scheme" _ "k8s.io/client-go/plugin/pkg/client/auth" "k8s.io/client-go/rest" + "k8s.io/client-go/util/retry" "k8s.io/klog/v2" "k8s.io/kubernetes/cmd/kube-scheduler/app" "k8s.io/kubernetes/pkg/scheduler" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/healthz" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" @@ -189,7 +191,7 @@ func main() { metricsServerOptions.FilterProvider = filters.WithAuthenticationAndAuthorization } - normalizeKubeConfigEnv() + utils.NormalizeKubeConfigEnv() kc := ctrl.GetConfigOrDie() mgr, err := ctrl.NewManager(kc, ctrl.Options{ Scheme: scheme, @@ -233,9 +235,6 @@ func main() { metricsRecorder := startMetricsRecorder(enableLeaderElection, mgr, gpuPricingMap) - // Initialize GPU allocator and set up watches - allocator, portAllocator := startTensorFusionAllocators(ctx, mgr) - // Initialize Index allocator for Device Plugin communication indexAllocator, err := indexallocator.NewIndexAllocator(ctx, mgr.GetClient()) if err != nil { @@ -244,15 +243,20 @@ func main() { } _ = indexAllocator.SetupWithManager(ctx, mgr) + // Initialize GPU allocator and set up watches + allocator, portAllocator := startTensorFusionAllocators(ctx, mgr, indexAllocator) + + ensureLeaderInfoConfigMap(mgr) + startAutoScaler(mgr, allocator) // Create pricing provider for webhook pricingProvider := pricing.NewStaticPricingProvider() startWebhook(mgr, portAllocator, indexAllocator, pricingProvider) - scheduler, nodeExpander := startScheduler(ctx, allocator, mgr, k8sVersion) + scheduler, nodeExpander := startScheduler(ctx, allocator, indexAllocator, mgr, k8sVersion) - startCustomResourceController(ctx, mgr, metricsRecorder, allocator, portAllocator, nodeExpander) + startCustomResourceController(ctx, mgr, metricsRecorder, allocator, portAllocator, indexAllocator, nodeExpander) startHttpServerForTFClient(ctx, kc, portAllocator, indexAllocator, allocator, scheduler, nodeExpander, mgr.Elected()) @@ -282,8 +286,9 @@ func addHealthCheckAPI(mgr manager.Manager) { func startTensorFusionAllocators( ctx context.Context, mgr manager.Manager, + indexAllocator *indexallocator.IndexAllocator, ) (*gpuallocator.GpuAllocator, *portallocator.PortAllocator) { - allocator := gpuallocator.NewGpuAllocator(ctx, mgr.GetClient(), 10*time.Second) + allocator := gpuallocator.NewGpuAllocator(ctx, indexAllocator, mgr.GetClient(), 10*time.Second) if err := allocator.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "unable to set up GPU allocator watches") os.Exit(1) @@ -358,6 +363,7 @@ func startCustomResourceController( metricsRecorder metrics.MetricsRecorder, allocator *gpuallocator.GpuAllocator, portAllocator *portallocator.PortAllocator, + indexAllocator *indexallocator.IndexAllocator, nodeExpander *expander.NodeExpander, ) { if os.Getenv(constants.EnableCustomResourceControllerEnv) == constants.FalseStringValue { @@ -437,11 +443,12 @@ func startCustomResourceController( os.Exit(1) } if err = (&controller.PodReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - Allocator: allocator, - PortAllocator: portAllocator, - Expander: nodeExpander, + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Allocator: allocator, + PortAllocator: portAllocator, + Expander: nodeExpander, + IndexAllocator: indexAllocator, }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "Pod") os.Exit(1) @@ -511,6 +518,7 @@ func startWebhook( func startScheduler( ctx context.Context, allocator *gpuallocator.GpuAllocator, + indexAllocator *indexallocator.IndexAllocator, mgr manager.Manager, k8sVersion *k8sVer.Version, ) (*scheduler.Scheduler, *expander.NodeExpander) { @@ -524,7 +532,7 @@ func startScheduler( gpuResourceFitOpt := app.WithPlugin( gpuResourceFitPlugin.Name, - gpuResourceFitPlugin.NewWithDeps(allocator, mgr.GetClient()), + gpuResourceFitPlugin.NewWithDeps(allocator, indexAllocator, mgr.GetClient()), ) gpuTopoOpt := app.WithPlugin( gpuTopoPlugin.Name, @@ -688,19 +696,6 @@ func startWatchGPUInfoChanges(ctx context.Context, gpuInfos *[]config.GpuInfo, g }() } -// only for local development, won't set KUBECONFIG env var in none local environments -func normalizeKubeConfigEnv() { - cfgPath := os.Getenv("KUBECONFIG") - if cfgPath != "" && strings.HasPrefix(cfgPath, "~") { - home, err := os.UserHomeDir() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - _ = os.Setenv("KUBECONFIG", strings.Replace(cfgPath, "~", home, 1)) - } -} - // Setup GreptimeDB connection func setupTimeSeriesDB() *metrics.TimeSeriesDB { timeSeriesDB := &metrics.TimeSeriesDB{} @@ -735,3 +730,32 @@ func addStopHandlers(mgr manager.Manager, allocator *gpuallocator.GpuAllocator) os.Exit(1) } } + +func ensureLeaderInfoConfigMap(mgr manager.Manager) { + err := mgr.Add(manager.RunnableFunc(func(ctx context.Context) error { + <-mgr.Elected() + leaderInfo := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: constants.LeaderInfoConfigMapName, + Namespace: utils.CurrentNamespace(), + }, + } + err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { + _, err := controllerutil.CreateOrUpdate(ctx, mgr.GetClient(), leaderInfo, func() error { + leaderInfo.Data = map[string]string{ + constants.LeaderInfoConfigMapLeaderIPKey: utils.CurrentIP(), + } + return nil + }) + return err + }) + if err != nil { + setupLog.Error(err, "Failed to update leader IP info in ConfigMap") + } + return nil + })) + if err != nil { + setupLog.Error(err, "unable to add leader info config map to manager") + os.Exit(1) + } +} diff --git a/config/crd/bases/tensor-fusion.ai_gpupools.yaml b/config/crd/bases/tensor-fusion.ai_gpupools.yaml index a8c2b5a0..afe2df8b 100644 --- a/config/crd/bases/tensor-fusion.ai_gpupools.yaml +++ b/config/crd/bases/tensor-fusion.ai_gpupools.yaml @@ -249,6 +249,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The + terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector + applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -608,6 +710,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object status: description: GPUPoolStatus defines the observed state of GPUPool. diff --git a/config/crd/bases/tensor-fusion.ai_gpus.yaml b/config/crd/bases/tensor-fusion.ai_gpus.yaml index 50c76bce..c96258f6 100644 --- a/config/crd/bases/tensor-fusion.ai_gpus.yaml +++ b/config/crd/bases/tensor-fusion.ai_gpus.yaml @@ -69,6 +69,54 @@ spec: GPUStatus defines the observed state of GPU. NOTE: When new fields added, remember to update syncGPUMetadataAndStatusFromCluster properties: + allocatedPartitions: + additionalProperties: + description: |- + AllocatedPartition represents an allocated partition on a GPU + Key in AllocatedPartitions map is podUID + properties: + allocatedAt: + description: AllocatedAt is when this partition was allocated + format: date-time + type: string + allocatedSlotEnd: + description: |- + AllocatedSlotEnd is the ending slot position (exclusive) where this partition is allocated + The partition occupies slots [AllocatedSlotStart, AllocatedSlotEnd) + format: int32 + type: integer + allocatedSlotStart: + description: |- + AllocatedSlotStart is the starting slot position where this partition is allocated + This is the actual hardware slot position (0-based index) + format: int32 + type: integer + namespace: + description: Namespace is the namespace of the pod using this + partition + type: string + podName: + description: PodName is the name of the pod using this partition + type: string + podUid: + description: PodUID is the UID of the pod using this partition + (used as map key) + type: string + templateId: + description: TemplateID is the template used to create this + partition + type: string + required: + - allocatedAt + - namespace + - podName + - podUid + - templateId + type: object + description: |- + AllocatedPartitions tracks allocated partitions on this GPU + Key is partitionUUID, value contains template info and allocated resources + type: object available: properties: compute: @@ -124,9 +172,15 @@ spec: index: format: int32 type: integer - message: + isolationMode: + default: soft + enum: + - shared + - soft + - hard + - partitioned type: string - model: + message: type: string nodeSelector: additionalProperties: @@ -138,6 +192,28 @@ spec: NUMA node format: int32 type: integer + partitionTemplates: + description: |- + PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + Reported from discovery, each template has fixed resource allocation + items: + description: |- + PartitionTemplate represents a hardware partition template (e.g., MIG profile) + Only stores template ID and name in GPU status. Detailed resource information + is stored in public GPU info config. + properties: + name: + description: Name is a human-readable name for this template + type: string + templateId: + description: TemplateID is the unique identifier for this partition + template (e.g., "1g.24gb", "4g.94gb") + type: string + required: + - name + - templateId + type: object + type: array phase: default: Pending enum: @@ -241,9 +317,6 @@ spec: Hypervisor will watch kubelet device plugin to report all GPUs already used by nvidia-device-plugin GPUs will be grouped by usedBy to be used by different Pods, tensor-fusion annotation or nvidia-device-plugin resource block - enum: - - tensor-fusion - - nvidia-device-plugin type: string uuid: type: string diff --git a/config/crd/bases/tensor-fusion.ai_schedulingconfigtemplates.yaml b/config/crd/bases/tensor-fusion.ai_schedulingconfigtemplates.yaml index c9e97ebf..245f455e 100644 --- a/config/crd/bases/tensor-fusion.ai_schedulingconfigtemplates.yaml +++ b/config/crd/bases/tensor-fusion.ai_schedulingconfigtemplates.yaml @@ -92,11 +92,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -108,19 +108,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string diff --git a/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml b/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml index d80f589b..c43bb82b 100644 --- a/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml +++ b/config/crd/bases/tensor-fusion.ai_tensorfusionclusters.yaml @@ -315,6 +315,108 @@ spec: type: boolean nodeManagerConfig: properties: + defaultVendor: + default: NVIDIA + description: |- + In single AI accelerator hardware vendor mode, when default vendor set + All nodes provisioned by NodeProvisioner or selected by NodeSelector will be set with vendor label + type: string + multiVendorNodeSelector: + additionalProperties: + description: |- + A node selector represents the union of the results of one or more label queries + over a set of nodes; that is, it represents the OR of the selectors represented + by the node selector terms. + properties: + nodeSelectorTerms: + description: Required. A list of node selector + terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements + by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements + by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the + selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + description: |- + When this field set, the GPU pool will be in multi AI accelerator vendor mode + each GPU node's vendor name is set to map key, e.g. { AMD: { nodeSelectorTerms }} + type: object nodeCompaction: properties: period: @@ -675,6 +777,9 @@ spec: type: object schedulingConfigTemplate: type: string + vendor: + default: NVIDIA + type: string type: object required: - specTemplate diff --git a/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml b/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml index 6fe04c9a..450b825f 100644 --- a/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml +++ b/config/crd/bases/tensor-fusion.ai_tensorfusionworkloads.yaml @@ -113,11 +113,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -129,19 +129,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string @@ -466,6 +466,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml b/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml index f7fd3820..ada997ea 100644 --- a/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml +++ b/config/crd/bases/tensor-fusion.ai_workloadprofiles.yaml @@ -100,11 +100,11 @@ spec: description: 'Resolution at which TSDB is queried for historical metrics. Default: 1m' type: string - lowerboundtflopspercentile: + lowerBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the lower bound on tflops recommendation. Default: 0.5' type: string - lowerboundvrampercentile: + lowerBoundVramPercentile: description: 'Vram usage percentile that will be used for the lower bound on vram recommendation. Default: 0.5' type: string @@ -116,19 +116,19 @@ spec: description: Target resource to scale, such as "tflops", "vram", or "all" by default type: string - targettflopspercentile: + targetTFlopsPercentile: description: 'Tflops usage percentile that will be used as a base for tflops target recommendation. Default: 0.9' type: string - targetvrampercentile: + targetVramPercentile: description: 'Vram usage percentile that will be used as a base for vram target recommendation. Default: 0.9' type: string - upperboundtflopspercentile: + upperBoundTflopsPercentile: description: 'Tflops usage percentile that will be used for the upper bound on tflops recommendation. Default: 0.95' type: string - upperboundvrampercentile: + upperBoundVramPercentile: description: 'Vram usage percentile that will be used for the upper bound on vram recommendation. Default: 0.95' type: string @@ -453,6 +453,11 @@ spec: type: object x-kubernetes-map-type: atomic type: object + partitionTemplateId: + description: |- + PartitionTemplateID specifies the partition template ID for partitioned isolation mode + This is read from pod annotation tensor-fusion.ai/partition if specified + type: string poolName: type: string qos: diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index a9cd2546..8aff3d82 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -125,6 +125,12 @@ rules: - patch - update - watch +- apiGroups: + - karpenter.sh + resources: + - '*' + verbs: + - '*' - apiGroups: - tensor-fusion.ai resources: diff --git a/dockerfile/node-discovery.Dockerfile b/dockerfile/hypervisor.Dockerfile similarity index 83% rename from dockerfile/node-discovery.Dockerfile rename to dockerfile/hypervisor.Dockerfile index 09ac6741..e2eae468 100644 --- a/dockerfile/node-discovery.Dockerfile +++ b/dockerfile/hypervisor.Dockerfile @@ -15,6 +15,7 @@ RUN go mod download COPY cmd/ cmd/ COPY api/ api/ COPY internal/ internal/ +COPY provider/ provider/ # Build @@ -22,13 +23,13 @@ COPY internal/ internal/ # was called. For example, if we call make docker-build in a local env which has the Apple Silicon M1 SO # the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, # by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. -RUN CGO_ENABLED=1 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o nodediscovery cmd/nodediscovery/main.go +RUN CGO_ENABLED=1 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o hypervisor cmd/hypervisor/main.go -# Use distroless as minimal base image to package the nodediscovery binary +# Use distroless as minimal base image to package the hypervisor binary # Refer to https://github.com/GoogleContainerTools/distroless for more details FROM ubuntu:24.04 WORKDIR / -COPY --from=builder /workspace/nodediscovery . +COPY --from=builder /workspace/hypervisor . USER 65532:65532 -ENTRYPOINT ["/nodediscovery"] +ENTRYPOINT ["/hypervisor"] diff --git a/go.mod b/go.mod index 9df8bb74..1d474d9b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,10 @@ require ( github.com/aws/aws-sdk-go-v2/service/ec2 v1.275.0 github.com/aws/smithy-go v1.23.2 github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a + github.com/charmbracelet/bubbles v0.21.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-contrib/gzip v1.2.5 github.com/gin-gonic/gin v1.11.0 github.com/go-sql-driver/mysql v1.9.3 @@ -27,6 +31,7 @@ require ( go.uber.org/zap v1.27.1 golang.org/x/time v0.14.0 gomodules.xyz/jsonpatch/v2 v2.5.0 + google.golang.org/grpc v1.77.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.6.0 gorm.io/gorm v1.31.1 @@ -39,6 +44,7 @@ require ( k8s.io/component-helpers v0.34.2 k8s.io/klog/v2 v2.130.1 k8s.io/kube-scheduler v0.34.2 + k8s.io/kubelet v0.34.2 k8s.io/kubernetes v1.34.2 k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 sigs.k8s.io/controller-runtime v0.22.4 @@ -53,10 +59,12 @@ require ( github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.14 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.14 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.14 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect @@ -64,15 +72,19 @@ require ( github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/ansi v0.10.1 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -110,6 +122,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -119,23 +132,32 @@ require ( github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/moby/term v0.5.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/posthog/posthog-go v1.6.13 // indirect github.com/prometheus/client_golang v1.23.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.17.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.55.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect github.com/spf13/cobra v1.10.1 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stoewer/go-strcase v1.3.1 // indirect @@ -143,11 +165,12 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.etcd.io/etcd/api/v3 v3.6.4 // indirect go.etcd.io/etcd/client/pkg/v3 v3.6.4 // indirect go.etcd.io/etcd/client/v3 v3.6.4 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect @@ -164,15 +187,14 @@ require ( golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.47.0 // indirect - golang.org/x/oauth2 v0.31.0 // indirect + golang.org/x/oauth2 v0.32.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/tools v0.38.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 // indirect - google.golang.org/grpc v1.75.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect @@ -186,7 +208,6 @@ require ( k8s.io/dynamic-resource-allocation v0.34.0 // indirect k8s.io/kms v0.34.2 // indirect k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611 // indirect - k8s.io/kubelet v0.34.0 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.33.0 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect diff --git a/go.sum b/go.sum index dab34718..3ad4bad9 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/aliyun/alibaba-cloud-sdk-go v1.63.107 h1:qagvUyrgOnBIlVRQWOyCZGVKUIYb github.com/aliyun/alibaba-cloud-sdk-go v1.63.107/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= github.com/aws/aws-sdk-go-v2 v1.40.0 h1:/WMUA0kjhZExjOQN2z3oLALDREea1A7TobfuiBrKlwc= @@ -40,6 +42,10 @@ github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a h1:qstXCawuAwrgFLoaU1IIYGGFeVKVBkJMVSSSKJXBD14= github.com/awslabs/operatorpkg v0.0.0-20251024191238-14554b75b88a/go.mod h1:D4OLvXkR+2pp9RKo8Ovjc1Mqnd0qPRW0gz3cjxGSCkA= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= +github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -54,6 +60,22 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= +github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= +github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= @@ -74,6 +96,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U= github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -194,6 +218,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92Bcuy github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -242,12 +268,18 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE= github.com/mfridman/tparse v0.18.0/go.mod h1:gEvqZTuCgEhPbYk/2lS3Kcxg1GmTxxU7kTC8DvP0i/A= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= @@ -262,6 +294,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -282,6 +320,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posthog/posthog-go v1.6.13 h1:4t9j0VOIJBgITm4v5rLsLy3IKUkU9dn2VusMNzZXScw= +github.com/posthog/posthog-go v1.6.13/go.mod h1:LcC1Nu4AgvV22EndTtrMXTy+7RGVC0MhChSw7Qk5XkY= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= @@ -294,11 +334,16 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= +github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= @@ -348,6 +393,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 h1:S2dVYn90KE98chqDkyE9Z4N61UnQd+KOfgp5Iu53llk= github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= @@ -366,8 +413,8 @@ go.etcd.io/etcd/server/v3 v3.6.4 h1:LsCA7CzjVt+8WGrdsnh6RhC0XqCsLkBly3ve5rTxMAU= go.etcd.io/etcd/server/v3 v3.6.4/go.mod h1:aYCL/h43yiONOv0QIR82kH/2xZ7m+IWYjzRmyQfnCAg= go.etcd.io/raft/v3 v3.6.0 h1:5NtvbDVYpnfZWcIHgGRk9DyzkBIXOi8j+DDp1IcnUWQ= go.etcd.io/raft/v3 v3.6.0/go.mod h1:nLvLevg6+xrVtHUmVaTcTz603gQPHfh7kUAwV6YpfGo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= @@ -432,8 +479,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= -golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -445,6 +492,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= @@ -478,12 +526,12 @@ gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= -google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 h1:APHvLLYBhtZvsbnpkfknDZ7NyH4z5+ub/I0u8L3Oz6g= -google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1/go.mod h1:xUjFWUnWDpZ/C0Gu0qloASKFb6f8/QXiiXhSPFsD668= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1 h1:pmJpJEvT846VzausCQ5d7KreSROcDqmO388w5YbnltA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250826171959-ef028d996bc1/go.mod h1:GmFNa4BdJZ2a8G+wCe9Bg3wwThLrJun751XstdJt5Og= -google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= -google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -542,8 +590,8 @@ k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611 h1:o4oKOsvSymDkZRsMAPZU7b k8s.io/kube-openapi v0.0.0-20250905212525-66792eed8611/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/kube-scheduler v0.34.2 h1:TtLcaXeIpkqgzMr2ch7Ap8Cluq4M182XUDRlnOPDdoc= k8s.io/kube-scheduler v0.34.2/go.mod h1:PTn4QYiSet8/00VQ2qGO/HWdo5iNJlVRCXz/7R3Ut5I= -k8s.io/kubelet v0.34.0 h1:1nZt1Q6Kfx7xCaTS9vnqR9sjZDxf3cRSQkAFCczULmc= -k8s.io/kubelet v0.34.0/go.mod h1:NqbF8ViVettlZbf9hw9DJhubaWn7rGvDDTcLMDm6tQ0= +k8s.io/kubelet v0.34.2 h1:Dl+1uh7xwJr70r+SHKyIpvu6XvzuoPu0uDIC4cqgJUs= +k8s.io/kubelet v0.34.2/go.mod h1:RfwR03iuKeVV7Z1qD9XKH98c3tlPImJpQ3qHIW40htM= k8s.io/kubernetes v1.34.2 h1:WQdDvYJazkmkwSncgNwGvVtaCt4TYXIU3wSMRgvp3MI= k8s.io/kubernetes v1.34.2/go.mod h1:m6pZk6a179pRo2wsTiCPORJ86iOEQmfIzUvtyEF8BwA= k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck= diff --git a/internal/alert/evaluator.go b/internal/alert/evaluator.go index 3f9a6384..5b3177c5 100644 --- a/internal/alert/evaluator.go +++ b/internal/alert/evaluator.go @@ -108,7 +108,7 @@ func renderQueryTemplate(rule *config.AlertRule) (string, error) { } var buf bytes.Buffer - data := map[string]interface{}{ + data := map[string]any{ "Threshold": rule.Threshold, "Conditions": fmt.Sprintf("ts >= now() - '%s'::INTERVAL", rule.EvaluationInterval), "Severity": rule.Severity, @@ -169,8 +169,8 @@ func (e *AlertEvaluator) processQueryResults(rows *sql.Rows, rule *config.AlertR return nil, fmt.Errorf("failed to get columns: %w", err) } - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } @@ -178,7 +178,7 @@ func (e *AlertEvaluator) processQueryResults(rows *sql.Rows, rule *config.AlertR return nil, fmt.Errorf("failed to scan row: %w", err) } - rowData := make(map[string]interface{}) + rowData := make(map[string]any) for i, col := range columns { rowData[col] = values[i] } diff --git a/internal/autoscaler/autoscaler_suite_test.go b/internal/autoscaler/autoscaler_suite_test.go index 0595acce..098ba11a 100644 --- a/internal/autoscaler/autoscaler_suite_test.go +++ b/internal/autoscaler/autoscaler_suite_test.go @@ -155,7 +155,7 @@ var _ = BeforeSuite(func() { WorkerUnitPriceMap: make(map[string]map[string]metrics.RawBillingPricing), } - allocator = gpuallocator.NewGpuAllocator(ctx, mgr.GetClient(), 150*time.Millisecond) + allocator = gpuallocator.NewGpuAllocator(ctx, nil, mgr.GetClient(), 150*time.Millisecond) err = allocator.SetupWithManager(ctx, mgr) Expect(err).ToNot(HaveOccurred()) @@ -273,7 +273,9 @@ var _ = BeforeSuite(func() { var _ = AfterSuite(func() { By("tearing down the test environment") - allocator.Stop() + if allocator != nil { + allocator.Stop() + } cancel() err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) diff --git a/internal/autoscaler/autoscaler_test.go b/internal/autoscaler/autoscaler_test.go index 2eba22fb..1055f98e 100644 --- a/internal/autoscaler/autoscaler_test.go +++ b/internal/autoscaler/autoscaler_test.go @@ -91,11 +91,11 @@ var _ = Describe("Autoscaler", func() { // create two workloads pool := tfEnv.GetGPUPool(0) - // with two replias + // with two replicas workload0 := createWorkload(pool, 0, 2) workload0Workers := getWorkers(workload0) key0 := WorkloadID{workload0.Namespace, workload0.Name} - // with one replia + // with one replica workload1 := createWorkload(pool, 1, 1) workload1Workers := getWorkers(workload1) key1 := WorkloadID{workload1.Namespace, workload1.Name} @@ -539,8 +539,8 @@ func (f *FakeRecommender) Name() string { return "fake" } -func (f *FakeRecommender) Recommend(ctx context.Context, workoad *workload.State) (*recommender.RecResult, error) { - meta.SetStatusCondition(&workoad.Status.Conditions, metav1.Condition{ +func (f *FakeRecommender) Recommend(ctx context.Context, workload *workload.State) (*recommender.RecResult, error) { + meta.SetStatusCondition(&workload.Status.Conditions, metav1.Condition{ Type: constants.ConditionStatusTypeRecommendationProvided, Status: metav1.ConditionTrue, LastTransitionTime: metav1.Now(), @@ -667,7 +667,7 @@ func mockSchedulerLoop(ctx context.Context, cfg *rest.Config) { func scheduleAndStartPod(pod *corev1.Pod, clientset *kubernetes.Clientset) { // simulate scheduling cycle Filter and Reserve - allocRequest, _, err := allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(ctx, pod) Expect(err).To(Succeed()) gpus, err := allocator.Alloc(allocRequest) if err != nil { diff --git a/internal/cloudprovider/pricing/pricing.go b/internal/cloudprovider/pricing/pricing.go index 45dd09bb..65dfccbd 100644 --- a/internal/cloudprovider/pricing/pricing.go +++ b/internal/cloudprovider/pricing/pricing.go @@ -31,6 +31,7 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/types" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" "k8s.io/apimachinery/pkg/api/resource" "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -104,6 +105,9 @@ func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.G tflopsMap[gpuInfo.Model] = completeInfo } + // Load partition templates from config + gpuallocator.LoadPartitionTemplatesFromConfig(*gpuInfos) + initOnce.Do(func() { globalAWSGPUInstanceData = make(map[string]GPUNodeInstanceInfoAndPrice) globalAzureGPUInstanceData = make(map[string]GPUNodeInstanceInfoAndPrice) diff --git a/internal/component/client.go b/internal/component/client.go index 4b2549a9..7aa196e9 100644 --- a/internal/component/client.go +++ b/internal/component/client.go @@ -13,7 +13,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -const ( +var ( ClientUpdateInProgressAnnotation = constants.Domain + "/client-update-in-progress" ClientBatchUpdateLastTimeAnnotation = constants.Domain + "/client-batch-update-last-time" ) @@ -23,7 +23,7 @@ type Client struct { } func (c *Client) GetName() string { - return "client" + return constants.ComponentClient } func (c *Client) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { diff --git a/internal/component/component.go b/internal/component/component.go index e3940a15..1d429d88 100644 --- a/internal/component/component.go +++ b/internal/component/component.go @@ -151,11 +151,11 @@ func isAutoUpdateEnable(component Interface, pool *tfv1.GPUPool) bool { if pool.Spec.NodeManagerConfig != nil { updatePolicy := pool.Spec.NodeManagerConfig.NodePoolRollingUpdatePolicy switch component.GetName() { - case "hypervisor": + case constants.ComponentHypervisor: return updatePolicy.AutoUpdateHypervisor - case "worker": + case constants.ComponentWorker: return updatePolicy.AutoUpdateWorker - case "client": + case constants.ComponentClient: return updatePolicy.AutoUpdateClient } } @@ -170,7 +170,7 @@ func calculateDesiredUpdatedDelta(total int, updatedSize int, batchPercentage in currentBatchIndex = newUpdateProgress / batchPercentage desiredSize = min((currentBatchIndex+1)*int32(batchSize), int32(total)) delta = desiredSize - int32(updatedSize) - // if rolling udpate policy changed or new nodes were added during update, we need to update progress + // if rolling update policy changed or new nodes were added during update, we need to update progress if delta < 0 { newUpdateProgress = min(newUpdateProgress+batchPercentage, 100) } else { diff --git a/internal/component/hypervisor.go b/internal/component/hypervisor.go index b33d03c8..3eb763b1 100644 --- a/internal/component/hypervisor.go +++ b/internal/component/hypervisor.go @@ -14,7 +14,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -const ( +var ( HypervisorUpdateInProgressAnnotation = constants.Domain + "/hypervisor-update-in-progress" HypervisorBatchUpdateLastTimeAnnotation = constants.Domain + "/hypervisor-batch-update-last-time" ) @@ -24,7 +24,7 @@ type Hypervisor struct { } func (h *Hypervisor) GetName() string { - return "hypervisor" + return constants.ComponentHypervisor } func (h *Hypervisor) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { @@ -88,7 +88,7 @@ func (h *Hypervisor) GetResourcesInfo(r client.Client, ctx context.Context, pool } key := client.ObjectKey{ Namespace: utils.CurrentNamespace(), - Name: fmt.Sprintf("hypervisor-%s", node.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", node.Name), } pod := &corev1.Pod{} err := r.Get(ctx, key, pod) diff --git a/internal/component/worker.go b/internal/component/worker.go index 4ed80086..02f98759 100644 --- a/internal/component/worker.go +++ b/internal/component/worker.go @@ -13,7 +13,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -const ( +var ( WorkerUpdateInProgressAnnotation = constants.Domain + "/worker-update-in-progress" WorkerBatchUpdateLastTimeAnnotation = constants.Domain + "/worker-batch-update-last-time" ) @@ -23,7 +23,7 @@ type Worker struct { } func (w *Worker) GetName() string { - return "worker" + return constants.ComponentWorker } func (w *Worker) DetectConfigChange(pool *tfv1.GPUPool, status *tfv1.PoolComponentStatus) (bool, string, string) { diff --git a/internal/config/gpu_info.go b/internal/config/gpu_info.go index f05bace1..830548b8 100644 --- a/internal/config/gpu_info.go +++ b/internal/config/gpu_info.go @@ -10,6 +10,49 @@ type GpuInfo struct { CostPerHour float64 `json:"costPerHour"` Fp16TFlops resource.Quantity `json:"fp16TFlops"` FullModelName string `json:"fullModelName"` + + // PartitionTemplates contains available partition templates for this GPU (e.g., MIG profiles) + // Only applicable for GPUs that support hardware partitioning + PartitionTemplates []PartitionTemplateInfo `json:"partitionTemplates,omitempty"` + + // MaxPartitions is the maximum number of partitions this GPU can support (e.g., 7 for MIG) + MaxPartitions uint32 `json:"maxPartitions,omitempty"` + + // MaxPlacementSlots is the maximum number of placement slots this GPU can support (e.g., 8 for NVIDIA MIG) + MaxPlacementSlots uint32 `json:"maxPlacementSlots,omitempty"` +} + +// PartitionTemplateInfo contains detailed resource information for a partition template +type PartitionTemplateInfo struct { + // TemplateID is the unique identifier for this partition template Profile `19` for 1g.10gb in A100 + TemplateID string `json:"templateId"` + + // TemplateID is the unique identifier (e.g., "1g.24gb", "4g.94gb") + Name string `json:"name"` + + // MemoryGigabytes is the memory allocated to this partition in gigabytes + MemoryGigabytes uint64 `json:"memoryGigabytes"` + + // ComputePercent is the percent of sliced GPU (0-100) + ComputePercent float64 `json:"computePercent"` + + // Description provides additional information about this template + Description string `json:"description,omitempty"` + + // MaxPartition for this single template, eg. 1g.10gb+me can only be allocate once + MaxPartition uint32 `json:"maxPartition"` + + // The placement limit for this template, use a bitmask to represent the placement limit + // e.g. sudo nvidia-smi mig -i 0 -lgipp + // GPU 0 Profile ID 19 Placements: {0,1,2,3,4,5,6}:1 + // GPU 0 Profile ID 20 Placements: {0,1,2,3,4,5,6}:1 + // GPU 0 Profile ID 15 Placements: {0,2,4,6}:2 + // GPU 0 Profile ID 14 Placements: {0,2,4}:2 + // GPU 0 Profile ID 9 Placements: {0,4}:4 + // GPU 0 Profile ID 5 Placement : {0}:4 + // GPU 0 Profile ID 0 Placement : {0}:8 + PlacementLimit []uint32 `json:"placementLimit"` + PlacementOffSet uint32 `json:"placementOffSet"` } func MockGpuInfo() *[]GpuInfo { diff --git a/internal/config/rules.go b/internal/config/rules.go index 8bbfb556..486b4bcd 100644 --- a/internal/config/rules.go +++ b/internal/config/rules.go @@ -60,7 +60,7 @@ func (r *AlertRule) String() string { r.Name, r.Query, r.Threshold, r.EvaluationInterval, r.ConsecutiveCount, r.Severity) } -func (r *AlertRule) AddFiringAlertAndCheckResolved(alertQueryResult map[string]interface{}) (*PostableAlert, bool, string) { +func (r *AlertRule) AddFiringAlertAndCheckResolved(alertQueryResult map[string]any) (*PostableAlert, bool, string) { if r.FiringAlerts == nil { r.FiringAlerts = make(map[string]*FiringAlertCache) } @@ -122,7 +122,7 @@ func (r *AlertRule) IsTestMode() bool { return r.TestMode } -func (r *AlertRule) toPostableAlert(alertQueryResult map[string]interface{}, startsAt time.Time, isResolved bool) PostableAlert { +func (r *AlertRule) toPostableAlert(alertQueryResult map[string]any, startsAt time.Time, isResolved bool) PostableAlert { summary, description, instance, err := r.renderAlertContentTemplate(alertQueryResult) if err != nil { @@ -147,7 +147,7 @@ func (r *AlertRule) toPostableAlert(alertQueryResult map[string]interface{}, sta return alert } -func (rule *AlertRule) renderAlertContentTemplate(data interface{}) (string, string, string, error) { +func (rule *AlertRule) renderAlertContentTemplate(data any) (string, string, string, error) { if rule.summaryTmplParsed == nil { summaryTmplParsed, err := template.New("summary").Parse(rule.Summary) rule.summaryTmplParsed = summaryTmplParsed diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 557fdabd..ecebe5f2 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -1,6 +1,7 @@ package constants import ( + "os" "time" "k8s.io/utils/ptr" @@ -17,15 +18,32 @@ var ( UnschedQueueBufferDuration = 10 * time.Second ) -const ( +var ( // Domain is the domain prefix used for all tensor-fusion.ai related annotations and finalizers - Domain = "tensor-fusion.ai" + // Change env var for enterprise's custom domain + DomainPrefix = func() string { + domainPrefix := os.Getenv("TENSOR_FUSION_CUSTOM_DOMAIN_PREFIX") + if domainPrefix == "" { + return "tensor-fusion" + } + return domainPrefix + }() + + DomainSuffix = func() string { + domainSuffix := os.Getenv("TENSOR_FUSION_CUSTOM_DOMAIN_SUFFIX") + if domainSuffix == "" { + return "ai" + } + return domainSuffix + }() + + Domain = DomainPrefix + "." + DomainSuffix // Finalizer constants FinalizerSuffix = "finalizer" Finalizer = Domain + "/" + FinalizerSuffix - SchedulerName = "tensor-fusion-scheduler" + SchedulerName = DomainPrefix + "-scheduler" LabelKeyOwner = Domain + "/managed-by" LabelKeyClusterOwner = Domain + "/cluster" @@ -83,7 +101,13 @@ const ( // GPUModelAnnotation specifies the required GPU model (e.g., "A100", "H100") GPUModelAnnotation = Domain + "/gpu-model" // GPU ID list is assigned by scheduler, should not specified by user - GPUDeviceIDsAnnotation = Domain + "/gpu-ids" + GPUDeviceIDsAnnotation = Domain + "/gpu-ids" + // User can specify the partition name to designate the partition template to use, e.g. 1g.20gb+me + // TODO: parse and pre-set in scheduler plugin to avoid find matched partition. + PartitionNameAnnotation = Domain + "/partition" + // PartitionTemplateIDAnnotation is the partition UUID assigned to a pod in partitioned mode + // This is read by accelerator.c to mock slice GPU like MIG does + PartitionTemplateIDAnnotation = Domain + "/partition-id" DedicatedGPUAnnotation = Domain + "/dedicated-gpu" SetPendingOwnedWorkloadAnnotation = Domain + "/pending-owned-workload" PricingAnnotation = Domain + "/hourly-pricing" @@ -92,8 +116,10 @@ const ( // Additional worker pod template is set by user with /worker-pod-template annotation WorkerPodTemplateAnnotation = Domain + "/worker-pod-template" - // Pod index annotation for Device Plugin communication (1-512) + // Pod index annotation for Device Plugin communication (1-128) + // When it's in annotation, use this string, when it's in resource limits, use it as prefix PodIndexAnnotation = Domain + "/index" + PodIndexDelimiter = "_" WorkloadModeAnnotation = Domain + "/workload-mode" WorkloadModeDynamic = "dynamic" @@ -138,7 +164,9 @@ const ( HypervisorServiceAccountName = "tensor-fusion-hypervisor-sa" TSDBVersionConfigMap = "tensor-fusion-tsdb-version" +) +const ( QoSLevelLow = "low" QoSLevelMedium = "medium" QoSLevelHigh = "high" @@ -178,7 +206,7 @@ const ( PhaseFailed = "Failed" ) -const ( +var ( // No disrupt label, similar to Karpenter, avoid TFConnection/Worker/GPUNode to be moved to another node or destroying node. // Refer: https://karpenter.sh/docs/concepts/disruption/ SchedulingDoNotDisruptLabel = Domain + "/do-not-disrupt" @@ -191,27 +219,28 @@ const ( ) // To match GPUNode with K8S node, when creating from cloud vendor, must set a label from cloud-init userdata -const ( +var ( ProvisionerLabelKey = Domain + "/node-provisioner" ProvisionerMissingLabel = Domain + "/orphan" ProvisionerNamePlaceholder = "__GPU_NODE_RESOURCE_NAME__" ) +var ( + TFDataPath = "/run/tensor-fusion" + TFDataPathWorkerExpr = "shm/$(POD_NAMESPACE)/$(POD_NAME)" + DataVolumeName = "tf-data" + TransportShmVolumeName = "tf-transport-shm" + TransportShmPath = "/dev/shm" + TensorFusionPoolManualCompaction = Domain + "/manual-compaction" + TensorFusionSystemName = DomainPrefix -const TFDataPath = "/run/tensor-fusion" -const TFDataPathWorkerExpr = "shm/$(POD_NAMESPACE)/$(POD_NAME)" -const DataVolumeName = "tf-data" -const TransportShmVolumeName = "tf-transport-shm" -const TransportShmPath = "/dev/shm" -const TensorFusionPoolManualCompaction = Domain + "/manual-compaction" -const TensorFusionSystemName = "tensor-fusion" - -const ( LeaderInfoConfigMapName = "tensor-fusion-operator-leader-info" LeaderInfoConfigMapLeaderIPKey = "leader-ip" + AcceleratorLabelVendor = Domain + "/hardware-vendor" ) const ShortUUIDAlphabet = "123456789abcdefghijkmnopqrstuvwxy" const SpotInstanceAssumedDiscountRatio = 0.3 +const MountShmSubcommand = "mount-shm" const ( LowFrequencyObjFailureInitialDelay = 300 * time.Millisecond @@ -221,6 +250,13 @@ const ( LowFrequencyObjFailureConcurrentReconcile = 5 ) +const ( + TelemetryEndpointEnvVar = "TELEMETRY_ENDPOINT" + TelemetryPublicKeyEnvVar = "TELEMETRY_PUBLIC_KEY" + DefaultTelemetryEndpoint = "https://us.i.posthog.com" + DefaultTelemetryPublicKey = "phc_qd1mhrtK35PpXx0bYQAYcscTJNnno73mC9qMwioTCi7" +) + const GiBToBytes = 1024 * 1024 * 1024 const AuthorizationHeader = "Authorization" @@ -233,3 +269,10 @@ const DefaultEvictionProtectionPriceRatio = 1.2 const NodeCriticalPriorityClassName = "system-node-critical" const KarpenterNodeClaimKind = "NodeClaim" const KarpenterNodePoolKind = "NodePool" + +const ( + // 16x8 dummy index device at max + // tensor-fusion.ai/index_0: 1 to tensor-fusion.ai/index_f: 8 + IndexKeyLength = 16 + IndexModLength = 8 +) diff --git a/internal/constants/env.go b/internal/constants/env.go index 52801324..c5521e68 100644 --- a/internal/constants/env.go +++ b/internal/constants/env.go @@ -136,21 +136,22 @@ const ( // TensorFusion hypervisor related envs const ( - HypervisorPoolNameEnv = "TENSOR_FUSION_POOL_NAME" - PodNameEnv = "POD_NAME" - VectorPodNodeNameEnv = "NODE_NAME" - HypervisorGPUNodeNameEnv = "GPU_NODE_NAME" - HypervisorSchedulingConfigEnv = "TF_HYPERVISOR_SCHEDULING_CONFIG" - HypervisorListenAddrEnv = "API_LISTEN_ADDR" - HypervisorMetricsFormatEnv = "TF_HYPERVISOR_METRICS_FORMAT" - HypervisorMetricsExtraLabelsEnv = "TF_HYPERVISOR_METRICS_EXTRA_LABELS" - HypervisorDetectUsedGPUEnv = "DETECT_IN_USED_GPUS" - HypervisorDevicePluginPathEnv = "DEVICE_PLUGIN_PATH" + HypervisorPoolNameEnv = "TENSOR_FUSION_POOL_NAME" + PodNameEnv = "POD_NAME" + VectorPodNodeNameEnv = "NODE_NAME" + HypervisorGPUNodeNameEnv = "GPU_NODE_NAME" + HypervisorSchedulingConfigEnv = "TF_HYPERVISOR_SCHEDULING_CONFIG" + HypervisorListenAddrEnv = "API_LISTEN_ADDR" + HypervisorMetricsFormatEnv = "TF_HYPERVISOR_METRICS_FORMAT" + HypervisorMetricsExtraLabelsEnv = "TF_HYPERVISOR_METRICS_EXTRA_LABELS" + HypervisorDetectUsedGPUEnv = "DETECT_IN_USED_GPUS" + HypervisorDevicePluginPathEnv = "DEVICE_PLUGIN_PATH" + HypervisorKubeletCheckpointPathEnv = "KUBELET_CHECKPOINT_PATH" // Add ptrace capability to hypervisor container, to trace all host PID using GPU SystemPtraceCapability = "SYS_PTRACE" - HypervisorDefaultPortNumber int32 = 8000 + HypervisorDefaultPortNumber int32 = 8001 HypervisorPortName string = "http" // For security enhancement, there are 2 types of endpoints to protect @@ -161,6 +162,10 @@ const ( // but k3s and some K8S distribution may not support, need to find some way to get SA token JWT pub key HypervisorVerifyServiceAccountEnabledEnvVar = "SA_TOKEN_VERIFY_ENABLED" HypervisorVerifyServiceAccountPublicKeyEnvVar = "SA_TOKEN_VERIFY_PUBLIC_KEY" + + // Hardware vendor and accelerator library path for multi-vendor support + TFHardwareVendorEnv = "TF_HARDWARE_VENDOR" + TFAcceleratorLibPathEnv = "TF_ACCELERATOR_LIB_PATH" ) // Node discovery related envs diff --git a/internal/constants/vendors.go b/internal/constants/vendors.go index f72c4636..ba3fc16e 100644 --- a/internal/constants/vendors.go +++ b/internal/constants/vendors.go @@ -70,3 +70,19 @@ var L3VirtualizationSupportedVendors = []map[string]bool{ AcceleratorVendorHuaweiAscendNPU: false, }, } + +// GetAcceleratorLibPath returns the accelerator library path based on vendor +// Vendor string should match constants from internal/constants/vendors.go +func GetAcceleratorLibPath(vendor string) string { + switch vendor { + case AcceleratorVendorNvidia: + return "libaccelerator_nvidia.so" + case AcceleratorVendorAMD: + return "libaccelerator_amd.so" + case AcceleratorVendorHuaweiAscendNPU: + return "libaccelerator_ascend.so" + default: + // Default to stub library for unknown vendors + return "libaccelerator_stub.so" + } +} diff --git a/internal/controller/gpunode_controller.go b/internal/controller/gpunode_controller.go index 4a6c235f..7eb5f45d 100644 --- a/internal/controller/gpunode_controller.go +++ b/internal/controller/gpunode_controller.go @@ -30,7 +30,6 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/metrics" "github.com/NexusGPU/tensor-fusion/internal/scheduler/expander" "github.com/NexusGPU/tensor-fusion/internal/utils" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -59,6 +58,7 @@ type GPUNodeReconciler struct { // +kubebuilder:rbac:groups=tensor-fusion.ai,resources=gpunodes/status,verbs=get;update;patch // +kubebuilder:rbac:groups=tensor-fusion.ai,resources=gpunodes/finalizers,verbs=update // +kubebuilder:rbac:groups=coordination.k8s.io,resources=leases,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=karpenter.sh,resources=*,verbs=* // Reconcile GPU nodes func (r *GPUNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { @@ -103,7 +103,7 @@ func (r *GPUNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct poolObj := &tfv1.GPUPool{} err = r.Get(ctx, client.ObjectKey{Name: poolName}, poolObj) if err != nil { - return ctrl.Result{}, fmt.Errorf("failed to get tensor-fusion pool, can not create node discovery job, pool: %s", poolName) + return ctrl.Result{}, fmt.Errorf("failed to get tensor-fusion pool, pool: %s", poolName) } // Check if the Kubernetes node exists; if not, the GPUNode should delete itself. @@ -135,15 +135,6 @@ func (r *GPUNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct } } - if err := r.reconcileNodeDiscoveryJob(ctx, node, poolObj); err != nil { - return ctrl.Result{}, err - } - - if node.Status.TotalGPUs == 0 { - log.Info("GPU on this node has not been discovered, wait next loop", "node", node.Name) - return ctrl.Result{}, nil - } - hypervisorName, err := r.reconcileHypervisorPod(ctx, node, poolObj, coreNode) if err != nil { return ctrl.Result{}, err @@ -259,77 +250,6 @@ func (r *GPUNodeReconciler) fetchAllOwnedGPUDevices(ctx context.Context, node *t return gpuList.Items, nil } -func (r *GPUNodeReconciler) reconcileNodeDiscoveryJob( - ctx context.Context, - gpunode *tfv1.GPUNode, - pool *tfv1.GPUPool, -) error { - log := log.FromContext(ctx) - log.Info("starting node discovery job") - - if pool.Spec.ComponentConfig == nil || pool.Spec.ComponentConfig.NodeDiscovery.PodTemplate == nil { - return fmt.Errorf(`missing node discovery pod template in pool spec`) - } - podTmpl := &corev1.PodTemplate{} - err := json.Unmarshal(pool.Spec.ComponentConfig.NodeDiscovery.PodTemplate.Raw, podTmpl) - if err != nil { - return fmt.Errorf("unmarshal pod template: %w", err) - } - tmpl := podTmpl.Template - if tmpl.Labels == nil { - tmpl.Labels = map[string]string{} - } - tmpl.Labels[constants.LabelComponent] = constants.ComponentNodeDiscovery - tmpl.Spec.NodeName = gpunode.Name - // allow job to run at any taint Nodes that marked as NoSchedule - tmpl.Spec.Tolerations = append(tmpl.Spec.Tolerations, corev1.Toleration{ - Key: string(corev1.TaintEffectNoSchedule), - Operator: corev1.TolerationOpExists, - }) - tmpl.Spec.EnableServiceLinks = ptr.To(false) - - utils.AddTFNodeDiscoveryConfAfterTemplate(ctx, &tmpl, pool, gpunode.Name, r.CompatibleWithNvidiaContainerToolkit) - - // create node-discovery job - job := &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: getDiscoveryJobName(gpunode.Name), - Namespace: utils.CurrentNamespace(), - Labels: tmpl.Labels, - Annotations: tmpl.Annotations, - }, - Spec: batchv1.JobSpec{ - TTLSecondsAfterFinished: ptr.To[int32](3600 * 10), - Template: tmpl, - }, - } - - if err := r.Get(ctx, client.ObjectKeyFromObject(job), job); err != nil { - if errors.IsNotFound(err) { - if err := ctrl.SetControllerReference(gpunode, job, r.Scheme); err != nil { - return fmt.Errorf("set owner reference %w", err) - } - if err := r.Create(ctx, job); err != nil { - return fmt.Errorf("create node discovery job %w", err) - } - } else { - return fmt.Errorf("create node discovery job %w", err) - } - } - - if job.Status.Failed > 0 { - log.Info("node discovery job failed, update GPU node status to failed", "node", gpunode.Name) - // Update phase to failed, require manual address why it failed and restart of node discovery job - gpunode.Status.Phase = tfv1.TensorFusionGPUNodePhaseFailed - if err := r.Status().Update(ctx, gpunode); err != nil { - return fmt.Errorf("failed to update GPU node status to failed: %w", err) - } - metrics.SetNodeMetrics(gpunode, pool, nil) - } - - return nil -} - func (r *GPUNodeReconciler) reconcileHypervisorPod( ctx context.Context, node *tfv1.GPUNode, @@ -344,7 +264,7 @@ func (r *GPUNodeReconciler) reconcileHypervisorPod( key := client.ObjectKey{ Namespace: utils.CurrentNamespace(), - Name: fmt.Sprintf("hypervisor-%s", node.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", node.Name), } currentPod := &corev1.Pod{} @@ -414,7 +334,21 @@ func (r *GPUNodeReconciler) createHypervisorPod( // add must-have tensor-fusion hypervisor manifest log.Info("adding must-have tensor-fusion hypervisor manifest", "node", node.Name) - utils.AddTFHypervisorConfAfterTemplate(ctx, &spec, pool) + utils.AddTFHypervisorConfAfterTemplate(ctx, &spec, pool, r.CompatibleWithNvidiaContainerToolkit) + + // add vendor-specific env vars for multi-vendor support + if node.Labels != nil && node.Labels[constants.AcceleratorLabelVendor] != "" { + vendor := node.Labels[constants.AcceleratorLabelVendor] + acceleratorLibPath := constants.GetAcceleratorLibPath(vendor) + spec.Containers[0].Env = utils.AppendEnvVarsIfNotExists(spec.Containers[0].Env, corev1.EnvVar{ + Name: constants.TFHardwareVendorEnv, + Value: vendor, + }, corev1.EnvVar{ + Name: constants.TFAcceleratorLibPathEnv, + Value: acceleratorLibPath, + }) + log.Info("added vendor env vars to hypervisor pod", "node", node.Name, "vendor", vendor, "libPath", acceleratorLibPath) + } // add scheduling config for hypervisor if pool.Spec.SchedulingConfigTemplate != nil { @@ -495,12 +429,7 @@ func (r *GPUNodeReconciler) SetupWithManager(mgr ctrl.Manager) error { {NamespacedName: client.ObjectKey{Name: obj.GetName()}}, } })). - Owns(&batchv1.Job{}). Owns(&corev1.Pod{}). Owns(&tfv1.GPU{}). Complete(r) } - -func getDiscoveryJobName(gpunodeName string) string { - return fmt.Sprintf("node-discovery-%s", gpunodeName) -} diff --git a/internal/controller/gpunode_controller_test.go b/internal/controller/gpunode_controller_test.go index 42ea9d7b..29ea919c 100644 --- a/internal/controller/gpunode_controller_test.go +++ b/internal/controller/gpunode_controller_test.go @@ -23,37 +23,24 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/utils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" - "k8s.io/utils/ptr" ) var _ = Describe("GPUNode Controller", func() { Context("When reconciling gpunodes", func() { - It("should create the node discovery job and the hypervisor pod", func() { + It("should create the hypervisor pod", func() { tfEnv := NewTensorFusionEnvBuilder(). AddPoolWithNodeCount(1). SetGpuCountPerNode(1). Build() gpuNode := tfEnv.GetGPUNode(0, 0) - By("checking that the node discovery job is created") - Eventually(func(g Gomega) { - job := &batchv1.Job{} - g.Expect(k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("node-discovery-%s", gpuNode.Name), - Namespace: utils.CurrentNamespace(), - }, job)).Should(Succeed()) - - g.Expect(job.Spec.TTLSecondsAfterFinished).Should(Equal(ptr.To[int32](3600 * 10))) - }).Should(Succeed()) - By("checking that the hypervisor pod is created") pod := &corev1.Pod{} Eventually(func(g Gomega) { err := k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod) g.Expect(err).ShouldNot(HaveOccurred()) @@ -72,7 +59,7 @@ var _ = Describe("GPUNode Controller", func() { Eventually(func(g Gomega) { newPod := &corev1.Pod{} err := k8sClient.Get(ctx, types.NamespacedName{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, newPod) g.Expect(err).ShouldNot(HaveOccurred()) diff --git a/internal/controller/gpupool_controller.go b/internal/controller/gpupool_controller.go index a823ba9f..2d0c2ed7 100644 --- a/internal/controller/gpupool_controller.go +++ b/internal/controller/gpupool_controller.go @@ -408,16 +408,73 @@ func (r *GPUPoolReconciler) reconcilePoolComponents(ctx context.Context, pool *t } func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, pool *tfv1.GPUPool) error { - if pool.Spec.NodeManagerConfig != nil && pool.Spec.NodeManagerConfig.NodeSelector != nil { - hash := utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + nodeManagerConfig := pool.Spec.NodeManagerConfig + if nodeManagerConfig == nil { + return nil + } + + // Handle MultiVendorNodeSelector mode + if len(nodeManagerConfig.MultiVendorNodeSelector) > 0 { + hash := utils.GetObjectHash(nodeManagerConfig.MultiVendorNodeSelector) + if poolSelectorChangeMap[pool.Name] == hash { + return nil + } + + // hash has changed, or first reconcile, should check all k8s nodes + nodes := &corev1.NodeList{} + if err := r.List(ctx, nodes); err != nil { + return err + } + for _, node := range nodes.Items { + // skip no label or deleting nodes + if node.Labels == nil || !node.DeletionTimestamp.IsZero() { + continue + } + // Loop through vendor keys, when any key matched, set vendor label and break + vendorMatched := false + for vendor, nodeSelector := range nodeManagerConfig.MultiVendorNodeSelector { + if nodeSelector == nil { + continue + } + matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, nodeSelector) + if err != nil { + return err + } + if matches { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, vendor); err != nil { + return err + } + vendorMatched = true + break + } + } + // If no vendor matched but node was previously matched, remove vendor label + if !vendorMatched && node.Labels[constants.AcceleratorLabelVendor] != "" { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, ""); err != nil { + return err + } + } + } + poolSelectorChangeMap[pool.Name] = hash + return nil + } + + // Handle default NodeSelector mode + if nodeManagerConfig.NodeSelector != nil { + hash := utils.GetObjectHash(nodeManagerConfig.NodeSelector) if poolSelectorChangeMap[pool.Name] == hash { return nil } + // Determine default vendor: use defaultVendor if set, otherwise NVIDIA + defaultVendor := constants.AcceleratorVendorNvidia + if nodeManagerConfig.DefaultVendor != "" { + defaultVendor = nodeManagerConfig.DefaultVendor + } + // hash has changed, or first reconcile, should check all k8s nodes nodes := &corev1.NodeList{} - selectors := utils.GetInitialGPUNodeSelector() - if err := r.List(ctx, nodes, client.MatchingLabels{selectors[0]: selectors[1]}); err != nil { + if err := r.List(ctx, nodes); err != nil { return err } for _, node := range nodes.Items { @@ -425,12 +482,12 @@ func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, poo if node.Labels == nil || !node.DeletionTimestamp.IsZero() { continue } - matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, pool.Spec.NodeManagerConfig.NodeSelector) + matches, err := schedulingcorev1.MatchNodeSelectorTerms(&node, nodeManagerConfig.NodeSelector) if err != nil { return err } if matches { - if err := UpdateK8SNodeSelectorHash(ctx, r.Client, &node, hash); err != nil { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, &node, hash, defaultVendor); err != nil { return err } } @@ -441,9 +498,9 @@ func (r *GPUPoolReconciler) reconcilePoolSelectorChange(ctx context.Context, poo return nil } -func UpdateK8SNodeSelectorHash(ctx context.Context, k8sClient client.Client, node *corev1.Node, hash string) error { - // skip nodes that already injected the hash - if node.Labels[constants.LabelNodeSelectorHash] == hash { +func UpdateK8SNodeSelectorHashAndVendor(ctx context.Context, k8sClient client.Client, node *corev1.Node, hash string, vendor string) error { + // skip nodes that already have the same hash and vendor + if node.Labels[constants.LabelNodeSelectorHash] == hash && node.Labels[constants.AcceleratorLabelVendor] == vendor { return nil } // update label to trigger the GPUNode reconcile @@ -452,7 +509,15 @@ func UpdateK8SNodeSelectorHash(ctx context.Context, k8sClient client.Client, nod if err := k8sClient.Get(ctx, client.ObjectKey{Name: node.Name}, latest); err != nil { return err } + if latest.Labels == nil { + latest.Labels = make(map[string]string) + } latest.Labels[constants.LabelNodeSelectorHash] = hash + if vendor != "" { + latest.Labels[constants.AcceleratorLabelVendor] = vendor + } else { + delete(latest.Labels, constants.AcceleratorLabelVendor) + } return k8sClient.Update(ctx, latest) }); err != nil { return err diff --git a/internal/controller/gpupool_controller_test.go b/internal/controller/gpupool_controller_test.go index 422a140c..caf85f6f 100644 --- a/internal/controller/gpupool_controller_test.go +++ b/internal/controller/gpupool_controller_test.go @@ -429,7 +429,7 @@ func verifyHypervisorPodHash(gpuNode *tfv1.GPUNode, hash string) { Eventually(func(g Gomega) { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -463,7 +463,7 @@ func verifyHypervisorPodHashConsistently(gpuNode *tfv1.GPUNode, hash string) { Consistently(func(g Gomega) { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -486,7 +486,7 @@ func verifyAllHypervisorPodHash(tfEnv *TensorFusionEnv, hash string) { for _, gpuNode := range nodeList.Items { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) @@ -552,7 +552,7 @@ func verifyAllHypervisorPodHashConsistently(tfEnv *TensorFusionEnv, hash string) for _, gpuNode := range nodeList.Items { pod := &corev1.Pod{} g.Expect(k8sClient.Get(ctx, client.ObjectKey{ - Name: fmt.Sprintf("hypervisor-%s", gpuNode.Name), + Name: fmt.Sprintf("tf-hypervisor-%s", gpuNode.Name), Namespace: utils.CurrentNamespace(), }, pod)).Should(Succeed()) g.Expect(pod.Labels[constants.LabelKeyPodTemplateHash]).Should(Equal(hash)) diff --git a/internal/controller/node_controller.go b/internal/controller/node_controller.go index d8908847..67723625 100644 --- a/internal/controller/node_controller.go +++ b/internal/controller/node_controller.go @@ -32,11 +32,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - "sigs.k8s.io/controller-runtime/pkg/event" - "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" - "sigs.k8s.io/controller-runtime/pkg/reconcile" schedulingcorev1 "k8s.io/component-helpers/scheduling/corev1" ) @@ -115,6 +112,14 @@ func (r *NodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl. } } + // If node changed to other AI accelerator hardware vendor, update gpuNode label vendor and trigger hypervisor update + if gpuNode.Labels[constants.AcceleratorLabelVendor] != node.Labels[constants.AcceleratorLabelVendor] { + gpuNode.Labels[constants.AcceleratorLabelVendor] = node.Labels[constants.AcceleratorLabelVendor] + if err := r.Update(ctx, gpuNode); err != nil { + return ctrl.Result{}, fmt.Errorf("failed to update GPU node vendor: %w", err) + } + } + if !node.DeletionTimestamp.IsZero() { log.Info("GPU node is being deleted, mark related GPUNode resource as destroying", "node", node.Name) gpuNode.Status.Phase = tfv1.TensorFusionGPUNodePhaseDestroying @@ -125,9 +130,14 @@ func (r *NodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl. } // update k8s node hash - hash := utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + hash := "" + if len(pool.Spec.NodeManagerConfig.MultiVendorNodeSelector) > 0 { + hash = utils.GetObjectHash(pool.Spec.NodeManagerConfig.MultiVendorNodeSelector) + } else { + hash = utils.GetObjectHash(pool.Spec.NodeManagerConfig.NodeSelector) + } if node.Labels[constants.LabelNodeSelectorHash] != hash { - if err := UpdateK8SNodeSelectorHash(ctx, r.Client, node, hash); err != nil { + if err := UpdateK8SNodeSelectorHashAndVendor(ctx, r.Client, node, hash, node.Labels[constants.AcceleratorLabelVendor]); err != nil { return ctrl.Result{}, fmt.Errorf("failed to update k8s node hash: %w", err) } } @@ -203,51 +213,35 @@ func (r *NodeReconciler) generateGPUNode(node *corev1.Node, pool *tfv1.GPUPool, if provisioner != "" { gpuNode.Labels[constants.ProvisionerLabelKey] = provisioner } + // Copy vendor label from k8s node to GPUNode + if node.Labels != nil && node.Labels[constants.AcceleratorLabelVendor] != "" { + gpuNode.Labels[constants.AcceleratorLabelVendor] = node.Labels[constants.AcceleratorLabelVendor] + } _ = controllerutil.SetControllerReference(pool, gpuNode, r.Scheme) return gpuNode } // SetupWithManager sets up the controller with the Manager. func (r *NodeReconciler) SetupWithManager(mgr ctrl.Manager) error { - // must choose an initial label selector to avoid performance impact in large Kubernetes clusters + ctr := ctrl.NewControllerManagedBy(mgr) + // Prefer to choose an initial label selector to avoid performance impact in large Kubernetes clusters that has lots of CPU nodes selectors := utils.GetInitialGPUNodeSelector() - p, err := predicate.LabelSelectorPredicate(metav1.LabelSelector{ - MatchLabels: map[string]string{ - selectors[0]: selectors[1], - }, - }) - if err != nil { - return fmt.Errorf("unable to create predicate: %w", err) + if len(selectors) == 2 { + p, err := predicate.LabelSelectorPredicate(metav1.LabelSelector{ + MatchLabels: map[string]string{ + selectors[0]: selectors[1], + }, + }) + if err != nil { + return fmt.Errorf("unable to create predicate: %w", err) + } + ctr.For(&corev1.Node{}, builder.WithPredicates(p)) + } else { + ctr.For(&corev1.Node{}) } - return ctrl.NewControllerManagedBy(mgr). - For(&corev1.Node{}, builder.WithPredicates(p)). + return ctr. Named("node"). - Watches(&tfv1.GPUPool{}, handler.EnqueueRequestsFromMapFunc(func(ctx context.Context, obj client.Object) []reconcile.Request { - nodelist := &tfv1.GPUNodeList{} - if err := mgr.GetClient().List(ctx, nodelist, client.MatchingLabels{ - selectors[0]: selectors[1], - }); err != nil { - log.FromContext(ctx).Error(err, "failed to list GPUNode") - return []reconcile.Request{} - } - var requests []reconcile.Request - for _, n := range nodelist.Items { - requests = append(requests, reconcile.Request{NamespacedName: client.ObjectKey{Name: n.Name}}) - } - return requests - }), builder.WithPredicates(predicate.Funcs{ - UpdateFunc: func(e event.UpdateEvent) bool { - oldObj, ok1 := e.ObjectOld.(*tfv1.GPUPool) - newObj, ok2 := e.ObjectNew.(*tfv1.GPUPool) - if !ok1 || !ok2 { - return false - } - oldNodeSelector := oldObj.Spec.NodeManagerConfig.NodeSelector - newNodeSelector := newObj.Spec.NodeManagerConfig.NodeSelector - return utils.GetObjectHash(oldNodeSelector) != utils.GetObjectHash(newNodeSelector) - }, - })). Complete(r) } diff --git a/internal/controller/pod_controller.go b/internal/controller/pod_controller.go index fb6d0c1e..09191eb8 100644 --- a/internal/controller/pod_controller.go +++ b/internal/controller/pod_controller.go @@ -25,6 +25,8 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/indexallocator" "github.com/NexusGPU/tensor-fusion/internal/metrics" "github.com/NexusGPU/tensor-fusion/internal/portallocator" "github.com/NexusGPU/tensor-fusion/internal/scheduler/expander" @@ -47,10 +49,11 @@ import ( // PodReconciler reconciles a Pod object type PodReconciler struct { client.Client - Scheme *runtime.Scheme - Allocator *gpuallocator.GpuAllocator - PortAllocator *portallocator.PortAllocator - Expander *expander.NodeExpander + Scheme *runtime.Scheme + Allocator *gpuallocator.GpuAllocator + PortAllocator *portallocator.PortAllocator + Expander *expander.NodeExpander + IndexAllocator *indexallocator.IndexAllocator } // +kubebuilder:rbac:groups=core,resources=*,verbs=get;list;watch @@ -71,6 +74,7 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R _ = r.Expander.RemovePreSchedulePod(req.Name, true) r.Allocator.DeallocByPodIdentifier(ctx, req.NamespacedName) metrics.RemoveWorkerMetrics(req.Name, time.Now()) + r.IndexAllocator.RemoveNodeIndexQueueForPod(req.NamespacedName) log.Info("Released GPU resources when pod deleted", "pod", req.NamespacedName) return ctrl.Result{}, nil } @@ -110,7 +114,12 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R } } + if utils.IsPodStopped(pod) { + r.Allocator.DeallocByPodIdentifier(ctx, req.NamespacedName) + } + if pod.Labels[constants.LabelComponent] == constants.ComponentWorker { + r.IndexAllocator.ReconcileLockState(pod) if pod.DeletionTimestamp.IsZero() { metrics.SetWorkerMetricsByWorkload(pod) } @@ -232,6 +241,10 @@ func (r *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } +func (r *PodReconciler) RegisterBackendWorkerChangeHandler(handler framework.WorkerChangeHandler) { + +} + // findConnectionNameNamespace extracts the connection name and namespace from the container's environment variables func findConnectionNameNamespace(pod *corev1.Pod) client.ObjectKey { connectionNameNamespace := client.ObjectKey{} diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go index 2f61b9f2..4ae8ce82 100644 --- a/internal/controller/suite_test.go +++ b/internal/controller/suite_test.go @@ -156,7 +156,7 @@ var _ = BeforeSuite(func() { WorkerUnitPriceMap: make(map[string]map[string]metrics.RawBillingPricing), } - allocator = gpuallocator.NewGpuAllocator(ctx, mgr.GetClient(), 150*time.Millisecond) + allocator = gpuallocator.NewGpuAllocator(ctx, nil, mgr.GetClient(), 150*time.Millisecond) err = allocator.SetupWithManager(ctx, mgr) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/controller/tensorfusionworkload_controller_test.go b/internal/controller/tensorfusionworkload_controller_test.go index 9c2a9cd3..f11fe3d5 100644 --- a/internal/controller/tensorfusionworkload_controller_test.go +++ b/internal/controller/tensorfusionworkload_controller_test.go @@ -37,6 +37,7 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" ) var _ = Describe("TensorFusionWorkload Controller", func() { @@ -402,7 +403,7 @@ func mockSchedulerLoop(ctx context.Context, cfg *rest.Config) { func scheduleAndStartPod(pod *corev1.Pod, clientset *kubernetes.Clientset) { // simulate scheduling cycle Filter and Reserve - allocRequest, _, err := allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(ctx, pod) Expect(err).To(Succeed()) gpus, err := allocator.Alloc(allocRequest) if err != nil { diff --git a/internal/gpuallocator/filter/filter_test.go b/internal/gpuallocator/filter/filter_test.go index c47ab594..5c6e2e5a 100644 --- a/internal/gpuallocator/filter/filter_test.go +++ b/internal/gpuallocator/filter/filter_test.go @@ -111,7 +111,7 @@ func TestFilters(t *testing.T) { filter := NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil) + }) result, err := filter.Filter(ctx, testPodKey, gpus) assert.NoError(t, err) assert.Len(t, result, 2) @@ -126,7 +126,7 @@ func TestFilters(t *testing.T) { With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil)) + })) // Apply filters result, _, err := registry.Apply(ctx, testPodKey, gpus, false) @@ -137,10 +137,11 @@ func TestFilters(t *testing.T) { t.Run("FilterRegistry with gpu indices filtering", func(t *testing.T) { registry := NewFilterRegistry(). + With(NewGPUIndexFilter([]int32{2, 3})). With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("1"), Vram: resource.MustParse("1Gi"), - }, []int32{2, 3})) + })) // Apply filters result, _, err := registry.Apply(ctx, testPodKey, gpus, false) @@ -160,7 +161,7 @@ func TestFilters(t *testing.T) { With(NewResourceFilter(tfv1.Resource{ Tflops: resource.MustParse("8"), Vram: resource.MustParse("30Gi"), - }, nil)) + })) // Apply base registry filters baseResult, _, err := baseRegistry.Apply(ctx, testPodKey, gpus, false) diff --git a/internal/gpuallocator/filter/gpu_index_filter.go b/internal/gpuallocator/filter/gpu_index_filter.go new file mode 100644 index 00000000..285f59bf --- /dev/null +++ b/internal/gpuallocator/filter/gpu_index_filter.go @@ -0,0 +1,57 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + "slices" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/samber/lo" +) + +// GPUIndexFilter filters GPUs based on required GPU indices +type GPUIndexFilter struct { + requiredIndices []int32 +} + +// NewGPUIndexFilter creates a new GPUIndexFilter with the specified indices +func NewGPUIndexFilter(requiredIndices []int32) *GPUIndexFilter { + return &GPUIndexFilter{ + requiredIndices: requiredIndices, + } +} + +// Filter implements GPUFilter.Filter +func (f *GPUIndexFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + // If no indices specified, pass all GPUs + if len(f.requiredIndices) == 0 { + return gpus, nil + } + + return lo.Filter(gpus, func(gpu *tfv1.GPU, _ int) bool { + // Check GPU index + if gpu.Status.Index != nil && slices.Contains(f.requiredIndices, *gpu.Status.Index) { + return true + } + return false + }), nil +} + +func (f *GPUIndexFilter) Name() string { + return "GPUIndexFilter" +} diff --git a/internal/gpuallocator/filter/gpu_isolation_mode_filter.go b/internal/gpuallocator/filter/gpu_isolation_mode_filter.go new file mode 100644 index 00000000..4d094e04 --- /dev/null +++ b/internal/gpuallocator/filter/gpu_isolation_mode_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUIsolationModeFilter filters GPUs based on their isolation mode +type GPUIsolationModeFilter struct { + requiredIsolationMode tfv1.IsolationModeType +} + +// NewGPUIsolationModeFilter creates a new filter that matches GPUs with the specified isolation mode +func NewGPUIsolationModeFilter(isolationMode tfv1.IsolationModeType) *GPUIsolationModeFilter { + return &GPUIsolationModeFilter{ + requiredIsolationMode: isolationMode, + } +} + +// Filter implements GPUFilter interface +func (f *GPUIsolationModeFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredIsolationMode == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.IsolationMode == "" || gpu.Status.IsolationMode == f.requiredIsolationMode { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUIsolationModeFilter) Name() string { + return "GPUIsolationModeFilter" +} diff --git a/internal/gpuallocator/filter/gpu_model_filter.go b/internal/gpuallocator/filter/gpu_model_filter.go new file mode 100644 index 00000000..f3d927e3 --- /dev/null +++ b/internal/gpuallocator/filter/gpu_model_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUModelFilter filters GPUs based on their model (e.g., A100, H100) +type GPUModelFilter struct { + requiredModel string +} + +// NewGPUModelFilter creates a new filter that matches GPUs with the specified model +func NewGPUModelFilter(model string) *GPUModelFilter { + return &GPUModelFilter{ + requiredModel: model, + } +} + +// Filter implements GPUFilter interface +func (f *GPUModelFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredModel == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.GPUModel == f.requiredModel { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUModelFilter) Name() string { + return "GPUModelFilter" +} diff --git a/internal/gpuallocator/filter/gpu_model_vendor_filter.go b/internal/gpuallocator/filter/gpu_model_vendor_filter.go deleted file mode 100644 index f095d76a..00000000 --- a/internal/gpuallocator/filter/gpu_model_vendor_filter.go +++ /dev/null @@ -1,50 +0,0 @@ -package filter - -import ( - "context" - - tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" -) - -// GPUModelAndVendorFilter filters GPUs based on their model (e.g., A100, H100) -type GPUModelAndVendorFilter struct { - requiredModel string - requiredVendor string -} - -// NewGPUModelAndVendorFilter creates a new filter that matches GPUs with the specified model -func NewGPUModelAndVendorFilter(model string, vendor string) *GPUModelAndVendorFilter { - return &GPUModelAndVendorFilter{ - requiredModel: model, - requiredVendor: vendor, - } -} - -// Filter implements GPUFilter interface -func (f *GPUModelAndVendorFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { - if f.requiredModel == "" && f.requiredVendor == "" { - return gpus, nil - } - - filtered := make([]*tfv1.GPU, 0, len(gpus)) - - if f.requiredModel != "" { - for _, gpu := range gpus { - if gpu.Status.GPUModel == f.requiredModel { - filtered = append(filtered, gpu) - } - } - } - if f.requiredVendor != "" { - for _, gpu := range gpus { - if gpu.Status.Vendor == f.requiredVendor { - filtered = append(filtered, gpu) - } - } - } - return filtered, nil -} - -func (f *GPUModelAndVendorFilter) Name() string { - return "GPUModelAndVendorFilter" -} diff --git a/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go b/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go index 0f57173b..e25de11e 100644 --- a/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go +++ b/internal/gpuallocator/filter/gpu_model_vendor_filter_test.go @@ -85,7 +85,7 @@ func TestGPUModelFilter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - filter := NewGPUModelAndVendorFilter(tt.requiredModel, "") + filter := NewGPUModelFilter(tt.requiredModel) got, err := filter.Filter(context.Background(), testPodKey, tt.gpus) if tt.wantErr { assert.Error(t, err) diff --git a/internal/gpuallocator/filter/gpu_vendor_filter.go b/internal/gpuallocator/filter/gpu_vendor_filter.go new file mode 100644 index 00000000..0f3ef5cf --- /dev/null +++ b/internal/gpuallocator/filter/gpu_vendor_filter.go @@ -0,0 +1,38 @@ +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// GPUVendorFilter filters GPUs based on their vendor +type GPUVendorFilter struct { + requiredVendor string +} + +// NewGPUVendorFilter creates a new filter that matches GPUs with the specified vendor +func NewGPUVendorFilter(vendor string) *GPUVendorFilter { + return &GPUVendorFilter{ + requiredVendor: vendor, + } +} + +// Filter implements GPUFilter interface +func (f *GPUVendorFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + if f.requiredVendor == "" { + return gpus, nil + } + + filtered := make([]*tfv1.GPU, 0, len(gpus)) + for _, gpu := range gpus { + if gpu.Status.Vendor == f.requiredVendor { + filtered = append(filtered, gpu) + } + } + return filtered, nil +} + +func (f *GPUVendorFilter) Name() string { + return "GPUVendorFilter" +} diff --git a/internal/gpuallocator/filter/partition_template_filter.go b/internal/gpuallocator/filter/partition_template_filter.go new file mode 100644 index 00000000..e6991764 --- /dev/null +++ b/internal/gpuallocator/filter/partition_template_filter.go @@ -0,0 +1,101 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/samber/lo" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// PartitionTemplateFilter filters GPUs based on partition template availability +// Only applies when isolation mode is partitioned +type PartitionTemplateFilter struct { + isolationMode tfv1.IsolationModeType + requiredTemplateID string + maxPartitionsMap map[string]uint32 // GPU model -> max partitions +} + +// NewPartitionTemplateFilter creates a new PartitionTemplateFilter +func NewPartitionTemplateFilter(isolationMode tfv1.IsolationModeType, requiredTemplateID string, maxPartitionsMap map[string]uint32) *PartitionTemplateFilter { + return &PartitionTemplateFilter{ + isolationMode: isolationMode, + requiredTemplateID: requiredTemplateID, + maxPartitionsMap: maxPartitionsMap, + } +} + +// Filter implements GPUFilter.Filter +func (f *PartitionTemplateFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNamespace, gpus []*tfv1.GPU) ([]*tfv1.GPU, error) { + // Only apply filter for partitioned isolation mode + if f.isolationMode != tfv1.IsolationModePartitioned { + return gpus, nil + } + + logger := log.FromContext(ctx) + + return lo.Filter(gpus, func(gpu *tfv1.GPU, _ int) bool { + // Check if GPU has partition templates + if len(gpu.Status.PartitionTemplates) == 0 { + logger.V(5).Info("GPU has no partition templates", "gpu", gpu.Name) + return false + } + + // If a specific template ID is required, check if GPU has it + if f.requiredTemplateID != "" { + hasTemplate := false + for _, template := range gpu.Status.PartitionTemplates { + if template.TemplateID == f.requiredTemplateID { + hasTemplate = true + break + } + } + if !hasTemplate { + logger.V(5).Info("GPU does not have required partition template", + "gpu", gpu.Name, "template", f.requiredTemplateID) + return false + } + } + + // Check partition count limit + allocatedCount := 0 + if gpu.Status.AllocatedPartitions != nil { + allocatedCount = len(gpu.Status.AllocatedPartitions) + } + + // Get max partitions from config + maxPartitions := f.maxPartitionsMap[gpu.Status.GPUModel] + if maxPartitions == 0 { + // Default to 7 for MIG if not configured + maxPartitions = 7 + } + + if maxPartitions > 0 && uint32(allocatedCount) >= maxPartitions { + logger.V(5).Info("GPU has reached maximum partition count", + "gpu", gpu.Name, "allocated", allocatedCount, "max", maxPartitions) + return false + } + + return true + }), nil +} + +func (f *PartitionTemplateFilter) Name() string { + return "PartitionTemplateFilter" +} diff --git a/internal/gpuallocator/filter/partition_template_filter_test.go b/internal/gpuallocator/filter/partition_template_filter_test.go new file mode 100644 index 00000000..a6eaf1e2 --- /dev/null +++ b/internal/gpuallocator/filter/partition_template_filter_test.go @@ -0,0 +1,175 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + "testing" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestPartitionTemplateFilter(t *testing.T) { + testPodKey := tfv1.NameNamespace{ + Name: "test-pod", + Namespace: "test-namespace", + } + + tests := []struct { + name string + isolationMode tfv1.IsolationModeType + requiredTemplate string + maxPartitionsMap map[string]uint32 + gpus []*tfv1.GPU + expectedCount int + expectedGPUNames []string + }{ + { + name: "non-partitioned mode should pass all GPUs", + isolationMode: tfv1.IsolationModeSoft, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-1"}, + }, + { + name: "partitioned mode - GPU without templates filtered out", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + { + name: "partitioned mode - specific template required", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "1g.24gb", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + { + name: "partitioned mode - max partitions reached", + isolationMode: tfv1.IsolationModePartitioned, + requiredTemplate: "", + maxPartitionsMap: map[string]uint32{"A100": 7}, + gpus: []*tfv1.GPU{ + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: "1g.24gb", PodUID: "pod-1"}, + "pod-2": {TemplateID: "1g.24gb", PodUID: "pod-2"}, + "pod-3": {TemplateID: "1g.24gb", PodUID: "pod-3"}, + "pod-4": {TemplateID: "1g.24gb", PodUID: "pod-4"}, + "pod-5": {TemplateID: "1g.24gb", PodUID: "pod-5"}, + "pod-6": {TemplateID: "1g.24gb", PodUID: "pod-6"}, + "pod-7": {TemplateID: "1g.24gb", PodUID: "pod-7"}, + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "gpu-2"}, + Status: tfv1.GPUStatus{ + GPUModel: "A100", + PartitionTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: "1g.24gb", PodUID: "pod-1"}, + }, + }, + }, + }, + expectedCount: 1, + expectedGPUNames: []string{"gpu-2"}, + }, + } + + ctx := context.Background() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filter := NewPartitionTemplateFilter(tt.isolationMode, tt.requiredTemplate, tt.maxPartitionsMap) + result, err := filter.Filter(ctx, testPodKey, tt.gpus) + + assert.NoError(t, err) + assert.Len(t, result, tt.expectedCount) + if len(tt.expectedGPUNames) > 0 { + resultNames := make([]string, len(result)) + for i, gpu := range result { + resultNames[i] = gpu.Name + } + assert.ElementsMatch(t, tt.expectedGPUNames, resultNames) + } + }) + } +} diff --git a/internal/gpuallocator/filter/resource_filter.go b/internal/gpuallocator/filter/resource_filter.go index fa8ca805..9f0a76ef 100644 --- a/internal/gpuallocator/filter/resource_filter.go +++ b/internal/gpuallocator/filter/resource_filter.go @@ -2,7 +2,6 @@ package filter import ( "context" - "slices" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/utils" @@ -12,14 +11,12 @@ import ( // ResourceFilter filters GPUs based on available resources type ResourceFilter struct { requiredResource tfv1.Resource - requiredIndices []int32 } // NewResourceFilter creates a new ResourceFilter with the specified resource requirements -func NewResourceFilter(required tfv1.Resource, requiredIndices []int32) *ResourceFilter { +func NewResourceFilter(required tfv1.Resource) *ResourceFilter { return &ResourceFilter{ requiredResource: required, - requiredIndices: requiredIndices, } } @@ -31,13 +28,6 @@ func (f *ResourceFilter) Filter(ctx context.Context, workerPodKey tfv1.NameNames return false } - // Check GPU indices range - if len(f.requiredIndices) > 0 { - if gpu.Status.Index != nil && !slices.Contains(f.requiredIndices, *gpu.Status.Index) { - return false - } - } - // Check TFlops availability hasTflops := gpu.Status.Available.Tflops.Cmp(f.requiredResource.Tflops) >= 0 diff --git a/internal/gpuallocator/gpuallocator.go b/internal/gpuallocator/gpuallocator.go index a32156da..e20298e1 100644 --- a/internal/gpuallocator/gpuallocator.go +++ b/internal/gpuallocator/gpuallocator.go @@ -5,9 +5,7 @@ import ( "context" "fmt" "math" - "slices" "sort" - "strconv" "strings" "sync" "time" @@ -18,6 +16,7 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator/filter" + "github.com/NexusGPU/tensor-fusion/internal/indexallocator" "github.com/NexusGPU/tensor-fusion/internal/metrics" "github.com/NexusGPU/tensor-fusion/internal/quota" "github.com/NexusGPU/tensor-fusion/internal/utils" @@ -38,12 +37,54 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" ) -const MaxGPUCounterPerAllocation = 128 const CleanUpCheckInterval = 3 * time.Minute var mu sync.Mutex var GPUCapacityMap = map[string]tfv1.Resource{} +// PartitionTemplateMap stores partition template info by GPU model +// Key: GPU model (e.g., "A100_SXM_80G"), Value: map of templateID -> template info +var PartitionTemplateMap = map[string]map[string]config.PartitionTemplateInfo{} + +// MaxPartitionsMap stores max partitions by GPU model +// Key: GPU model, Value: max partitions (e.g., 7 for MIG) +var MaxPartitionsMap = map[string]uint32{} + +// MaxPlacementSlotsMap stores max placement slots by GPU model +// Key: GPU model, Value: max placement slots (e.g., 8 for MIG) +var MaxPlacementSlotsMap = map[string]uint32{} + +// LoadPartitionTemplatesFromConfig loads partition templates and max partitions from GPU info config +// This should be called when GPU info config is loaded/updated +func LoadPartitionTemplatesFromConfig(gpuInfos []config.GpuInfo) { + mu.Lock() + defer mu.Unlock() + + for _, gpuInfo := range gpuInfos { + // Store max partitions + if gpuInfo.MaxPartitions > 0 { + MaxPartitionsMap[gpuInfo.Model] = gpuInfo.MaxPartitions + MaxPartitionsMap[gpuInfo.FullModelName] = gpuInfo.MaxPartitions + } + + // Store max placement slots + if gpuInfo.MaxPlacementSlots > 0 { + MaxPlacementSlotsMap[gpuInfo.Model] = gpuInfo.MaxPlacementSlots + MaxPlacementSlotsMap[gpuInfo.FullModelName] = gpuInfo.MaxPlacementSlots + } + + // Store partition templates + if len(gpuInfo.PartitionTemplates) > 0 { + templateMap := make(map[string]config.PartitionTemplateInfo, len(gpuInfo.PartitionTemplates)) + for _, template := range gpuInfo.PartitionTemplates { + templateMap[template.TemplateID] = template + } + PartitionTemplateMap[gpuInfo.Model] = templateMap + PartitionTemplateMap[gpuInfo.FullModelName] = templateMap + } + } +} + type Strategy interface { // When isForNode = true, indicates each GPU's node level score // otherwise it's single GPU score inside one node @@ -83,11 +124,12 @@ type GpuAllocator struct { nodeGpuStore map[string]map[string]*tfv1.GPU poolGpuStore map[string]map[string]*tfv1.GPU nodeWorkerStore map[string]map[types.NamespacedName]struct{} - storeMutex sync.RWMutex - allocateMutex sync.Mutex - syncInterval time.Duration - cancel context.CancelFunc - ctx context.Context + + storeMutex sync.RWMutex + allocateMutex sync.Mutex + syncInterval time.Duration + cancel context.CancelFunc + ctx context.Context // Queue for tracking modified GPUs that need to be synced dirtyQueue map[types.NamespacedName]struct{} @@ -104,10 +146,16 @@ type GpuAllocator struct { reconcileWorkerOnce sync.Once initializedCh chan struct{} - bindHandlers []func(req *tfv1.AllocRequest) + bindHandlers []func(req *tfv1.AllocRequest) + indexAllocator *indexallocator.IndexAllocator } -func NewGpuAllocator(ctx context.Context, client client.Client, syncInterval time.Duration) *GpuAllocator { +func NewGpuAllocator( + ctx context.Context, + indexAllocator *indexallocator.IndexAllocator, + client client.Client, + syncInterval time.Duration, +) *GpuAllocator { log := log.FromContext(ctx) if client == nil { @@ -123,6 +171,15 @@ func NewGpuAllocator(ctx context.Context, client client.Client, syncInterval tim // Create quota store quotaStore := quota.NewQuotaStore(client, ctx) + if indexAllocator == nil { + newIndexAllocator, err := indexallocator.NewIndexAllocator(ctx, client) + if err != nil { + log.Error(err, "Failed to create index allocator") + return nil + } + indexAllocator = newIndexAllocator + } + allocator := &GpuAllocator{ Client: client, filterRegistry: baseRegistry, @@ -135,6 +192,7 @@ func NewGpuAllocator(ctx context.Context, client client.Client, syncInterval tim dirtyQueue: make(map[types.NamespacedName]struct{}), ctx: ctx, + indexAllocator: indexAllocator, uniqueAllocation: make(map[string]*tfv1.AllocRequest, 512), uniqueDeallocation: make(map[string]struct{}, 512), podNamespaceNsToPodUID: make(map[string]string, 512), @@ -178,20 +236,43 @@ func (s *GpuAllocator) Filter( toFilterGPUs []*tfv1.GPU, isSimulateSchedule bool, ) ([]*tfv1.GPU, []filter.FilterDetail, error) { - // Add SameNodeFilter if count > 1 to ensure GPUs are from the same node - filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(req.Request, req.GPUIndices)) + // Filter order: index -> isolation -> partition -> resource -> (model, vendor, nodeAffinity) -> sameNode + filterRegistry := s.filterRegistry - // Add GPU model filter if specified + // 1. GPU index filter (extracted from resource filter) + if len(req.GPUIndices) > 0 { + filterRegistry = filterRegistry.With(filter.NewGPUIndexFilter(req.GPUIndices)) + } + + // 2. GPU isolation mode filter + if req.Isolation != "" { + filterRegistry = filterRegistry.With(filter.NewGPUIsolationModeFilter(req.Isolation)) + } + + // 3. Partition template filter (only for partitioned mode) + if req.Isolation == tfv1.IsolationModePartitioned { + filterRegistry = filterRegistry.With(filter.NewPartitionTemplateFilter(req.Isolation, req.PartitionTemplateID, MaxPartitionsMap)) + } + + // 4. Resource filter (moved after isolation/partition filters) + filterRegistry = filterRegistry.With(filter.NewResourceFilter(req.Request)) + + // 5. GPU model filter if specified if req.GPUModel != "" { - filterRegistry = filterRegistry.With(filter.NewGPUModelAndVendorFilter(req.GPUModel, req.GPUVendor)) + filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(req.GPUModel)) } - // NOTE: deprecated, use Kubernetes native spec template affinity way + // 6. GPU vendor filter if specified + if req.GPUVendor != "" { + filterRegistry = filterRegistry.With(filter.NewGPUVendorFilter(req.GPUVendor)) + } + + // 7. NOTE: deprecated, use Kubernetes native spec template affinity way if req.NodeAffinity != nil { filterRegistry = filterRegistry.With(filter.NewNodeAffinityFilter(s.Client, req.NodeAffinity)) } - // Same node filter must be applied at final step + // 8. Same node filter must be applied at final step if req.Count > 1 { filterRegistry = filterRegistry.With(filter.NewSameNodeFilter(req.Count)) } @@ -217,17 +298,59 @@ func (s *GpuAllocator) FilterWithPreempt( return nil, nil, fmt.Errorf("gpu %s not found", gpuName) } gpuCopy := gpu.DeepCopy() - gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) - gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + + // Handle partitioned mode: add back partition resources from config + if preemptAllocRequest.Isolation == tfv1.IsolationModePartitioned && preemptAllocRequest.PartitionTemplateID != "" { + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage( + gpuCopy.Status.Capacity.Tflops, gpuCopy.Status.GPUModel, preemptAllocRequest.PartitionTemplateID) + if err == nil { + gpuCopy.Status.Available.Tflops.Add(partitionTflops) + gpuCopy.Status.Available.Vram.Add(partitionVram) + } else { + // Fallback to request resources + gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) + gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + } + } else { + // Non-partitioned mode + gpuCopy.Status.Available.Tflops.Add(preemptAllocRequest.Request.Tflops) + gpuCopy.Status.Available.Vram.Add(preemptAllocRequest.Request.Vram) + } toFilterGPUs = append(toFilterGPUs, gpuCopy) } } - filterRegistry := s.filterRegistry.With(filter.NewResourceFilter(req.Request, req.GPUIndices)) - // Add GPU model filter if specified + // Use same filter order as regular Filter + filterRegistry := s.filterRegistry + + // 1. GPU index filter + if len(req.GPUIndices) > 0 { + filterRegistry = filterRegistry.With(filter.NewGPUIndexFilter(req.GPUIndices)) + } + + // 2. GPU isolation mode filter + if req.Isolation != "" { + filterRegistry = filterRegistry.With(filter.NewGPUIsolationModeFilter(req.Isolation)) + } + + // 3. Partition template filter (only for partitioned mode) + if req.Isolation == tfv1.IsolationModePartitioned { + filterRegistry = filterRegistry.With(filter.NewPartitionTemplateFilter(req.Isolation, req.PartitionTemplateID, MaxPartitionsMap)) + } + + // 4. Resource filter + filterRegistry = filterRegistry.With(filter.NewResourceFilter(req.Request)) + + // 5. GPU model filter if specified if req.GPUModel != "" { - filterRegistry = filterRegistry.With(filter.NewGPUModelAndVendorFilter(req.GPUModel, req.GPUVendor)) + filterRegistry = filterRegistry.With(filter.NewGPUModelFilter(req.GPUModel)) + } + + // 6. GPU vendor filter if specified + if req.GPUVendor != "" { + filterRegistry = filterRegistry.With(filter.NewGPUVendorFilter(req.GPUVendor)) } + // No need to check count and other filters since it's always in the same node during each preempt trial filteredGPUs, filterDetails, err := filterRegistry.Apply(s.ctx, req.WorkloadNameNamespace, toFilterGPUs, false) if err != nil { @@ -266,6 +389,71 @@ func (s *GpuAllocator) Select(req *tfv1.AllocRequest, filteredGPUs []*tfv1.GPU) return result, nil } +// GetMatchedPartition finds the best matching partition template for a request in partitioned mode. +// Returns the GPU, matched partition template, and partition UUID if a match is found. +// In partitioned mode, GPUs must have partition templates available, and we select the smallest +// template that can satisfy the request to minimize resource waste. +func (s *GpuAllocator) GetMatchedPartition( + req *tfv1.AllocRequest, + filteredGPUs []*tfv1.GPU, +) (*tfv1.GPU, *PartitionMatchResult, error) { + // Only process partitioned mode requests + if req.Isolation != tfv1.IsolationModePartitioned { + return nil, nil, fmt.Errorf("GetMatchedPartition only supports partitioned isolation mode") + } + + if len(filteredGPUs) == 0 { + return nil, nil, fmt.Errorf("no GPUs available for partition matching") + } + + var bestGPU *tfv1.GPU + var bestMatch *PartitionMatchResult + bestScore := math.MaxFloat64 + + s.storeMutex.RLock() + defer s.storeMutex.RUnlock() + + // Find the best GPU with the best matching partition template + for _, gpu := range filteredGPUs { + // Get partition templates from GPU status + if len(gpu.Status.PartitionTemplates) == 0 { + continue // Skip GPUs without partition templates + } + // Match partition template (gets template info from config) + match, err := MatchPartitionTemplate(gpu.Status, req) + if err != nil { + log.FromContext(s.ctx).V(5).Info("Failed to match partition template for GPU", + "gpu", gpu.Name, "error", err) + continue + } + + if !match.CanAllocate { + continue + } + + // Check if GPU has enough resources (gets template info from config) + if err := CheckPartitionAvailability(gpu, match.TemplateID); err != nil { + log.FromContext(s.ctx).V(5).Info("GPU does not have available resources for partition", + "gpu", gpu.Name, "error", err) + continue + } + + // Update best match if this is better (lower score = less waste) + if match.Score < bestScore { + bestGPU = gpu + bestMatch = match + bestScore = match.Score + } + } + + if bestGPU == nil || bestMatch == nil { + return nil, nil, fmt.Errorf("no suitable partition template found for request: TFLOPs=%s, VRAM=%s", + req.Request.Tflops.String(), req.Request.Vram.String()) + } + + return bestGPU, bestMatch, nil +} + // Bind allocates resources on the provided GPUs for the given request. // It updates the in-memory store and marks the GPUs as dirty for syncing. func (s *GpuAllocator) Bind( @@ -302,24 +490,32 @@ func (s *GpuAllocator) Bind( if gpu.Status.Available == nil { return nil, fmt.Errorf("GPU %s has nil available resources", selectedGPU) } - if gpu.Status.Available.Tflops.Cmp(req.Request.Tflops) < 0 { - return nil, fmt.Errorf("GPU %s insufficient TFLOPs: available %s, requested %s", - selectedGPU, gpu.Status.Available.Tflops.String(), req.Request.Tflops.String()) - } - if gpu.Status.Available.Vram.Cmp(req.Request.Vram) < 0 { - return nil, fmt.Errorf("GPU %s insufficient VRAM: available %s, requested %s", - selectedGPU, gpu.Status.Available.Vram.String(), req.Request.Vram.String()) - } - - // reduce available resource on the GPU status - if !req.Request.ComputePercent.IsZero() { - requiredTflops := utils.ComputePercentToTflops(gpu.Status.Capacity.Tflops, req.Request) - gpu.Status.Available.Tflops.Sub(*requiredTflops) + // Handle partitioned mode differently + if req.Isolation == tfv1.IsolationModePartitioned && req.PartitionTemplateID != "" { + if err := s.bindPartition(gpu, req, selectedGPU); err != nil { + return nil, err + } } else { - gpu.Status.Available.Tflops.Sub(req.Request.Tflops) + // Non-partitioned mode: subtract request resources + if gpu.Status.Available.Tflops.Cmp(req.Request.Tflops) < 0 { + return nil, fmt.Errorf("GPU %s insufficient TFLOPs: available %s, requested %s", + selectedGPU, gpu.Status.Available.Tflops.String(), req.Request.Tflops.String()) + } + if gpu.Status.Available.Vram.Cmp(req.Request.Vram) < 0 { + return nil, fmt.Errorf("GPU %s insufficient VRAM: available %s, requested %s", + selectedGPU, gpu.Status.Available.Vram.String(), req.Request.Vram.String()) + } + + // reduce available resource on the GPU status + if !req.Request.ComputePercent.IsZero() { + requiredTflops := utils.ComputePercentToTflops(gpu.Status.Capacity.Tflops, req.Request) + gpu.Status.Available.Tflops.Sub(*requiredTflops) + } else { + gpu.Status.Available.Tflops.Sub(req.Request.Tflops) + } + gpu.Status.Available.Vram.Sub(req.Request.Vram) } - gpu.Status.Available.Vram.Sub(req.Request.Vram) addRunningApp(s.ctx, gpu, req) @@ -460,18 +656,18 @@ func (s *GpuAllocator) Dealloc( ) { <-s.initializedCh podUID := string(podMeta.UID) - log := log.FromContext(s.ctx) + logger := log.FromContext(s.ctx) request, exists := s.uniqueAllocation[podUID] if !exists || request == nil { // should not block finalizer - log.Error(fmt.Errorf("pod has not allocated GPUs"), "pod", podUID) + logger.Error(fmt.Errorf("pod has not allocated GPUs"), "pod", podUID) return } if _, exists := s.uniqueDeallocation[podUID]; exists { // should not block finalizer - log.Error(fmt.Errorf("pod has already deallocated GPUs"), "pod", podUID) + logger.Error(fmt.Errorf("pod has already deallocated GPUs"), "pod", podUID) return } @@ -484,18 +680,23 @@ func (s *GpuAllocator) Dealloc( gpuNameNs := types.NamespacedName{Name: gpu} storeGPU, exists := s.gpuStore[gpuNameNs] if !exists { - log.Error(fmt.Errorf("GPU not found in store"), "Failed to deallocate GPU", "name", gpu) + logger.Error(fmt.Errorf("GPU not found in store"), "Failed to deallocate GPU", "name", gpu) continue } - // Add resources back to the GPU - if !request.Request.ComputePercent.IsZero() { - requiredTflops := utils.ComputePercentToTflops(storeGPU.Status.Capacity.Tflops, request.Request) - storeGPU.Status.Available.Tflops.Add(*requiredTflops) + // Handle partitioned mode deallocation + if request.Isolation == tfv1.IsolationModePartitioned && request.PartitionTemplateID != "" { + s.deallocPartition(storeGPU, request, gpu) } else { - storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + // Non-partitioned mode: add back request resources + if !request.Request.ComputePercent.IsZero() { + requiredTflops := utils.ComputePercentToTflops(storeGPU.Status.Capacity.Tflops, request.Request) + storeGPU.Status.Available.Tflops.Add(*requiredTflops) + } else { + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + } + storeGPU.Status.Available.Vram.Add(request.Request.Vram) } - storeGPU.Status.Available.Vram.Add(request.Request.Vram) if nodeName == "" { nodeName = storeGPU.Status.NodeSelector[constants.KubernetesHostNameLabel] @@ -515,7 +716,7 @@ func (s *GpuAllocator) Dealloc( // Deallocate quota resources in memory (atomic operation) s.quotaStore.DeallocateQuota(workloadNameNamespace.Namespace, request) - log.Info("GPU deallocation successful", + logger.Info("GPU deallocation successful", "namespace", workloadNameNamespace.Namespace, "workload", workloadNameNamespace.Name, "gpu_count", len(gpus), @@ -886,6 +1087,7 @@ func (s *GpuAllocator) SetupWithManager(ctx context.Context, mgr manager.Manager } func (s *GpuAllocator) SetAllocatorReady() { + s.indexAllocator.SetReady() close(s.initializedCh) } @@ -1071,6 +1273,9 @@ func syncGPUMetadataAndStatusFromCluster(old *tfv1.GPU, gpu *tfv1.GPU) { old.Status.Vendor = gpu.Status.Vendor old.Status.NUMANode = gpu.Status.NUMANode old.Status.Index = gpu.Status.Index + // Sync partition templates from cluster (discovered by node discovery) + // Don't overwrite AllocatedPartitions as that's managed by the allocator + old.Status.PartitionTemplates = gpu.Status.PartitionTemplates } func (s *GpuAllocator) handleGPUUpdateCapacityDiff(old, gpu *tfv1.GPU) { @@ -1151,6 +1356,7 @@ func (s *GpuAllocator) SyncGPUsToK8s() { // Apply our status updates to the latest version latest.Status.Available = gpu.Status.Available latest.Status.RunningApps = gpu.Status.RunningApps + latest.Status.AllocatedPartitions = gpu.Status.AllocatedPartitions // Attempt to update with the latest version return s.Status().Update(s.ctx, latest) @@ -1316,7 +1522,7 @@ func (s *GpuAllocator) reconcileAllocationState() { !controllerutil.ContainsFinalizer(&worker, constants.Finalizer) if scheduled { - allocRequest, msg, err := s.ComposeAllocationRequest(&worker) + allocRequest, msg, err := utils.ComposeAllocationRequest(ctx, &worker) if err != nil { logger.Error(err, "Failed to compose allocation request for existing worker Pod, annotation may not be valid", "pod", worker.Name, "msg", msg) return false @@ -1324,6 +1530,10 @@ func (s *GpuAllocator) reconcileAllocationState() { s.uniqueAllocation[string(worker.UID)] = allocRequest s.podNamespaceNsToPodUID[worker.Namespace+"/"+worker.Name] = string(worker.UID) s.addAllocationMap(worker.Spec.NodeName, worker.ObjectMeta) + + if utils.IsPodPending(&worker) { + s.indexAllocator.ReconcileLockState(&worker) + } } return scheduled && !deletedAndDeAllocated }) @@ -1340,6 +1550,8 @@ func (s *GpuAllocator) reconcileAllocationState() { actualRunningAppsMap[gpuKey] = gpu.Status.RunningApps gpu.Status.RunningApps = []*tfv1.RunningAppDetail{} + // Clear AllocatedPartitions - will be rebuilt from workers + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) } // This is important for progressive migration mode @@ -1357,12 +1569,55 @@ func (s *GpuAllocator) reconcileAllocationState() { for gpuId := range gpuIdsList { gpuKey := types.NamespacedName{Name: gpuId} + gpu := s.gpuStore[gpuKey] + if gpu == nil { + continue + } + gpuAvailableRes, ok := actualAvailableMap[gpuKey] if ok { - gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) - gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + // Handle partitioned mode differently + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID != "" { + // Calculate partition resource usage from config + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, allocRequest.PartitionTemplateID) + if err == nil { + gpuAvailableRes.Tflops.Sub(partitionTflops) + gpuAvailableRes.Vram.Sub(partitionVram) + + // Rebuild AllocatedPartitions using podUID as key + if gpu.Status.AllocatedPartitions == nil { + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) + } + podUID := string(worker.UID) + // During reconciliation, preserve existing slot assignments if available + existingPartition, exists := gpu.Status.AllocatedPartitions[podUID] + allocatedPartition := tfv1.AllocatedPartition{ + TemplateID: allocRequest.PartitionTemplateID, + PodUID: podUID, + PodName: worker.Name, + Namespace: worker.Namespace, + AllocatedAt: metav1.Now(), // Use current time for reconciliation + } + // Preserve existing slot assignments if they exist + if exists { + allocatedPartition.AllocatedSlotStart = existingPartition.AllocatedSlotStart + allocatedPartition.AllocatedSlotEnd = existingPartition.AllocatedSlotEnd + } + gpu.Status.AllocatedPartitions[podUID] = allocatedPartition + } else { + // Fallback to request resources if template not found + logger.Info("Partition template not found in config during reconciliation, using request resources", + "gpu", gpuId, "template", allocRequest.PartitionTemplateID, "error", err) + gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) + gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + } + } else { + // Non-partitioned mode + gpuAvailableRes.Tflops.Sub(allocRequest.Request.Tflops) + gpuAvailableRes.Vram.Sub(allocRequest.Request.Vram) + } } - addRunningApp(ctx, s.gpuStore[gpuKey], allocRequest) + addRunningApp(ctx, gpu, allocRequest) } } @@ -1384,6 +1639,12 @@ func (s *GpuAllocator) reconcileAllocationState() { s.markGPUDirtyLocked(gpuKey) log.FromContext(ctx).Info("Correcting gpu running apps", "gpu", gpuKey.Name, "runningApps", len(gpu.Status.RunningApps)) } + + // Mark GPU dirty if AllocatedPartitions need to be synced + // (they are already updated in the loop above, just need to sync to K8s) + if len(gpu.Status.AllocatedPartitions) > 0 { + s.markGPUDirtyLocked(gpuKey) + } } // reconcile quota store state @@ -1482,65 +1743,124 @@ func removeRunningApp(ctx context.Context, gpu *tfv1.GPU, allocRequest *tfv1.All } } -func (s *GpuAllocator) ComposeAllocationRequest(pod *v1.Pod) (*tfv1.AllocRequest, string, error) { - // allow Pods with no requests/limits to use TensorFusion, Pod webhook will ensure at least one request/limit is set - gpuRequestResource, err := utils.GetGPUResource(pod, true) - if err != nil { - log.FromContext(s.ctx).Error(err, "Invalid gpu request annotation", "pod", pod.Name, "namespace", pod.Namespace) +// bindPartition handles partition allocation for a single GPU in partitioned mode +func (s *GpuAllocator) bindPartition(gpu *tfv1.GPU, req *tfv1.AllocRequest, selectedGPU string) error { + // Verify template exists in GPU status + templateExists := false + for _, template := range gpu.Status.PartitionTemplates { + if template.TemplateID == req.PartitionTemplateID { + templateExists = true + break + } + } + if !templateExists { + return fmt.Errorf("partition template %s not found on GPU %s", req.PartitionTemplateID, selectedGPU) } - gpuLimitResource, err := utils.GetGPUResource(pod, false) + + // Calculate partition resource usage from config (no overhead) + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, req.PartitionTemplateID) if err != nil { - log.FromContext(s.ctx).Error(err, "Invalid gpu limit annotation", "pod", pod.Name, "namespace", pod.Namespace) + return fmt.Errorf("failed to get partition template info for GPU %s template %s: %w", selectedGPU, req.PartitionTemplateID, err) } - count := 1 - if gpuCountStr, exists := pod.Annotations[constants.GpuCountAnnotation]; exists { - count, err = strconv.Atoi(gpuCountStr) - if err != nil { - return &tfv1.AllocRequest{}, "invalid gpu count annotation", err - } + // Check availability for partition resources + if gpu.Status.Available.Tflops.Cmp(partitionTflops) < 0 { + return fmt.Errorf("GPU %s insufficient TFLOPs for partition: available %s, required %s", + selectedGPU, gpu.Status.Available.Tflops.String(), partitionTflops.String()) } - if count > MaxGPUCounterPerAllocation { - return &tfv1.AllocRequest{}, "gpu count annotation is too large", nil + if gpu.Status.Available.Vram.Cmp(partitionVram) < 0 { + return fmt.Errorf("GPU %s insufficient VRAM for partition: available %s, required %s", + selectedGPU, gpu.Status.Available.Vram.String(), partitionVram.String()) } - qosLevel := tfv1.QoSLevel(pod.Annotations[constants.QoSLevelAnnotation]) - if qosLevel == "" { - qosLevel = tfv1.QoSMedium - } + // Subtract partition resources (no overhead) + gpu.Status.Available.Tflops.Sub(partitionTflops) + gpu.Status.Available.Vram.Sub(partitionVram) - gpuVendor := pod.Annotations[constants.GpuVendorAnnotation] + // Initialize AllocatedPartitions map if needed + if gpu.Status.AllocatedPartitions == nil { + gpu.Status.AllocatedPartitions = make(map[string]tfv1.AllocatedPartition) + } - gpuIndices, hasError := utils.ParseIndicesAnnotation(pod.Annotations[constants.GpuIndicesAnnotation]) - if hasError { - return &tfv1.AllocRequest{}, "invalid gpu-indices annotation", - fmt.Errorf("can not parse gpu indices annotation") + // Find and assign slot position + var slotStart, slotEnd *uint32 + templateConfigs, exists := PartitionTemplateMap[gpu.Status.GPUModel] + if exists { + if templateInfo, found := templateConfigs[req.PartitionTemplateID]; found { + if len(templateInfo.PlacementLimit) > 0 && templateInfo.PlacementOffSet > 0 { + // Build slot occupancy map from existing partitions + occupiedSlots := buildSlotOccupancyMap(gpu, templateConfigs) + // Find available slot position + if startPos, found := findAvailableSlotPosition(templateInfo, occupiedSlots); found { + slotStart = &startPos + endPos := startPos + templateInfo.PlacementOffSet + slotEnd = &endPos + } + } + } } - allocRequest := tfv1.AllocRequest{ - PoolName: pod.Annotations[constants.GpuPoolKey], - Request: gpuRequestResource, - Limit: gpuLimitResource, + // Store partition allocation info using podUID as key + podUID := string(req.PodMeta.UID) + gpu.Status.AllocatedPartitions[podUID] = tfv1.AllocatedPartition{ + TemplateID: req.PartitionTemplateID, + PodUID: podUID, + PodName: req.PodMeta.Name, + Namespace: req.PodMeta.Namespace, + AllocatedAt: metav1.Now(), + AllocatedSlotStart: slotStart, + AllocatedSlotEnd: slotEnd, + } + + log.FromContext(s.ctx).Info("Allocated partition on GPU", + "gpu", selectedGPU, + "template", req.PartitionTemplateID, + "podUID", podUID, + "slotStart", slotStart, + "slotEnd", slotEnd) + return nil +} - Count: uint(count), - GPUModel: pod.Annotations[constants.GPUModelAnnotation], - GPUIndices: gpuIndices, - GPUVendor: gpuVendor, - WorkloadNameNamespace: tfv1.NameNamespace{ - Name: pod.Labels[constants.WorkloadKey], - Namespace: pod.Namespace, - }, - PodMeta: pod.ObjectMeta, - QoS: qosLevel, - } +// deallocPartition handles partition deallocation for a single GPU in partitioned mode +func (s *GpuAllocator) deallocPartition(storeGPU *tfv1.GPU, request *tfv1.AllocRequest, gpu string) { + logger := log.FromContext(s.ctx) + // Find and remove the allocated partition using podUID as key + podUID := string(request.PodMeta.UID) + if storeGPU.Status.AllocatedPartitions != nil { + allocatedPartition, exists := storeGPU.Status.AllocatedPartitions[podUID] + if exists { + // Calculate partition resource usage from config (no overhead) + partitionTflops, partitionVram, err := CalculatePartitionResourceUsage(storeGPU.Status.Capacity.Tflops, storeGPU.Status.GPUModel, allocatedPartition.TemplateID) + if err != nil { + // Fallback: add back request resources if template not found in config + logger.Info("Partition template not found in config during deallocation, using request resources", + "gpu", gpu, "template", allocatedPartition.TemplateID, "error", err) + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) + } else { + // Add back partition resources (no overhead) + storeGPU.Status.Available.Tflops.Add(partitionTflops) + storeGPU.Status.Available.Vram.Add(partitionVram) + } - // for already allocated workers, set the GPU device IDs for further scaling and retrieval - if gpuIdStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { - gpuIds := strings.SplitSeq(gpuIdStr, ",") - allocRequest.GPUNames = slices.Collect(gpuIds) + // Remove partition from allocated partitions map using podUID + delete(storeGPU.Status.AllocatedPartitions, podUID) + logger.Info("Removed partition allocation", + "gpu", gpu, + "podUID", podUID, + "template", allocatedPartition.TemplateID) + } else { + logger.Info("Partition not found in allocated partitions during deallocation", + "gpu", gpu, "podUID", podUID) + // Fallback: add back request resources + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) + } + } else { + // No allocated partitions map, fallback to request resources + storeGPU.Status.Available.Tflops.Add(request.Request.Tflops) + storeGPU.Status.Available.Vram.Add(request.Request.Vram) } - - return &allocRequest, "", nil } func (s *GpuAllocator) addAllocationMap(gpuNodeName string, podMeta metav1.ObjectMeta) { diff --git a/internal/gpuallocator/gpuallocator_test.go b/internal/gpuallocator/gpuallocator_test.go index 496818d3..40042576 100644 --- a/internal/gpuallocator/gpuallocator_test.go +++ b/internal/gpuallocator/gpuallocator_test.go @@ -66,7 +66,7 @@ var _ = Describe("GPU Allocator", func() { } BeforeEach(func() { - allocator = NewGpuAllocator(ctx, k8sClient, 150*time.Millisecond) + allocator = NewGpuAllocator(ctx, nil, k8sClient, 150*time.Millisecond) err := allocator.SetupWithManager(ctx, mgr) Expect(err).NotTo(HaveOccurred()) <-allocator.initializedCh diff --git a/internal/gpuallocator/partitioned_scheduling.go b/internal/gpuallocator/partitioned_scheduling.go new file mode 100644 index 00000000..09bf650a --- /dev/null +++ b/internal/gpuallocator/partitioned_scheduling.go @@ -0,0 +1,342 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gpuallocator + +import ( + "fmt" + "math" + "sort" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/config" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "k8s.io/apimachinery/pkg/api/resource" +) + +const DefaultMaxPartitionNum = 32 +const PartitionMatchingComputingWeight = 0.6 +const PartitionMatchingVRAMWeight = 0.4 + +// PartitionMatchResult represents the result of matching a partition template to a request +type PartitionMatchResult struct { + Template *config.PartitionTemplateInfo // Template info from config + TemplateID string // Template ID + Score float64 // Lower score means better match (less waste) + CanAllocate bool + Reason string +} + +// MatchPartitionTemplate matches a partition template to an allocation request. +// Gets template info from config (PartitionTemplateMap) based on GPU model. +// In partitioned mode, we find the smallest template that can satisfy the request. +func MatchPartitionTemplate(gpuStatus tfv1.GPUStatus, req *tfv1.AllocRequest) (*PartitionMatchResult, error) { + gpuModel := gpuStatus.GPUModel + gpuTemplates := gpuStatus.PartitionTemplates + if len(gpuTemplates) == 0 { + return nil, fmt.Errorf("no partition templates available for GPU model %s", gpuModel) + } + + // Get template configs from global map + templateConfigs, exists := PartitionTemplateMap[gpuModel] + if !exists || len(templateConfigs) == 0 { + return nil, fmt.Errorf("no partition template configs found for GPU model %s", gpuModel) + } + + // Convert request to comparable values + // Handle ComputePercent: convert to TFLOPs if specified + var requestTflops float64 + if !req.Request.ComputePercent.IsZero() { + // Get GPU capacity from global map to convert ComputePercent to TFLOPs + mu.Lock() + gpuCapacity, exists := GPUCapacityMap[gpuModel] + mu.Unlock() + if !exists { + return nil, fmt.Errorf("GPU capacity not found for model %s, cannot convert ComputePercent to TFLOPs", gpuModel) + } + requiredTflops := utils.ComputePercentToTflops(gpuCapacity.Tflops, req.Request) + requestTflops = requiredTflops.AsApproximateFloat64() + } else { + requestTflops = req.Request.Tflops.AsApproximateFloat64() + } + requestVramBytes := req.Request.Vram.Value() + + // Get max partitions from config + maxPartitions := MaxPartitionsMap[gpuModel] + if maxPartitions <= 0 { + maxPartitions = DefaultMaxPartitionNum + } + + // Find the best matching template + var bestMatch *PartitionMatchResult + bestScore := math.MaxFloat64 // Lower is better (we want smallest that fits) + + for _, gpuTemplate := range gpuTemplates { + // Get detailed template info from config + templateInfo, exists := templateConfigs[gpuTemplate.TemplateID] + if !exists { + continue // Skip if template not found in config + } + + // If a specific template is required, only consider that one + if req.PartitionTemplateID != "" && gpuTemplate.TemplateID != req.PartitionTemplateID { + continue + } + + result := &PartitionMatchResult{ + Template: &templateInfo, + TemplateID: gpuTemplate.TemplateID, + CanAllocate: false, + } + + // Check if template resources can satisfy the request + templateTflops := templateInfo.ComputePercent * gpuStatus.Capacity.Tflops.AsApproximateFloat64() + templateVramBytes := int64(templateInfo.MemoryGigabytes * 1024 * 1024 * 1024) + + // Check if template has enough resources + if templateTflops < requestTflops { + result.Reason = fmt.Sprintf("template %s has insufficient TFLOPs: %.2f < %.2f", + gpuTemplate.TemplateID, templateTflops, requestTflops) + continue + } + + if templateVramBytes < requestVramBytes { + result.Reason = fmt.Sprintf("template %s has insufficient VRAM: %d < %d", + gpuTemplate.TemplateID, templateVramBytes, requestVramBytes) + continue + } + + // Check if we can allocate more partitions (MIG constraint) + currentPartitionCount := len(gpuStatus.AllocatedPartitions) + if maxPartitions > 0 && uint32(currentPartitionCount) >= maxPartitions { + result.Reason = fmt.Sprintf("GPU has reached maximum partition count: %d/%d", + currentPartitionCount, maxPartitions) + continue + } + + // Calculate score: prefer templates that are just large enough (minimize waste) + tflopsWaste := (templateTflops - requestTflops) / math.Max(requestTflops, 1.0) + vramWaste := float64(templateVramBytes-requestVramBytes) / math.Max(float64(requestVramBytes), 1.0) + score := tflopsWaste*PartitionMatchingComputingWeight + vramWaste*PartitionMatchingVRAMWeight + + result.Score = score + result.CanAllocate = true + result.Reason = "template can satisfy request" + + // Update best match if this is better + if bestMatch == nil || score < bestScore { + bestMatch = result + bestScore = score + } + } + + if bestMatch == nil { + return nil, fmt.Errorf("no partition template can satisfy request: TFLOPs=%.2f, VRAM=%d", + requestTflops, requestVramBytes) + } + + return bestMatch, nil +} + +// CalculatePartitionResourceUsage calculates the resource usage for a partition template. +// Gets template info from config. +func CalculatePartitionResourceUsage(capacityTflops resource.Quantity, gpuModel, templateID string) (tflops resource.Quantity, vram resource.Quantity, err error) { + templateConfigs, exists := PartitionTemplateMap[gpuModel] + if !exists { + return resource.Quantity{}, resource.Quantity{}, fmt.Errorf("no partition template configs for GPU model %s", gpuModel) + } + + templateInfo, exists := templateConfigs[templateID] + if !exists { + return resource.Quantity{}, resource.Quantity{}, fmt.Errorf("partition template %s not found for GPU model %s", templateID, gpuModel) + } + + tflops = resource.MustParse(fmt.Sprintf("%.2f", templateInfo.ComputePercent*capacityTflops.AsApproximateFloat64()/100.0)) + vram = resource.MustParse(fmt.Sprintf("%dGi", templateInfo.MemoryGigabytes)) + + return tflops, vram, nil +} + +// areSlotsFree checks if slots starting from startPos for offset slots are all free. +func areSlotsFree(occupiedSlots map[uint32]bool, startPos, offset uint32) bool { + for i := range offset { + if occupiedSlots[startPos+i] { + return false + } + } + return true +} + +// buildSlotOccupancyMap builds a map of occupied slots from existing partitions. +// Uses AllocatedSlotStart/End if available, otherwise falls back to greedy assignment. +func buildSlotOccupancyMap( + gpu *tfv1.GPU, + templateConfigs map[string]config.PartitionTemplateInfo, +) map[uint32]bool { + occupiedSlots := make(map[uint32]bool) + + // First, use explicit slot assignments if available + for _, partition := range gpu.Status.AllocatedPartitions { + if partition.AllocatedSlotStart != nil && partition.AllocatedSlotEnd != nil { + start := *partition.AllocatedSlotStart + end := *partition.AllocatedSlotEnd + for slot := start; slot < end; slot++ { + occupiedSlots[slot] = true + } + } + } + + // For partitions without explicit slot assignments, use greedy approach + // Convert map to slice and sort by AllocatedAt timestamp (ASC) + partitions := make([]tfv1.AllocatedPartition, 0, len(gpu.Status.AllocatedPartitions)) + for _, partition := range gpu.Status.AllocatedPartitions { + // Skip if already has explicit slot assignment + if partition.AllocatedSlotStart != nil && partition.AllocatedSlotEnd != nil { + continue + } + partitions = append(partitions, partition) + } + + if len(partitions) > 0 { + sort.Slice(partitions, func(i, j int) bool { + // If both have valid timestamps, compare by time + if !partitions[i].AllocatedAt.IsZero() && !partitions[j].AllocatedAt.IsZero() { + if !partitions[i].AllocatedAt.Equal(&partitions[j].AllocatedAt) { + return partitions[i].AllocatedAt.Before(&partitions[j].AllocatedAt) + } + } + // Fallback to PodUID for stable ordering when timestamps are zero or equal + return partitions[i].PodUID < partitions[j].PodUID + }) + + // Process each partition without explicit slots in allocation order + for _, partition := range partitions { + templateInfo, exists := templateConfigs[partition.TemplateID] + if !exists || len(templateInfo.PlacementLimit) == 0 || templateInfo.PlacementOffSet == 0 { + continue + } + + // Find first available starting position for this partition + for _, startPos := range templateInfo.PlacementLimit { + if areSlotsFree(occupiedSlots, startPos, templateInfo.PlacementOffSet) { + // Assign this partition to this position + for i := uint32(0); i < templateInfo.PlacementOffSet; i++ { + occupiedSlots[startPos+i] = true + } + break + } + } + } + } + + return occupiedSlots +} + +// findAvailableSlotPosition finds the first available slot position for a template. +// Returns the starting position and true if found, 0 and false otherwise. +func findAvailableSlotPosition( + templateInfo config.PartitionTemplateInfo, + occupiedSlots map[uint32]bool, +) (uint32, bool) { + if len(templateInfo.PlacementLimit) == 0 || templateInfo.PlacementOffSet == 0 { + return 0, false + } + + for _, startPos := range templateInfo.PlacementLimit { + if areSlotsFree(occupiedSlots, startPos, templateInfo.PlacementOffSet) { + return startPos, true + } + } + + return 0, false +} + +// CheckPartitionAvailability checks if a GPU has enough resources to allocate a partition. +// Gets template info from config. +func CheckPartitionAvailability( + gpu *tfv1.GPU, + templateID string, +) error { + // Get template info from config first to check template-specific constraints + templateConfigs, exists := PartitionTemplateMap[gpu.Status.GPUModel] + if !exists { + return fmt.Errorf("no partition template configs for GPU model %s", gpu.Status.GPUModel) + } + + templateInfo, exists := templateConfigs[templateID] + if !exists { + return fmt.Errorf("partition template %s not found for GPU model %s", templateID, gpu.Status.GPUModel) + } + + currentCount := len(gpu.Status.AllocatedPartitions) + + // Check general partition count limit first (cheaper check) + maxPartitions := MaxPartitionsMap[gpu.Status.GPUModel] + if maxPartitions == 0 { + maxPartitions = 7 // Default MIG limit + } + if maxPartitions > 0 && uint32(currentCount) >= maxPartitions { + return fmt.Errorf("GPU %s has reached maximum partition count: %d/%d", + gpu.Name, currentCount, maxPartitions) + } + + // Count how many partitions of this template are already allocated + templateCount := uint32(0) + for _, partition := range gpu.Status.AllocatedPartitions { + if partition.TemplateID == templateID { + templateCount++ + } + } + + // Check MaxPartition limit for this specific template + if templateInfo.MaxPartition > 0 && templateCount >= templateInfo.MaxPartition { + return fmt.Errorf("GPU %s has reached maximum partition count for template %s: %d/%d", + gpu.Name, templateID, templateCount, templateInfo.MaxPartition) + } + + // Check placement slots using bitmask-based tracking + if len(templateInfo.PlacementLimit) > 0 && templateInfo.PlacementOffSet > 0 { + // Build slot occupancy map from existing partitions + occupiedSlots := buildSlotOccupancyMap(gpu, templateConfigs) + + // Check if the new template can find a valid placement + _, found := findAvailableSlotPosition(templateInfo, occupiedSlots) + if !found { + return fmt.Errorf("GPU %s has no available placement slots for template %s: required %d slots starting from positions %v", + gpu.Name, templateID, templateInfo.PlacementOffSet, templateInfo.PlacementLimit) + } + } + + // Calculate required resources from config + requiredTflops, requiredVram, err := CalculatePartitionResourceUsage(gpu.Status.Capacity.Tflops, gpu.Status.GPUModel, templateID) + if err != nil { + return err + } + + // Check TFLOPs availability + if gpu.Status.Available.Tflops.Cmp(requiredTflops) < 0 { + return fmt.Errorf("GPU %s insufficient TFLOPs for partition: available %s, required %s", + gpu.Name, gpu.Status.Available.Tflops.String(), requiredTflops.String()) + } + + // Check VRAM availability + if gpu.Status.Available.Vram.Cmp(requiredVram) < 0 { + return fmt.Errorf("GPU %s insufficient VRAM for partition: available %s, required %s", + gpu.Name, gpu.Status.Available.Vram.String(), requiredVram.String()) + } + + return nil +} diff --git a/internal/gpuallocator/partitioned_scheduling_test.go b/internal/gpuallocator/partitioned_scheduling_test.go new file mode 100644 index 00000000..5d020cf2 --- /dev/null +++ b/internal/gpuallocator/partitioned_scheduling_test.go @@ -0,0 +1,527 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gpuallocator + +import ( + "testing" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/config" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const testGPUModel = "A100_SXM_80G" + +func TestMatchPartitionTemplate(t *testing.T) { + // Setup: Initialize partition template map + gpuModel := testGPUModel + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + "1g.24gb": { + TemplateID: "19", + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB (function converts to bytes) + ComputePercent: 1.0 / 7.0 * 100, + }, + "4g.94gb": { + TemplateID: "9", + Name: "4g.94gb", + MemoryGigabytes: 94, // 94GB (function converts to bytes) + ComputePercent: 4.0 / 7.0 * 100, + }, + } + // Setup: Initialize GPU capacity map for ComputePercent conversion + // A100_SXM_80G has ~312 TFLOPs capacity + mu.Lock() + GPUCapacityMap[gpuModel] = tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + } + mu.Unlock() + + tests := []struct { + name string + gpuTemplates []tfv1.PartitionTemplate + req *tfv1.AllocRequest + allocatedPartitions map[string]tfv1.AllocatedPartition + expectError bool + expectedTemplateID string + }{ + { + name: "match smallest template that fits", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "1g.24gb", // Should match smallest that fits + }, + { + name: "match specific template when required", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + PartitionTemplateID: "4g.94gb", + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "4g.94gb", + }, + { + name: "no template matches request", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("300"), // Too large + Vram: resource.MustParse("100Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "no templates available", + gpuTemplates: []tfv1.PartitionTemplate{}, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + Tflops: resource.MustParse("30"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "match with ComputePercent - smallest template that fits", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 10% of 312 TFLOPs = 31.2 TFLOPs, should match 1g.24gb (50 TFLOPs) + ComputePercent: resource.MustParse("10"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "1g.24gb", + }, + { + name: "match with ComputePercent - requires larger template", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + {TemplateID: "4g.94gb", Name: "4g.94gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 50% of 312 TFLOPs = 156 TFLOPs, should match 4g.94gb (200 TFLOPs) + ComputePercent: resource.MustParse("50"), + Vram: resource.MustParse("50Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: false, + expectedTemplateID: "4g.94gb", + }, + { + name: "match with ComputePercent - no template matches", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + // 80% of 312 TFLOPs = 249.6 TFLOPs, too large for 1g.24gb (50 TFLOPs) + ComputePercent: resource.MustParse("80"), + Vram: resource.MustParse("100Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + { + name: "match with ComputePercent - missing GPU capacity", + gpuTemplates: []tfv1.PartitionTemplate{ + {TemplateID: "1g.24gb", Name: "1g.24gb"}, + }, + req: &tfv1.AllocRequest{ + Request: tfv1.Resource{ + ComputePercent: resource.MustParse("10"), + Vram: resource.MustParse("20Gi"), + }, + }, + allocatedPartitions: map[string]tfv1.AllocatedPartition{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use different GPU model for missing capacity test + testGPUModel := gpuModel + if tt.name == "match with ComputePercent - missing GPU capacity" { + testGPUModel = "UNKNOWN_GPU_MODEL" + } + + result, err := MatchPartitionTemplate( + tfv1.GPUStatus{ + GPUModel: testGPUModel, + PartitionTemplates: tt.gpuTemplates, + AllocatedPartitions: tt.allocatedPartitions, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + }, + tt.req, + ) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.CanAllocate) + assert.Equal(t, tt.expectedTemplateID, result.TemplateID) + } + }) + } +} + +func TestCalculatePartitionResourceUsage(t *testing.T) { + // Setup + gpuModel := testGPUModel + templateID := "1g.24gb" + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + templateID: { + TemplateID: templateID, + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB (function converts to bytes) + ComputePercent: 1.0 / 7.0 * 100, + }, + } + + tflops, vram, err := CalculatePartitionResourceUsage(resource.MustParse("312"), gpuModel, templateID) + + assert.NoError(t, err) + // Compare using Cmp to handle different formatting + // 1/7 of 312 TFLOPs = 44.57 TFLOPs + expectedTflops := resource.MustParse("44.57") + assert.Equal(t, 0, tflops.Cmp(expectedTflops), "TFLOPs: got %s, expected %s", tflops.String(), expectedTflops.String()) + // Compare VRAM using Cmp to handle quantity representation differences + assert.Equal(t, 0, vram.Cmp(resource.MustParse("24Gi")), "VRAM: got %s, expected 24Gi", vram.String()) +} + +func TestCheckPartitionAvailability(t *testing.T) { + // Setup: A100 MIG constraints based on nvidia-smi mig -lgipp output + // Profile 19 (1g.24gb): Placements {0,1,2,3,4,5,6}:1 - can start at any of 7 positions, occupies 1 slot each + // Profile 9 (4g.94gb): Placements {0,4}:4 - can start at position 0 or 4, occupies 4 slots each + gpuModel := testGPUModel + template1g := "1g.24gb" // Profile 19 + template4g := "4g.94gb" // Profile 9 + + // Clear and setup maps for this test + mu.Lock() + PartitionTemplateMap[gpuModel] = map[string]config.PartitionTemplateInfo{ + template1g: { + TemplateID: template1g, + Name: "1g.24gb", + MemoryGigabytes: 24, // 24GB + ComputePercent: 1.0 / 7.0 * 100, + MaxPartition: 7, // Can allocate up to 7 instances + PlacementLimit: []uint32{0, 1, 2, 3, 4, 5, 6}, // Can start at any of these positions + PlacementOffSet: 1, // Occupies 1 slot + }, + template4g: { + TemplateID: template4g, + Name: "4g.94gb", + MemoryGigabytes: 94, // 94GB + ComputePercent: 4.0 / 7.0 * 100, + MaxPartition: 2, // Can only allocate 2 instances + PlacementLimit: []uint32{0, 4}, // Can start at position 0 or 4 + PlacementOffSet: 4, // Occupies 4 slots (0-3 or 4-7) + }, + } + MaxPartitionsMap[gpuModel] = 7 + MaxPlacementSlotsMap[gpuModel] = 8 // A100 has 8 placement slots (0-7) + mu.Unlock() + + tests := []struct { + name string + gpu *tfv1.GPU + templateID string + expectError bool + errorContains string + }{ + { + name: "happy path - 1g.24gb allocation succeeds", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("100"), + Vram: resource.MustParse("50Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: false, + }, + { + name: "Profile 19 * 4 should fail - all valid positions occupied", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("96Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at position 0 (slot 0) + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at position 1 (slot 1) + "pod-3": {TemplateID: template1g, PodUID: "pod-3"}, // Profile 19 at position 2 (slot 2) + "pod-4": {TemplateID: template1g, PodUID: "pod-4"}, // Profile 19 at position 3 (slot 3) + // Positions 4,5,6 are still free, but trying to allocate 5th instance + // Actually wait, if we have 4 instances, we need to check if 5th can fit + // Let me change this to have Profile 9 at position 0, then Profile 19 * 3, then try 4th + }, + }, + }, + templateID: template1g, + expectError: false, // Actually 4 instances can fit at positions 0,1,2,3, leaving 4,5,6 free + }, + { + name: "Profile 9 at 0 + Profile 19 * 4 should fail", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("96Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9", AllocatedAt: metav1.NewTime(metav1.Now().Add(-3 * time.Hour))}, // Profile 9 allocated first at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1", AllocatedAt: metav1.NewTime(metav1.Now().Add(-2 * time.Hour))}, // Profile 19 at position 4 (slot 4) + "pod-2": {TemplateID: template1g, PodUID: "pod-2", AllocatedAt: metav1.NewTime(metav1.Now().Add(-1 * time.Hour))}, // Profile 19 at position 5 (slot 5) + "pod-3": {TemplateID: template1g, PodUID: "pod-3", AllocatedAt: metav1.Now()}, // Profile 19 at position 6 (slot 6) + // Trying to allocate 4th Profile 19 instance - should fail + // All valid positions {0,1,2,3,4,5,6} are either occupied or conflict + }, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "placement slots", + }, + { + name: "Profile 9 * 1 + Profile 19 * 3 should work", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("150"), + Vram: resource.MustParse("118Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at slot 4 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at slot 5 + // Trying to allocate 3rd Profile 19 instance - should succeed at slot 6 + }, + }, + }, + templateID: template1g, + expectError: false, // 3rd Profile 19 instance should succeed + }, + { + name: "Profile 9 * 1 + Profile 19 * 3 should work (happy case)", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("150"), + Vram: resource.MustParse("118Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-p9": {TemplateID: template4g, PodUID: "pod-p9"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Profile 19 at slot 4 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Profile 19 at slot 5 + // Trying to allocate 3rd Profile 19 instance - should succeed at slot 6 + }, + }, + }, + templateID: template1g, + expectError: false, + }, + { + name: "Profile 9 - all placement positions occupied", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("94Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template4g, PodUID: "pod-1"}, // Profile 9 at position 0, occupies slots 0,1,2,3 + "pod-2": {TemplateID: template4g, PodUID: "pod-2"}, // Profile 9 at position 4, occupies slots 4,5,6,7 + // Both positions {0,4} are now occupied + }, + }, + }, + templateID: template4g, + expectError: true, + errorContains: "maximum partition count", // MaxPartition check happens first (2/2) + }, + { + name: "insufficient TFLOPs", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("10"), // Too low + Vram: resource.MustParse("50Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "insufficient TFLOPs", + }, + { + name: "insufficient VRAM", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("100"), + Vram: resource.MustParse("10Gi"), // Too low for 24Gi required + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{}, + }, + }, + templateID: template1g, + expectError: true, + errorContains: "insufficient VRAM", + }, + { + name: "Profile 9 can allocate at position 4 when Profile 19 uses slots 0-2", + gpu: &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{Name: "gpu-1"}, + Status: tfv1.GPUStatus{ + GPUModel: gpuModel, + Capacity: &tfv1.Resource{ + Tflops: resource.MustParse("312"), + Vram: resource.MustParse("80Gi"), + }, + Available: &tfv1.Resource{ + Tflops: resource.MustParse("200"), + Vram: resource.MustParse("94Gi"), + }, + AllocatedPartitions: map[string]tfv1.AllocatedPartition{ + "pod-1": {TemplateID: template1g, PodUID: "pod-1"}, // Slot 0 + "pod-2": {TemplateID: template1g, PodUID: "pod-2"}, // Slot 1 + "pod-3": {TemplateID: template1g, PodUID: "pod-3"}, // Slot 2 + // Slots 3,4,5,6,7 are free + // Profile 9 can use position 4 (slots 4,5,6,7) or position 0 (slots 0,1,2,3) + // Position 0 conflicts, but position 4 is free + }, + }, + }, + templateID: template4g, + expectError: false, // Profile 9 can use position 4 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckPartitionAvailability(tt.gpu, tt.templateID) + + if tt.expectError { + if !assert.Error(t, err) { + return // Stop if no error when one is expected + } + if tt.errorContains != "" && err != nil { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/gpuallocator/quota_consolidated_test.go b/internal/gpuallocator/quota_consolidated_test.go index b3345ce0..74acbf18 100644 --- a/internal/gpuallocator/quota_consolidated_test.go +++ b/internal/gpuallocator/quota_consolidated_test.go @@ -377,7 +377,7 @@ var _ = Describe("GPUAllocator Quota Integration", func() { Build() ctx := context.Background() - allocator := NewGpuAllocator(ctx, client, 0) + allocator := NewGpuAllocator(ctx, nil, client, 0) initAllocator(allocator) @@ -415,7 +415,7 @@ var _ = Describe("GPUAllocator Quota Integration", func() { Build() ctx := context.Background() - allocator := NewGpuAllocator(ctx, client, 0) + allocator := NewGpuAllocator(ctx, nil, client, 0) initAllocator(allocator) @@ -451,7 +451,7 @@ var _ = Describe("GPUAllocator Concurrent Quota Enforcement", func() { Build() ctx := context.Background() - allocator := NewGpuAllocator(ctx, client, 0) + allocator := NewGpuAllocator(ctx, nil, client, 0) initAllocator(allocator) @@ -539,7 +539,7 @@ var _ = Describe("GPUAllocator Quota Reconciliation", func() { Build() ctx := context.Background() - allocator := NewGpuAllocator(ctx, client, 0) + allocator := NewGpuAllocator(ctx, nil, client, 0) initAllocator(allocator) @@ -576,7 +576,7 @@ var _ = Describe("GPUAllocator Quota Deallocation", func() { Build() ctx := context.Background() - allocator := NewGpuAllocator(ctx, client, 0) + allocator := NewGpuAllocator(ctx, nil, client, 0) initAllocator(allocator) diff --git a/internal/hypervisor/api/device_types.go b/internal/hypervisor/api/device_types.go new file mode 100644 index 00000000..201dd7d7 --- /dev/null +++ b/internal/hypervisor/api/device_types.go @@ -0,0 +1,111 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package api + +// DeviceInfo represents discovered GPU device information +// +k8s:deepcopy-gen=true +type DeviceInfo struct { + UUID string + Vendor string + Model string + Index int32 + NUMANode int32 + TotalMemoryBytes uint64 + MaxTflops float64 + Capabilities DeviceCapabilities + Properties map[string]string + Healthy bool + + ParentUUID string + + // Host - Guest device node mapping, eg /dev/nvidia0 -> /dev/nvidia0 + // When multiple device allocated, deduplicated by device node + DeviceNode map[string]string + + // Env to inject to guest + DeviceEnv map[string]string +} + +type NodeInfo struct { + // Extra metadata for centralized management + RAMSizeBytes int64 + DataDiskBytes int64 + + // Aggregated info of whole Node + TotalTFlops float64 + TotalVRAMBytes int64 + DeviceIDs []string + + // TODO: discover and merge extra devices and topology info like: + // Nvlink/IB NICs, etc. + // CXL available or not, PCIe generation etc. +} + +// DeviceCapabilities represents device capabilities +// +k8s:deepcopy-gen=true +type DeviceCapabilities struct { + SupportsPartitioning bool + SupportsSoftIsolation bool + SupportsHardIsolation bool + SupportsSnapshot bool + SupportsMetrics bool + MaxPartitions uint32 + MaxWorkersPerDevice uint32 +} + +// ComputeUtilization represents compute utilization for a process on a device +type ComputeUtilization struct { + ProcessID string + DeviceUUID string + UtilizationPercent float64 +} + +// MemoryUtilization represents memory utilization for a process on a device +type MemoryUtilization struct { + ProcessID string + DeviceUUID string + UsedBytes uint64 + ReservedBytes uint64 +} + +// GPUUsageMetrics represents GPU device metrics +// +k8s:deepcopy-gen=true +type GPUUsageMetrics struct { + DeviceUUID string + MemoryBytes uint64 + MemoryPercentage float64 + ComputePercentage float64 + ComputeTflops float64 + Rx float64 // PCIe RX in KB + Tx float64 // PCIe TX in KB + Temperature float64 + PowerUsage int64 // in watts + ExtraMetrics map[string]float64 +} + +// WorkerMetrics represents worker process metrics on a device +// +k8s:deepcopy-gen=true +type WorkerMetrics struct { + DeviceUUID string + WorkerUID string + ProcessID string + MemoryBytes uint64 + MemoryPercentage float64 + ComputeTflops float64 + ComputePercentage float64 + ExtraMetrics map[string]float64 +} diff --git a/internal/hypervisor/api/http_types.go b/internal/hypervisor/api/http_types.go new file mode 100644 index 00000000..d40c46ab --- /dev/null +++ b/internal/hypervisor/api/http_types.go @@ -0,0 +1,91 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package api + +import ( + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// HTTP API Response Types + +// ErrorResponse represents an error response +type ErrorResponse struct { + Error string `json:"error"` +} + +// DataResponse is a generic response wrapper for data-only responses +// +k8s:deepcopy-gen=false +type DataResponse[T any] struct { + Data T `json:"data"` +} + +// MessageAndDataResponse is a generic response wrapper for responses with message and data +type MessageAndDataResponse[T any] struct { + Message string `json:"message"` + Data T `json:"data"` +} + +// StatusResponse represents a simple status response +type StatusResponse struct { + Status string `json:"status"` +} + +// Types to be compatible with legacy APIs + +// LimiterInfo represents worker limiter information (used in legacy.go) +type LimiterInfo struct { + WorkerUID string `json:"worker_uid"` + Requests *tfv1.Resource `json:"requests,omitempty"` + Limits *tfv1.Resource `json:"limits,omitempty"` +} + +// ListLimitersResponse represents the response from GET /api/v1/limiter (used in legacy.go) +type ListLimitersResponse struct { + Limiters []LimiterInfo `json:"limiters"` +} + +// TrapResponse represents the response from POST /api/v1/trap (used in legacy.go) +type TrapResponse struct { + Message string `json:"message"` + SnapshotCount int `json:"snapshot_count"` +} + +// PodInfo represents pod information for the /api/v1/pod endpoint (used in legacy.go) +type PodInfo struct { + PodName string `json:"pod_name"` + Namespace string `json:"namespace"` + GPUIDs []string `json:"gpu_uuids"` + TflopsLimit *float64 `json:"tflops_limit,omitempty"` + VramLimit *uint64 `json:"vram_limit,omitempty"` + QoSLevel tfv1.QoSLevel `json:"qos_level,omitempty"` +} + +// ListPodsResponse represents the response from GET /api/v1/pod (used in legacy.go) +type ListPodsResponse struct { + Pods []PodInfo `json:"pods"` +} + +// ProcessInfo represents process mapping information (used in legacy.go) +type ProcessInfo struct { + WorkerUID string `json:"worker_uid"` + ProcessMapping map[string]string `json:"process_mapping"` // container PID -> host PID +} + +// ListProcessesResponse represents the response from GET /api/v1/process (used in legacy.go) +type ListProcessesResponse struct { + Processes []ProcessInfo `json:"processes"` +} diff --git a/internal/hypervisor/api/worker_types.go b/internal/hypervisor/api/worker_types.go new file mode 100644 index 00000000..e93feb01 --- /dev/null +++ b/internal/hypervisor/api/worker_types.go @@ -0,0 +1,81 @@ +package api + +import ( + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +// IsolationMode represents the isolation mode for worker processes +type IsolationMode = tfv1.IsolationModeType + +// +k8s:deepcopy-gen=true +type WorkerInfo struct { + WorkerUID string + Namespace string + WorkerName string + AllocatedDevices []string + Status WorkerStatus + + QoS tfv1.QoSLevel + IsolationMode IsolationMode + + Requests tfv1.Resource + Limits tfv1.Resource + + WorkloadName string + WorkloadNamespace string + + // Only set for partitioned mode + PartitionTemplateID string + + // Extra information from backend + Labels map[string]string + Annotations map[string]string + + DeletedAt int64 +} + +func (w *WorkerInfo) FilterValue() string { + return w.WorkerUID + " " + w.WorkerName + " " + w.Namespace +} + +type WorkerStatus string + +const ( + WorkerStatusPending WorkerStatus = "Pending" + WorkerStatusDeviceAllocating WorkerStatus = "DeviceAllocating" + WorkerStatusRunning WorkerStatus = "Running" + WorkerStatusTerminated WorkerStatus = "Terminated" +) + +// +k8s:deepcopy-gen=true +type WorkerAllocation struct { + WorkerInfo *WorkerInfo + + // the complete or partitioned device info + DeviceInfos []*DeviceInfo + + Envs map[string]string + + Mounts []*Mount + + Devices []*DeviceSpec +} + +// DeviceSpec specifies a host device to mount into a container. +// +k8s:deepcopy-gen=true +type DeviceSpec struct { + GuestPath string `json:"guestPath,omitempty"` + + HostPath string `json:"hostPath,omitempty"` + + Permissions string `json:"permissions,omitempty"` +} + +// Mount specifies a host volume to mount into a container. +// where device library or tools are installed on host and container +// +k8s:deepcopy-gen=true +type Mount struct { + GuestPath string `json:"guestPath,omitempty"` + + HostPath string `json:"hostPath,omitempty"` +} diff --git a/internal/hypervisor/api/zz_generated.deepcopy.go b/internal/hypervisor/api/zz_generated.deepcopy.go new file mode 100644 index 00000000..3e43cf07 --- /dev/null +++ b/internal/hypervisor/api/zz_generated.deepcopy.go @@ -0,0 +1,245 @@ +//go:build !ignore_autogenerated + +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by controller-gen. DO NOT EDIT. + +package api + +import () + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DeviceCapabilities) DeepCopyInto(out *DeviceCapabilities) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeviceCapabilities. +func (in *DeviceCapabilities) DeepCopy() *DeviceCapabilities { + if in == nil { + return nil + } + out := new(DeviceCapabilities) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DeviceInfo) DeepCopyInto(out *DeviceInfo) { + *out = *in + out.Capabilities = in.Capabilities + if in.Properties != nil { + in, out := &in.Properties, &out.Properties + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + if in.DeviceNode != nil { + in, out := &in.DeviceNode, &out.DeviceNode + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + if in.DeviceEnv != nil { + in, out := &in.DeviceEnv, &out.DeviceEnv + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeviceInfo. +func (in *DeviceInfo) DeepCopy() *DeviceInfo { + if in == nil { + return nil + } + out := new(DeviceInfo) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DeviceSpec) DeepCopyInto(out *DeviceSpec) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeviceSpec. +func (in *DeviceSpec) DeepCopy() *DeviceSpec { + if in == nil { + return nil + } + out := new(DeviceSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *GPUUsageMetrics) DeepCopyInto(out *GPUUsageMetrics) { + *out = *in + if in.ExtraMetrics != nil { + in, out := &in.ExtraMetrics, &out.ExtraMetrics + *out = make(map[string]float64, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUUsageMetrics. +func (in *GPUUsageMetrics) DeepCopy() *GPUUsageMetrics { + if in == nil { + return nil + } + out := new(GPUUsageMetrics) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Mount) DeepCopyInto(out *Mount) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Mount. +func (in *Mount) DeepCopy() *Mount { + if in == nil { + return nil + } + out := new(Mount) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkerAllocation) DeepCopyInto(out *WorkerAllocation) { + *out = *in + if in.WorkerInfo != nil { + in, out := &in.WorkerInfo, &out.WorkerInfo + *out = new(WorkerInfo) + (*in).DeepCopyInto(*out) + } + if in.DeviceInfos != nil { + in, out := &in.DeviceInfos, &out.DeviceInfos + *out = make([]*DeviceInfo, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(DeviceInfo) + (*in).DeepCopyInto(*out) + } + } + } + if in.Envs != nil { + in, out := &in.Envs, &out.Envs + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + if in.Mounts != nil { + in, out := &in.Mounts, &out.Mounts + *out = make([]*Mount, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(Mount) + **out = **in + } + } + } + if in.Devices != nil { + in, out := &in.Devices, &out.Devices + *out = make([]*DeviceSpec, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(DeviceSpec) + **out = **in + } + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkerAllocation. +func (in *WorkerAllocation) DeepCopy() *WorkerAllocation { + if in == nil { + return nil + } + out := new(WorkerAllocation) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkerInfo) DeepCopyInto(out *WorkerInfo) { + *out = *in + if in.AllocatedDevices != nil { + in, out := &in.AllocatedDevices, &out.AllocatedDevices + *out = make([]string, len(*in)) + copy(*out, *in) + } + in.Requests.DeepCopyInto(&out.Requests) + in.Limits.DeepCopyInto(&out.Limits) + if in.Labels != nil { + in, out := &in.Labels, &out.Labels + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + if in.Annotations != nil { + in, out := &in.Annotations, &out.Annotations + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkerInfo. +func (in *WorkerInfo) DeepCopy() *WorkerInfo { + if in == nil { + return nil + } + out := new(WorkerInfo) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *WorkerMetrics) DeepCopyInto(out *WorkerMetrics) { + *out = *in + if in.ExtraMetrics != nil { + in, out := &in.ExtraMetrics, &out.ExtraMetrics + *out = make(map[string]float64, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkerMetrics. +func (in *WorkerMetrics) DeepCopy() *WorkerMetrics { + if in == nil { + return nil + } + out := new(WorkerMetrics) + in.DeepCopyInto(out) + return out +} diff --git a/internal/hypervisor/backend/kubernetes/api_client.go b/internal/hypervisor/backend/kubernetes/api_client.go new file mode 100644 index 00000000..61392e60 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/api_client.go @@ -0,0 +1,173 @@ +package kubernetes + +import ( + "context" + "fmt" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "k8s.io/apimachinery/pkg/api/equality" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "k8s.io/client-go/util/retry" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +var ( + scheme = runtime.NewScheme() +) + +func init() { + utilruntime.Must(tfv1.AddToScheme(scheme)) +} + +// APIClient provides CRUD operations for GPU resources +type APIClient struct { + client client.Client + ctx context.Context +} + +// NewAPIClient creates a new API client instance with an existing client +func NewAPIClient(ctx context.Context, k8sClient client.Client) *APIClient { + return &APIClient{ + client: k8sClient, + ctx: ctx, + } +} + +// NewAPIClientFromConfig creates a new API client instance from a rest.Config +func NewAPIClientFromConfig(ctx context.Context, restConfig *rest.Config) (*APIClient, error) { + k8sClient, err := client.New(restConfig, client.Options{ + Scheme: scheme, + }) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + return &APIClient{ + client: k8sClient, + ctx: ctx, + }, nil +} + +// GPUInfo contains information needed to create or update a GPU +type GPUInfo struct { + UUID string + DeviceName string + VRAMBytes uint64 + TFlops resource.Quantity + Index int32 + NUMANodeID int32 + NodeName string + Vendor string + IsolationMode tfv1.IsolationModeType +} + +// CreateOrUpdateGPU creates or updates a GPU resource with metadata and status +func (a *APIClient) CreateOrUpdateGPU( + gpuNodeName string, gpuID string, + mutateFn func(gpuNode *tfv1.GPUNode, gpu *tfv1.GPU) error, +) error { + // Fetch the GPUNode info + gpuNode := &tfv1.GPUNode{} + if err := a.client.Get(a.ctx, client.ObjectKey{Name: gpuNodeName}, gpuNode); err != nil { + return fmt.Errorf("failed to get GPUNode %s: %w", gpuNodeName, err) + } + + // Create or update GPU metadata + err := retry.OnError(wait.Backoff{ + Steps: 7, + Duration: time.Second, + Factor: 1.0, + Jitter: 0.1, + }, func(err error) bool { + return true // Retry on all errors + }, func() error { + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: gpuID, + }, + } + _, err := controllerutil.CreateOrPatch(a.ctx, a.client, gpu, func() error { + return mutateFn(gpuNode, gpu) + }) + if err != nil { + return err + } + return nil + }) + return err +} + +// GetGPU retrieves a GPU resource by UUID +func (a *APIClient) GetGPU(uuid string) (*tfv1.GPU, error) { + gpu := &tfv1.GPU{} + if err := a.client.Get(a.ctx, client.ObjectKey{Name: uuid}, gpu); err != nil { + return nil, fmt.Errorf("failed to get GPU %s: %w", uuid, err) + } + return gpu, nil +} + +// UpdateGPUStatus updates the status of a GPU resource using merge patch +func (a *APIClient) UpdateGPUStatus(gpu *tfv1.GPU) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + current := &tfv1.GPU{} + if err := a.client.Get(a.ctx, client.ObjectKeyFromObject(gpu), current); err != nil { + return err + } + + patch := client.MergeFrom(current.DeepCopy()) + current.Status = gpu.Status + return a.client.Status().Patch(a.ctx, current, patch) + }) +} + +// UpdateGPUNodeStatus updates the status of a GPUNode resource +func (a *APIClient) UpdateGPUNodeStatus(nodeInfo *api.NodeInfo) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + current := &tfv1.GPUNode{} + if err := a.client.Get(a.ctx, client.ObjectKeyFromObject(current), current); err != nil { + return err + } + + original := current.DeepCopy() + patch := client.MergeFrom(original) + + current.Status.TotalTFlops = resource.MustParse(fmt.Sprintf("%f", nodeInfo.TotalTFlops)) + current.Status.TotalVRAM = resource.MustParse(fmt.Sprintf("%d", nodeInfo.TotalVRAMBytes)) + current.Status.TotalGPUs = int32(len(nodeInfo.DeviceIDs)) + current.Status.ManagedGPUs = current.Status.TotalGPUs + current.Status.ManagedGPUDeviceIDs = nodeInfo.DeviceIDs + current.Status.NodeInfo = tfv1.GPUNodeInfo{ + RAMSize: *resource.NewQuantity(nodeInfo.RAMSizeBytes, resource.DecimalSI), + DataDiskSize: *resource.NewQuantity(nodeInfo.DataDiskBytes, resource.DecimalSI), + } + if current.Status.Phase == "" { + current.Status.Phase = tfv1.TensorFusionGPUNodePhasePending + } + + if equality.Semantic.DeepEqual(original, current) { + return nil + } + return a.client.Status().Patch(a.ctx, current, patch) + }) +} + +// DeleteGPU deletes a GPU resource +func (a *APIClient) DeleteGPU(uuid string) error { + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: uuid, + }, + } + if err := a.client.Delete(a.ctx, gpu); err != nil { + return fmt.Errorf("failed to delete GPU %s: %w", uuid, err) + } + return nil +} diff --git a/internal/hypervisor/backend/kubernetes/deviceplugin.go b/internal/hypervisor/backend/kubernetes/deviceplugin.go new file mode 100644 index 00000000..02e34628 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/deviceplugin.go @@ -0,0 +1,282 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/samber/lo" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "k8s.io/klog/v2" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +const ( + // DevicePluginPath is the path where device plugins should register + DevicePluginPath = "/var/lib/kubelet/device-plugins" + // KubeletSocket is the kubelet registration socket + KubeletSocket = "kubelet.sock" + // DevicePluginEndpoint is the endpoint name for this device plugin + DevicePluginEndpoint = "tensor-fusion-index-%d.sock" +) + +// DevicePlugin implements the Kubernetes device plugin interface +type DevicePlugin struct { + pluginapi.UnimplementedDevicePluginServer + + ctx context.Context + deviceController framework.DeviceController + workerController framework.WorkerController + kubeletClient *PodCacheManager + + server *grpc.Server + socketPath string + resourceNameIndex int +} + +// NewDevicePlugins creates a new device plugin instance +func NewDevicePlugins(ctx context.Context, deviceController framework.DeviceController, workerController framework.WorkerController, kubeletClient *PodCacheManager) []*DevicePlugin { + devicePlugins := make([]*DevicePlugin, constants.IndexKeyLength) + for i := range constants.IndexKeyLength { + devicePlugins[i] = &DevicePlugin{ + ctx: ctx, + deviceController: deviceController, + workerController: workerController, + kubeletClient: kubeletClient, + socketPath: filepath.Join(DevicePluginPath, fmt.Sprintf(DevicePluginEndpoint, i)), + resourceNameIndex: i, + } + } + return devicePlugins +} + +// Start starts the device plugin gRPC server and registers with kubelet +func (dp *DevicePlugin) Start() error { + // Clean up any existing socket + // Check if file exists first to avoid permission errors on non-existent files + if _, err := os.Stat(dp.socketPath); err == nil { + // File exists, try to remove it + if err := os.Remove(dp.socketPath); err != nil { + return fmt.Errorf("failed to remove existing socket: %w", err) + } + } else if !os.IsNotExist(err) { + // Some other error checking file existence (e.g., permission denied on parent directory) + // Log warning but continue - net.Listen will handle it + klog.Warningf("Could not check socket file existence: %v", err) + } + + // Create directory if it doesn't exist + if err := os.MkdirAll(DevicePluginPath, 0750); err != nil { + return fmt.Errorf("failed to create device plugin directory: %w", err) + } + + // Create Unix socket listener + listener, err := net.Listen("unix", dp.socketPath) + if err != nil { + return fmt.Errorf("failed to create listener: %w", err) + } + + // Create gRPC server + dp.server = grpc.NewServer() + pluginapi.RegisterDevicePluginServer(dp.server, dp) + + // Start gRPC server + go func() { + klog.Infof("Starting device plugin gRPC server on %s", dp.socketPath) + if err := dp.server.Serve(listener); err != nil { + klog.Errorf("Device plugin gRPC server error: %v", err) + } + }() + + // Wait for server to be ready + conn, err := dp.dial(dp.socketPath, 5*time.Second) + if err != nil { + return fmt.Errorf("failed to dial device plugin socket: %w", err) + } + _ = conn.Close() + + // Register with kubelet + if err := dp.register(); err != nil { + return fmt.Errorf("failed to register with kubelet: %w", err) + } + return nil +} + +// Stop stops the device plugin +func (dp *DevicePlugin) Stop() error { + if dp.server != nil { + dp.server.Stop() + } + return os.Remove(dp.socketPath) +} + +// register registers the device plugin with kubelet +func (dp *DevicePlugin) register() error { + kubeletSocketPath := filepath.Join(DevicePluginPath, KubeletSocket) + + // Check if kubelet socket exists + if _, err := os.Stat(kubeletSocketPath); os.IsNotExist(err) { + return fmt.Errorf("kubelet socket does not exist at %s (kubelet may not be running or device plugin support not enabled)", kubeletSocketPath) + } else if err != nil { + return fmt.Errorf("failed to check kubelet socket: %w", err) + } + + conn, err := dp.dial(kubeletSocketPath, 5*time.Second) + if err != nil { + return fmt.Errorf("failed to dial kubelet: %w", err) + } + defer func() { + _ = conn.Close() + }() + + client := pluginapi.NewRegistrationClient(conn) + req := &pluginapi.RegisterRequest{ + Version: pluginapi.Version, + Endpoint: fmt.Sprintf(DevicePluginEndpoint, dp.resourceNameIndex), + ResourceName: fmt.Sprintf("%s%s%d", constants.PodIndexAnnotation, constants.PodIndexDelimiter, dp.resourceNameIndex), + Options: &pluginapi.DevicePluginOptions{ + PreStartRequired: false, + GetPreferredAllocationAvailable: false, + }, + } + + _, err = client.Register(context.Background(), req) + if err != nil { + return fmt.Errorf("failed to register: %w", err) + } + + klog.Infof("Successfully registered device plugin with kubelet: tensor-fusion.ai/index_%d", dp.resourceNameIndex) + return nil +} + +// dial establishes a connection to a Unix socket +func (dp *DevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) { + // Use unix:// prefix for gRPC to recognize it as a Unix socket + // The dialer will receive the full address, so we need to strip the prefix + target := "unix://" + unixSocketPath + conn, err := grpc.NewClient(target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + // Strip unix:// prefix to get the actual socket path + socketPath := addr + if len(addr) > 7 && addr[:7] == "unix://" { + socketPath = addr[7:] + } + return net.DialTimeout("unix", socketPath, timeout) + }), + ) + return conn, err +} + +// GetDevicePluginOptions returns options for the device plugin +func (dp *DevicePlugin) GetDevicePluginOptions(ctx context.Context, req *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { + return &pluginapi.DevicePluginOptions{ + PreStartRequired: false, + GetPreferredAllocationAvailable: false, + }, nil +} + +// ListAndWatch streams device list and health updates +func (dp *DevicePlugin) ListAndWatch(req *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error { + klog.Info("ListAndWatch called") + devices := make([]*pluginapi.Device, constants.IndexModLength) + for i := range constants.IndexModLength { + devices[i] = &pluginapi.Device{ + ID: fmt.Sprintf("%d", i+1), + Health: pluginapi.Healthy, + } + } + if err := stream.Send(&pluginapi.ListAndWatchResponse{Devices: devices}); err != nil { + return fmt.Errorf("failed to send device list: %w", err) + } + return nil +} + +// Allocate handles device allocation requests from kubelet +func (dp *DevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { + responses := make([]*pluginapi.ContainerAllocateResponse, 0, len(req.ContainerRequests)) + + for containerIdx, containerReq := range req.ContainerRequests { + podIndex := len(containerReq.DevicesIds) + if podIndex <= 0 || podIndex > constants.IndexModLength { + return nil, fmt.Errorf("container request %d dummy device requests is not valid: (expected index value 1-%d)", containerIdx, constants.IndexModLength) + } + + podIndexFull := podIndex + (dp.resourceNameIndex * constants.IndexModLength) + + klog.V(4).Infof("Processing allocation for container index %d, pod index %d (from DevicesIds)", containerIdx, podIndexFull) + // Get worker info from kubelet client using pod index + // This will automatically check for duplicate indices and fail fast if found + workerInfo, err := dp.kubeletClient.GetWorkerInfoForAllocationByIndex(podIndexFull) + if err != nil { + klog.Errorf("Failed to get worker info for pod index %d: %v", podIndexFull, err) + return nil, fmt.Errorf("failed to get worker info for pod index %d: %w", podIndexFull, err) + } + if workerInfo == nil { + return nil, fmt.Errorf("worker info not found for pod index %d", podIndexFull) + } + // Call worker controller to allocate + allocResp, err := dp.workerController.AllocateWorkerDevices(workerInfo) + if err != nil { + return nil, fmt.Errorf("failed to allocate devices for worker %s %s: %w", workerInfo.WorkerName, workerInfo.WorkerUID, err) + } + + containerResp := &pluginapi.ContainerAllocateResponse{ + Envs: allocResp.Envs, + Mounts: lo.Map(allocResp.Mounts, func(mount *api.Mount, _ int) *pluginapi.Mount { + return &pluginapi.Mount{ + ContainerPath: mount.GuestPath, + HostPath: mount.HostPath, + } + }), + Devices: lo.Map(allocResp.Devices, func(device *api.DeviceSpec, _ int) *pluginapi.DeviceSpec { + return &pluginapi.DeviceSpec{ + ContainerPath: device.GuestPath, + HostPath: device.HostPath, + Permissions: device.Permissions, + } + }), + CdiDevices: []*pluginapi.CDIDevice{}, + } + responses = append(responses, containerResp) + } + + return &pluginapi.AllocateResponse{ + ContainerResponses: responses, + }, nil +} + +// PreStartContainer is called before container start (optional) +func (dp *DevicePlugin) PreStartContainer(ctx context.Context, req *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { + return &pluginapi.PreStartContainerResponse{}, nil +} + +// GetPreferredAllocation returns preferred device allocation (optional) +func (dp *DevicePlugin) GetPreferredAllocation(ctx context.Context, req *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { + return &pluginapi.PreferredAllocationResponse{ + ContainerResponses: []*pluginapi.ContainerPreferredAllocationResponse{}, + }, nil +} diff --git a/internal/hypervisor/backend/kubernetes/deviceplugin_test.go b/internal/hypervisor/backend/kubernetes/deviceplugin_test.go new file mode 100644 index 00000000..3724d120 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/deviceplugin_test.go @@ -0,0 +1,81 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +// TestDevicePluginAllocate_ExtractsIndexFromDevicesIds tests that the device plugin +// correctly extracts the pod index from DevicesIds[0], not from len(req.ContainerRequests) +// This is a key test to verify the device plugin implementation matches the design: +// - DevicesIds[0] contains the index value (1-512) from resource limits +// - len(req.ContainerRequests) is just the number of containers, NOT the pod index +// - CdiDevices must be empty to prevent dummy device allocation +func TestDevicePluginAllocate_ExtractsIndexFromDevicesIds(t *testing.T) { + // This test verifies the key design principle: + // The pod index comes from DevicesIds[0], which contains the value from + // tensor-fusion.ai/index resource limit, NOT from len(req.ContainerRequests) + + req := &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIds: []string{"3"}, // Index "3" from resource limit + }, + }, + } + + // Verify the structure: len(ContainerRequests) = 1, but index is "3" from DevicesIds[0] + assert.Len(t, req.ContainerRequests, 1, "Should have 1 container request") + assert.Equal(t, "3", req.ContainerRequests[0].DevicesIds[0], "Index should come from DevicesIds[0], not from len(ContainerRequests)") + + // This demonstrates that len(req.ContainerRequests) is NOT the pod index + // The pod index is extracted from DevicesIds[0] + assert.NotEqual(t, len(req.ContainerRequests), 3, "len(ContainerRequests) should NOT equal the pod index") +} + +// TestDevicePluginAllocate_MultipleContainers tests that len(req.ContainerRequests) +// is used for iteration, not for pod index identification +func TestDevicePluginAllocate_MultipleContainers(t *testing.T) { + // Create request with 2 containers, both with index "5" + // len(ContainerRequests) = 2, but pod index is still "5" from DevicesIds + req := &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIds: []string{"5"}, // First container: index 5 + }, + { + DevicesIds: []string{"5"}, // Second container: same pod, same index + }, + }, + } + + // Verify: len(ContainerRequests) = 2, but index is "5" from DevicesIds + assert.Len(t, req.ContainerRequests, 2, "Should have 2 container requests") + assert.Equal(t, "5", req.ContainerRequests[0].DevicesIds[0], "First container index from DevicesIds") + assert.Equal(t, "5", req.ContainerRequests[1].DevicesIds[0], "Second container index from DevicesIds") + + // Key verification: len(ContainerRequests) is NOT the pod index + assert.NotEqual(t, len(req.ContainerRequests), 5, "len(ContainerRequests) should NOT equal the pod index") + + // Both containers have the same index because they're in the same pod + assert.Equal(t, req.ContainerRequests[0].DevicesIds[0], req.ContainerRequests[1].DevicesIds[0], + "Both containers should have the same index (same pod)") +} diff --git a/internal/hypervisor/backend/kubernetes/dra.go b/internal/hypervisor/backend/kubernetes/dra.go new file mode 100644 index 00000000..276009a4 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/dra.go @@ -0,0 +1 @@ +package kubernetes diff --git a/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go b/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go new file mode 100644 index 00000000..33ce2e12 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/external_dp/detector_test.go @@ -0,0 +1,283 @@ +package external_dp + +import ( + "context" + "os" + "testing" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// MockAPIServer is a mock implementation of APIServerInterface +type MockAPIServer struct { + mock.Mock +} + +func (m *MockAPIServer) GetGPU(uuid string) (*tfv1.GPU, error) { + args := m.Called(uuid) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*tfv1.GPU), args.Error(1) +} + +func (m *MockAPIServer) UpdateGPUStatus(gpu *tfv1.GPU) error { + args := m.Called(gpu) + return args.Error(0) +} + +// MockKubeletClient is a mock implementation of KubeletClientInterface +type MockKubeletClient struct { + mock.Mock + pods map[string]any +} + +func (m *MockKubeletClient) GetAllPods() map[string]any { + return m.pods +} + +func TestReadCheckpointFile(t *testing.T) { + // Create a temporary checkpoint file with test data + testData := `{ + "Data": { + "PodDeviceEntries": [ + { + "PodUID": "a7461dc1-023a-4bd5-a403-c738bb1d7db4", + "ContainerName": "web", + "ResourceName": "nvidia.com/gpu", + "DeviceIDs": { + "-1": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + }, + "AllocResp": "CkIKFk5WSURJQV9WSVNJQkxFX0RFVklDRVMSKEdQVS03ZDg0MjlkNS01MzFkLWQ2YTYtNjUxMC0zYjY2MjA4MWE3NWEaJAoOL2Rldi9udmlkaWFjdGwSDi9kZXYvbnZpZGlhY3RsGgJydxomCg8vZGV2L252aWRpYS11dm0SDy9kZXYvbnZpZGlhLXV2bRoCcncaMgoVL2Rldi9udmlkaWEtdXZtLXRvb2xzEhUvZGV2L252aWRpYS11dm0tdG9vbHMaAnJ3Gi4KEy9kZXYvbnZpZGlhLW1vZGVzZXQSEy9kZXYvbnZpZGlhLW1vZGVzZXQaAnJ3GiAKDC9kZXYvbnZpZGlhMBIML2Rldi9udmlkaWEwGgJydw==" + } + ], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + }, + "Checksum": 2262205670 +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(testData) + assert.NoError(t, err) + _ = tmpFile.Close() + + detector := &DevicePluginDetector{ + checkpointPath: tmpFile.Name(), + } + + checkpoint, err := detector.readCheckpointFile() + assert.NoError(t, err) + assert.NotNil(t, checkpoint) + assert.Len(t, checkpoint.Data.PodDeviceEntries, 1) + assert.Equal(t, "a7461dc1-023a-4bd5-a403-c738bb1d7db4", checkpoint.Data.PodDeviceEntries[0].PodUID) + assert.Equal(t, "nvidia.com/gpu", checkpoint.Data.PodDeviceEntries[0].ResourceName) + assert.Contains(t, checkpoint.Data.RegisteredDevices, "nvidia.com/gpu") +} + +func TestExtractDeviceIDs(t *testing.T) { + checkpoint := &KubeletCheckpoint{ + Data: CheckpointData{ + PodDeviceEntries: []PodDeviceEntry{ + { + ResourceName: "nvidia.com/gpu", + DeviceIDs: map[string][]string{ + "-1": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + }, + RegisteredDevices: map[string][]string{ + "nvidia.com/gpu": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + } + + detector := &DevicePluginDetector{ + vendorDetectors: map[string]VendorDetector{ + "nvidia.com/gpu": NewNvidiaDevicePluginDetector(), + }, + } + + allocated, registered := detector.extractDeviceIDs(checkpoint) + assert.Contains(t, allocated, "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a") + assert.Contains(t, registered, "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a") +} + +func TestNvidiaDevicePluginDetector(t *testing.T) { + detector := NewNvidiaDevicePluginDetector() + assert.Equal(t, []string{"nvidia.com/gpu", "nvidia.com/mig"}, detector.GetResourceNamePrefixes()) + system, realDeviceID := detector.GetUsedBySystemAndRealDeviceID("GPU-8511dc03-7592-b8b7-1a92-582d40da52fb", "nvidia.com/gpu") + assert.Equal(t, string(UsedByNvidiaDevicePlugin), system) + assert.Equal(t, "GPU-8511dc03-7592-b8b7-1a92-582d40da52fb", realDeviceID) + // External device plugin detection only works for nvidia.com/gpu resources with device IDs longer than 40 characters + system, realDeviceID = detector.GetUsedBySystemAndRealDeviceID("GPU-422d6152-4d4b-5b0e-9d3a-b3b44e2742ea-1", "nvidia.com/gpu") + assert.Equal(t, string(UsedBy3rdPartyDevicePlugin), system) + assert.Equal(t, "GPU-422d6152-4d4b-5b0e-9d3a-b3b44e2742ea", realDeviceID) + // nvidia.com/mig always returns nvidia-device-plugin + system, realDeviceID = detector.GetUsedBySystemAndRealDeviceID("MIG-422d6152-4d4b-5b0e-9d3a-b3b44e2742ea", "nvidia.com/mig-1g.5gb") + assert.Equal(t, string(UsedByNvidiaDevicePlugin), system) + assert.Equal(t, "MIG-422d6152-4d4b-5b0e-9d3a-b3b44e2742ea", realDeviceID) +} + +func TestProcessDeviceState_DeviceAdded(t *testing.T) { + mockAPI := new(MockAPIServer) + + checkpointData := `{ + "Data": { + "PodDeviceEntries": [ + { + "PodUID": "a7461dc1-023a-4bd5-a403-c738bb1d7db4", + "ContainerName": "web", + "ResourceName": "nvidia.com/gpu", + "DeviceIDs": { + "-1": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } + ], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(checkpointData) + assert.NoError(t, err) + _ = tmpFile.Close() + + // Mock GPU resource + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a", + }, + Status: tfv1.GPUStatus{ + UsedBy: tfv1.UsedByTensorFusion, + }, + } + + mockAPI.On("GetGPU", "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a").Return(gpu, nil) + mockAPI.On("UpdateGPUStatus", mock.MatchedBy(func(gpu *tfv1.GPU) bool { + return gpu.Status.UsedBy == UsedByNvidiaDevicePlugin + })).Return(nil) + + detector := &DevicePluginDetector{ + ctx: context.Background(), + checkpointPath: tmpFile.Name(), + apiClient: mockAPI, + vendorDetectors: make(map[string]VendorDetector), + previousDeviceIDs: make(map[string]string), + } + // Register vendor detectors properly - use the same pattern as registerVendorDetectors + nvdpDetector := NewNvidiaDevicePluginDetector() + for _, prefix := range nvdpDetector.GetResourceNamePrefixes() { + detector.vendorDetectors[prefix] = nvdpDetector + } + + // Verify checkpoint can be read and devices extracted + checkpoint, err := detector.readCheckpointFile() + assert.NoError(t, err) + allocated, _ := detector.extractDeviceIDs(checkpoint) + assert.Contains(t, allocated, "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a", "Device should be in allocated map") + + err = detector.processDeviceState(false) + assert.NoError(t, err) + mockAPI.AssertExpectations(t) +} + +func TestProcessDeviceState_DeviceRemoved(t *testing.T) { + mockAPI := new(MockAPIServer) + + checkpointData := `{ + "Data": { + "PodDeviceEntries": [], + "RegisteredDevices": { + "nvidia.com/gpu": [ + "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a" + ] + } + } +}` + + tmpFile, err := os.CreateTemp("", "checkpoint-*.json") + assert.NoError(t, err) + defer func() { + _ = os.Remove(tmpFile.Name()) + }() + + _, err = tmpFile.WriteString(checkpointData) + assert.NoError(t, err) + _ = tmpFile.Close() + + // Mock GPU resource that was previously allocated + gpu := &tfv1.GPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a", + }, + Status: tfv1.GPUStatus{ + UsedBy: UsedByNvidiaDevicePlugin, + }, + } + + mockAPI.On("GetGPU", "gpu-7d8429d5-531d-d6a6-6510-3b662081a75a").Return(gpu, nil) + mockAPI.On("UpdateGPUStatus", mock.MatchedBy(func(gpu *tfv1.GPU) bool { + return gpu.Status.UsedBy == tfv1.UsedByTensorFusion + })).Return(nil) + + detector := &DevicePluginDetector{ + ctx: context.Background(), + checkpointPath: tmpFile.Name(), + apiClient: mockAPI, + vendorDetectors: make(map[string]VendorDetector), + previousDeviceIDs: map[string]string{"gpu-7d8429d5-531d-d6a6-6510-3b662081a75a": "nvidia.com/gpu"}, + } + // Register vendor detectors properly - use the same pattern as registerVendorDetectors + nvdpDetector := NewNvidiaDevicePluginDetector() + for _, prefix := range nvdpDetector.GetResourceNamePrefixes() { + detector.vendorDetectors[prefix] = nvdpDetector + } + + err = detector.processDeviceState(false) + assert.NoError(t, err) + mockAPI.AssertExpectations(t) +} + +func TestFindEntryForDevice(t *testing.T) { + checkpoint := &KubeletCheckpoint{ + Data: CheckpointData{ + PodDeviceEntries: []PodDeviceEntry{ + { + ResourceName: "nvidia.com/gpu", + DeviceIDs: map[string][]string{ + "-1": {"GPU-7d8429d5-531d-d6a6-6510-3b662081a75a"}, + }, + }, + }, + }, + } + + detector := &DevicePluginDetector{} + entry := detector.findEntryForDevice(checkpoint, "GPU-7d8429d5-531d-d6a6-6510-3b662081a75a") + assert.Equal(t, "nvidia.com/gpu", entry.ResourceName) +} diff --git a/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go b/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go new file mode 100644 index 00000000..cd0fe841 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/external_dp/kubelet_checkpoint.go @@ -0,0 +1,564 @@ +package external_dp + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "math/rand" + "net" + "os" + "path/filepath" + "slices" + "strings" + "sync" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/fsnotify/fsnotify" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/rest" + "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + // Default kubelet checkpoint file path + defaultKubeletCheckpointPath = "/var/lib/kubelet/device-plugins/kubelet_internal_checkpoint" + + // Default kubelet pod-resources socket path + defaultKubeletPodResourcesSocket = "/var/lib/kubelet/pod-resources/kubelet.sock" + + // Polling intervals + defaultPollInterval = 30 * time.Second + defaultPatchAllInterval = 120 * time.Second + patchAllIntervalJitter = 0.15 // ±15% jitter +) + +var ( + scheme = runtime.NewScheme() +) + +func init() { + utilruntime.Must(tfv1.AddToScheme(scheme)) +} + +// KubeletCheckpoint represents the structure of kubelet device checkpoint file +type KubeletCheckpoint struct { + Data CheckpointData `json:"Data"` +} + +type CheckpointData struct { + PodDeviceEntries []PodDeviceEntry `json:"PodDeviceEntries,omitempty"` + RegisteredDevices map[string][]string `json:"RegisteredDevices,omitempty"` +} + +type PodDeviceEntry struct { + PodUID string `json:"PodUID"` + ContainerName string `json:"ContainerName"` + ResourceName string `json:"ResourceName"` + DeviceIDs map[string][]string `json:"DeviceIDs"` +} + +// VendorDetector interface for vendor-specific device plugin detectors +type VendorDetector interface { + // GetResourceName returns the resource name this detector handles (e.g., "nvidia.com/gpu") + GetResourceNamePrefixes() []string + // GetUsedBySystem returns the UsedBy system name for this vendor + GetUsedBySystemAndRealDeviceID(deviceID, resourceName string) (system string, realDeviceID string) +} + +// APIClientInterface defines the interface for GPU API operations +type APIClientInterface interface { + GetGPU(uuid string) (*tfv1.GPU, error) + UpdateGPUStatus(gpu *tfv1.GPU) error +} + +// DevicePluginDetector watches kubelet device checkpoint and manages GPU resource patching +type DevicePluginDetector struct { + ctx context.Context + checkpointPath string + apiClient APIClientInterface + vendorDetectors map[string]VendorDetector // key: resource name + previousDeviceIDs map[string]string + mu sync.RWMutex + watcher *fsnotify.Watcher + stopCh chan struct{} + + k8sClient client.Client +} + +// NewDevicePluginDetector creates a new device plugin detector +func NewDevicePluginDetector( + ctx context.Context, + checkpointPath string, + apiClient APIClientInterface, + restConfig *rest.Config, +) (*DevicePluginDetector, error) { + k8sClient, err := client.New(restConfig, client.Options{ + Scheme: scheme, + }) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + } + if checkpointPath == "" { + checkpointPath = defaultKubeletCheckpointPath + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + klog.Errorf("failed to create filesystem watcher for kubelet CDI checkpoint file: %v", err) + } + + detector := &DevicePluginDetector{ + ctx: ctx, + checkpointPath: checkpointPath, + apiClient: apiClient, + vendorDetectors: make(map[string]VendorDetector), + previousDeviceIDs: make(map[string]string), + watcher: watcher, + k8sClient: k8sClient, + stopCh: make(chan struct{}), + } + + // Register vendor-specific detectors + detector.registerVendorDetectors() + + return detector, nil +} + +// registerVendorDetectors registers all vendor-specific detectors +func (d *DevicePluginDetector) registerVendorDetectors() { + // Register NVIDIA detector + nvdpDetector := NewNvidiaDevicePluginDetector() + resourceNamePrefixes := nvdpDetector.GetResourceNamePrefixes() + for _, resourceNamePrefix := range resourceNamePrefixes { + d.vendorDetectors[resourceNamePrefix] = nvdpDetector + } + + // Add more vendor detectors here as needed + // amdDetector := NewAMDDevicePluginDetector() + // d.vendorDetectors[amdDetector.GetResourceName()] = amdDetector +} + +// Start starts watching the checkpoint file and processing device allocations +func (d *DevicePluginDetector) Start() error { + klog.Info("Starting device plugin detector", "checkpointPath", d.checkpointPath) + + // Setup filesystem watcher + if err := d.setupFilesystemWatcher(); err != nil { + klog.Warningf("Failed to setup filesystem watcher, falling back to polling only: %v", err) + } + + // Start processing loop + go d.run() + + return nil +} + +// Stop stops the detector +func (d *DevicePluginDetector) Stop() { + close(d.stopCh) + if d.watcher != nil { + _ = d.watcher.Close() + } +} + +// setupFilesystemWatcher sets up filesystem watcher for the checkpoint file +func (d *DevicePluginDetector) setupFilesystemWatcher() error { + // Watch the directory containing the checkpoint file + dir := filepath.Dir(d.checkpointPath) + if err := d.watcher.Add(dir); err != nil { + return fmt.Errorf("failed to watch directory %s: %w", dir, err) + } + + // Also watch the file itself if it exists + if _, err := os.Stat(d.checkpointPath); err == nil { + if err := d.watcher.Add(d.checkpointPath); err != nil { + klog.Warningf("Failed to watch checkpoint file directly: %v", err) + } + } + + klog.Infof("Filesystem watcher enabled for checkpoint file: %s", d.checkpointPath) + return nil +} + +// run is the main processing loop +func (d *DevicePluginDetector) run() { + // Create tickers for periodic polling + pollTicker := time.NewTicker(defaultPollInterval) + defer pollTicker.Stop() + + patchAllInterval := d.durationWithJitter(defaultPatchAllInterval, patchAllIntervalJitter) + patchAllTicker := time.NewTicker(patchAllInterval) + defer patchAllTicker.Stop() + + // Process initial state + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process initial device state: %v", err) + } + + for { + select { + case <-d.ctx.Done(): + klog.Info("Device plugin detector shutdown requested") + return + + case <-d.stopCh: + klog.Info("Device plugin detector stopped") + return + + case event, ok := <-d.watcher.Events: + if !ok { + klog.Warning("Filesystem watcher channel closed, restarting watcher") + // Try to restart watcher + if err := d.setupFilesystemWatcher(); err != nil { + klog.Errorf("Failed to restart filesystem watcher: %v", err) + } + continue + } + + // Process checkpoint file changes + if event.Op&(fsnotify.Write|fsnotify.Create) != 0 && + (event.Name == d.checkpointPath || strings.HasSuffix(event.Name, filepath.Base(d.checkpointPath))) { + klog.V(4).Infof("Checkpoint file changed: %s", event.Name) + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process device state after filesystem event: %v", err) + } + } + + case err := <-d.watcher.Errors: + if err != nil { + klog.Errorf("Filesystem watcher error: %v", err) + } + + case <-pollTicker.C: + // Periodic polling fallback + klog.V(4).Info("Periodic polling check") + if err := d.processDeviceState(false); err != nil { + klog.Errorf("Failed to process device state during periodic check: %v", err) + } + + case <-patchAllTicker.C: + // Periodic full patch check to handle deleted pods + klog.V(4).Info("Checking all devices for deleted pods") + if err := d.processDeviceState(true); err != nil { + klog.Errorf("Failed to process device state during patch all check: %v", err) + } + // Reset ticker with new jitter + patchAllTicker.Reset(d.durationWithJitter(defaultPatchAllInterval, patchAllIntervalJitter)) + } + } +} + +// processDeviceState reads and processes the device checkpoint state +func (d *DevicePluginDetector) processDeviceState(patchAllDevices bool) error { + d.mu.Lock() + defer d.mu.Unlock() + // Read checkpoint file + checkpoint, err := d.readCheckpointFile() + if err != nil { + return fmt.Errorf("failed to read checkpoint file: %w", err) + } + + // Extract registered device IDs (for comparison) + allocated, registeredDeviceIDs := d.extractDeviceIDs(checkpoint) + if d.grpcEndpointAvailable() { + // Use kubelet pod-resources gRPC API as SSoT if available, otherwise fallback to checkpoint + allocatedDevices, err := d.getAllocatedDevices() + if err != nil { + klog.Errorf("Failed to get allocated devices from gRPC: %v", err) + } else { + allocated = allocatedDevices + } + } + + // Determine added and removed devices + previousDeviceIDs := make(map[string]string, len(d.previousDeviceIDs)) + maps.Copy(previousDeviceIDs, d.previousDeviceIDs) + + var addedDevices, removedDevices map[string]string + + if patchAllDevices { + // Patch all devices: treat all allocated as added, and all registered but not allocated as removed + addedDevices = allocated + removedDevices = make(map[string]string) + for deviceID := range registeredDeviceIDs { + if resName, exists := allocated[deviceID]; !exists { + removedDevices[deviceID] = resName + } + } + } else { + // Only process changes + addedDevices = make(map[string]string) + removedDevices = make(map[string]string) + + for deviceID, resName := range allocated { + if _, exists := previousDeviceIDs[deviceID]; !exists { + addedDevices[deviceID] = resName + } + } + + for deviceID, resName := range previousDeviceIDs { + if _, exists := allocated[deviceID]; !exists { + removedDevices[deviceID] = resName + } + } + } + + // Process added devices using vendor-specific detectors + hasError := false + for deviceID, resName := range addedDevices { + for _, detector := range d.vendorDetectors { + resourceNamePrefixes := detector.GetResourceNamePrefixes() + if slices.Contains(resourceNamePrefixes, resName) { + usedBySystem, realDeviceID := detector.GetUsedBySystemAndRealDeviceID(deviceID, resName) + klog.V(4).Infof("Device added: %s, resource: %s, patching with usedBy: %s, realDeviceID: %s", deviceID, resName, usedBySystem, realDeviceID) + if err := d.patchGPUResource(realDeviceID, usedBySystem); err != nil { + klog.Errorf("Failed to patch GPU resource for added device %s: %v", deviceID, err) + hasError = true + } + } + } + } + + // Process removed devices + for deviceID, resName := range removedDevices { + for _, detector := range d.vendorDetectors { + resourceNamePrefixes := detector.GetResourceNamePrefixes() + if slices.Contains(resourceNamePrefixes, resName) { + klog.V(4).Infof("Device plugin allocated container removed: %s, resource: %s, patching usedBy field to tensor fusion", deviceID, resName) + if err := d.patchGPUResource(deviceID, string(tfv1.UsedByTensorFusion)); err != nil { + klog.Errorf("Failed to patch GPU resource usedBy field to tensor fusion for removed device %s: %v", deviceID, err) + hasError = true + } + } + } + } + + // Update previous state only if no errors occurred + if !hasError { + d.previousDeviceIDs = allocated + } + return nil +} + +// patchGPUResource patches a GPU resource with the specified usedBy value +func (d *DevicePluginDetector) patchGPUResource(deviceID, usedBySystem string) error { + const maxRetries = 3 + + for i := range maxRetries { + // Get current GPU resource + gpu, err := d.apiClient.GetGPU(deviceID) + if err != nil { + if i < maxRetries-1 { + backoff := time.Duration(200*(1< 7 && addr[:7] == "unix://" { + socketPath = addr[7:] + } + return net.DialTimeout("unix", socketPath, timeout) + }), + ) + return conn, err +} diff --git a/internal/hypervisor/backend/kubernetes/external_dp/nvdp_detector.go b/internal/hypervisor/backend/kubernetes/external_dp/nvdp_detector.go new file mode 100644 index 00000000..bd81164b --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/external_dp/nvdp_detector.go @@ -0,0 +1,42 @@ +package external_dp + +import ( + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" +) + +const ( + resourceNvidiaGPU = "nvidia.com/gpu" + resourceNvidiaMIG = "nvidia.com/mig" + realDeviceIDLength = 40 +) + +var UsedByNvidiaDevicePlugin = tfv1.UsedBySystem("nvidia-device-plugin") +var UsedBy3rdPartyDevicePlugin = tfv1.UsedBySystem("3rd-party-device-plugin") + +// NvidiaDevicePluginDetector handles NVIDIA-specific device plugin detection +type NvidiaDevicePluginDetector struct{} + +// NewNvidiaDevicePluginDetector creates a new NVIDIA device plugin detector +func NewNvidiaDevicePluginDetector() *NvidiaDevicePluginDetector { + return &NvidiaDevicePluginDetector{} +} + +// GetResourceName returns the resource name this detector handles +func (n *NvidiaDevicePluginDetector) GetResourceNamePrefixes() []string { + return []string{resourceNvidiaGPU, resourceNvidiaMIG} +} + +// GetUsedBySystem returns the UsedBy system name for NVIDIA +func (n *NvidiaDevicePluginDetector) GetUsedBySystemAndRealDeviceID(deviceID, resourceName string) (system string, realDeviceID string) { + if resourceName == resourceNvidiaGPU { + // Some external device plugin's device ID is GPU-(UUID)-0, 1, 2, 3 (e.g. HAMI) + // Need to recover to real device ID + if len(deviceID) > realDeviceIDLength { + return string(UsedBy3rdPartyDevicePlugin), deviceID[:realDeviceIDLength] + } else { + return string(UsedByNvidiaDevicePlugin), deviceID + } + } else { + return string(UsedByNvidiaDevicePlugin), deviceID + } +} diff --git a/internal/hypervisor/backend/kubernetes/kubernetes_backend.go b/internal/hypervisor/backend/kubernetes/kubernetes_backend.go new file mode 100644 index 00000000..6c4cff3b --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/kubernetes_backend.go @@ -0,0 +1,307 @@ +package kubernetes + +import ( + "context" + "fmt" + "os" + "sync" + "time" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/kubernetes/external_dp" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/google/uuid" + "github.com/samber/lo" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/rest" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" +) + +type KubeletBackend struct { + ctx context.Context + + deviceController framework.DeviceController + workerController framework.WorkerController + + apiClient *APIClient + podCacher *PodCacheManager + devicePlugins []*DevicePlugin + deviceDetector *external_dp.DevicePluginDetector + + nodeName string + + workers map[string]*api.WorkerInfo + workersMu sync.RWMutex + + subscribers map[string]struct{} + workerHandler *framework.WorkerChangeHandler +} + +var _ framework.Backend = &KubeletBackend{} + +func NewKubeletBackend(ctx context.Context, deviceController framework.DeviceController, workerController framework.WorkerController, restConfig *rest.Config) (*KubeletBackend, error) { + // Get node name from environment or config + nodeName := os.Getenv(constants.HypervisorGPUNodeNameEnv) + if nodeName == "" { + return nil, fmt.Errorf("node name env var 'GPU_NODE_NAME' for this hypervisor not set") + } + + // Create kubelet client + podCacher, err := NewPodCacheManager(ctx, restConfig, nodeName) + if err != nil { + return nil, err + } + + // Create API server for device detector + apiClient, err := NewAPIClientFromConfig(ctx, restConfig) + if err != nil { + return nil, err + } + + // Create device plugin detector + var deviceDetector *external_dp.DevicePluginDetector + if os.Getenv(constants.HypervisorDetectUsedGPUEnv) == constants.TrueStringValue { + checkpointPath := os.Getenv(constants.HypervisorKubeletCheckpointPathEnv) + // Create adapter for kubelet client to match interface + deviceDetector, err = external_dp.NewDevicePluginDetector(ctx, checkpointPath, apiClient, restConfig) + if err != nil { + return nil, err + } + } + + return &KubeletBackend{ + ctx: ctx, + deviceController: deviceController, + workerController: workerController, + podCacher: podCacher, + deviceDetector: deviceDetector, + apiClient: apiClient, + nodeName: nodeName, + workers: make(map[string]*api.WorkerInfo), + subscribers: make(map[string]struct{}), + }, nil +} + +func (b *KubeletBackend) Start() error { + if err := b.podCacher.Start(); err != nil { + return err + } + klog.Info("Kubelet client started, watching pods") + + // Create and start device plugin + b.devicePlugins = NewDevicePlugins(b.ctx, b.deviceController, b.workerController, b.podCacher) + for _, devicePlugin := range b.devicePlugins { + if err := devicePlugin.Start(); err != nil { + return err + } + } + klog.Infof("Device plugins started and registered with kubelet") + + // Start device plugin detector to watch external device plugins + if b.deviceDetector != nil { + if err := b.deviceDetector.Start(); err != nil { + klog.Warningf("Failed to start device plugin detector: %v", err) + } else { + klog.Info("Device plugin detector started") + } + } + return nil +} + +func (b *KubeletBackend) Stop() error { + if b.devicePlugins != nil { + for i, devicePlugin := range b.devicePlugins { + if err := devicePlugin.Stop(); err != nil { + klog.Errorf("Failed to stop device plugin %d: %v", i, err) + } + } + } + + if b.deviceDetector != nil { + b.deviceDetector.Stop() + } + + if b.podCacher != nil { + for subscriberID := range b.subscribers { + b.podCacher.UnregisterWorkerInfoSubscriber(subscriberID) + } + b.subscribers = make(map[string]struct{}) + b.podCacher.Stop() + } + + return nil +} + +// RegisterWorkerUpdateHandler registers a handler for worker updates +func (b *KubeletBackend) RegisterWorkerUpdateHandler(handler framework.WorkerChangeHandler) error { + b.workerHandler = &handler + + // Create a channel bridge to convert channel messages to handler calls + workerCh := make(chan *api.WorkerInfo, 16) + subscriberID := uuid.NewString() + b.podCacher.RegisterWorkerInfoSubscriber(subscriberID, workerCh) + b.subscribers[subscriberID] = struct{}{} + + // Start bridge goroutine + go func() { + defer func() { + b.podCacher.UnregisterWorkerInfoSubscriber(subscriberID) + delete(b.subscribers, subscriberID) + }() + + for { + select { + case <-b.ctx.Done(): + return + case worker, ok := <-workerCh: + if !ok { + return + } + if worker == nil { + continue + } + + // Determine if this is add, update, or remove + b.workersMu.Lock() + oldWorker, exists := b.workers[worker.WorkerUID] + + if worker.DeletedAt > 0 { + // Worker was deleted + if exists && handler.OnRemove != nil { + handler.OnRemove(worker) + } + delete(b.workers, worker.WorkerUID) + } else if !exists { + // New worker + b.workers[worker.WorkerUID] = worker + if handler.OnAdd != nil { + handler.OnAdd(worker) + } + } else { + // Updated worker + b.workers[worker.WorkerUID] = worker + if handler.OnUpdate != nil { + handler.OnUpdate(oldWorker, worker) + } + } + b.workersMu.Unlock() + } + } + }() + return nil +} + +func (b *KubeletBackend) StartWorker(worker *api.WorkerInfo) error { + klog.Warningf("StartWorker not implemented, should be managed by operator") + return nil +} + +func (b *KubeletBackend) StopWorker(workerUID string) error { + klog.Warningf("StopWorker not implemented, should be managed by operator") + return nil +} + +func (b *KubeletBackend) GetProcessMappingInfo(workerUID string, hostPID uint32) (*framework.ProcessMappingInfo, error) { + return GetWorkerInfoFromHostPID(hostPID, workerUID) +} + +func (b *KubeletBackend) GetDeviceChangeHandler() framework.DeviceChangeHandler { + return framework.DeviceChangeHandler{ + OnAdd: func(device *api.DeviceInfo) { + if err := b.apiClient.CreateOrUpdateGPU(b.nodeName, device.UUID, + func(gpuNode *tfv1.GPUNode, gpu *tfv1.GPU) error { + return b.mutateGPUResourceState(device, gpuNode, gpu) + }); err != nil { + klog.Errorf("Failed to create or update GPU: %v", err) + } else { + klog.Infof("Device added: %s", device.UUID) + } + klog.Infof("Device added: %s", device.UUID) + }, + OnRemove: func(device *api.DeviceInfo) { + if err := b.apiClient.DeleteGPU(device.UUID); err != nil { + klog.Errorf("Failed to delete GPU: %v", err) + } else { + klog.Infof("Device removed: %s", device.UUID) + } + }, + OnUpdate: func(oldDevice, newDevice *api.DeviceInfo) { + if err := b.apiClient.CreateOrUpdateGPU(b.nodeName, newDevice.UUID, + func(gpuNode *tfv1.GPUNode, gpu *tfv1.GPU) error { + return b.mutateGPUResourceState(newDevice, gpuNode, gpu) + }); err != nil { + klog.Errorf("Failed to update GPU: %v", err) + } else { + klog.Infof("Device updated: %s", newDevice.UUID) + } + }, + OnDiscoveryComplete: func(nodeInfo *api.NodeInfo) { + if err := b.apiClient.UpdateGPUNodeStatus(nodeInfo); err != nil { + klog.Errorf("Failed to update GPUNode status: %v", err) + } + }, + } +} + +func (b *KubeletBackend) ListWorkers() []*api.WorkerInfo { + b.workersMu.RLock() + defer b.workersMu.RUnlock() + return lo.Values(b.workers) +} + +func (b *KubeletBackend) mutateGPUResourceState(device *api.DeviceInfo, gpuNode *tfv1.GPUNode, gpu *tfv1.GPU) error { + // Set metadata fields + gpu.Labels = map[string]string{ + constants.LabelKeyOwner: gpuNode.Name, + constants.GpuPoolKey: gpuNode.OwnerReferences[0].Name, + } + gpu.Annotations = map[string]string{ + constants.LastSyncTimeAnnotationKey: time.Now().Format(time.RFC3339), + } + + if !metav1.IsControlledBy(gpu, gpuNode) { + // Create a new controller ref. + gvk, err := apiutil.GVKForObject(gpuNode, scheme) + if err != nil { + return err + } + ref := metav1.OwnerReference{ + APIVersion: gvk.GroupVersion().String(), + Kind: gvk.Kind, + Name: gpuNode.GetName(), + UID: gpuNode.GetUID(), + BlockOwnerDeletion: ptr.To(true), + Controller: ptr.To(true), + } + gpu.OwnerReferences = []metav1.OwnerReference{ref} + } + + // Set status fields + gpu.Status.Capacity = &tfv1.Resource{ + Vram: resource.MustParse(fmt.Sprintf("%dMi", device.TotalMemoryBytes/1024/1024)), + Tflops: resource.MustParse(fmt.Sprintf("%f", device.MaxTflops)), + } + gpu.Status.UUID = device.UUID + gpu.Status.GPUModel = device.Model + gpu.Status.Index = ptr.To(device.Index) + gpu.Status.Vendor = device.Vendor + gpu.Status.NUMANode = ptr.To(device.NUMANode) + gpu.Status.NodeSelector = map[string]string{ + constants.KubernetesHostNameLabel: b.nodeName, + } + if gpu.Status.Available == nil { + gpu.Status.Available = gpu.Status.Capacity.DeepCopy() + } + if gpu.Status.UsedBy == "" { + gpu.Status.UsedBy = tfv1.UsedByTensorFusion + } + if gpu.Status.Phase == "" { + gpu.Status.Phase = tfv1.TensorFusionGPUPhasePending + } + return nil +} diff --git a/internal/hypervisor/backend/kubernetes/ns_mapper.go b/internal/hypervisor/backend/kubernetes/ns_mapper.go new file mode 100644 index 00000000..3d128af5 --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/ns_mapper.go @@ -0,0 +1,112 @@ +package kubernetes + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" +) + +// GetWorkerInfoFromHostPID extracts worker information from a process's environment +// by reading /proc/{hostPID}/environ and /proc/{hostPID}/status +// workerUID (podUID) is provided as input parameter, not extracted from environment +func GetWorkerInfoFromHostPID(hostPID uint32, workerUID string) (*framework.ProcessMappingInfo, error) { + procDir := fmt.Sprintf("/proc/%d", hostPID) + + // Check if process exists + if _, err := os.Stat(procDir); os.IsNotExist(err) { + return nil, fmt.Errorf("process %d does not exist", hostPID) + } + + // Read environment variables from /proc/{pid}/environ + envPath := filepath.Join(procDir, "environ") + envData, err := os.ReadFile(envPath) + if err != nil { + return nil, fmt.Errorf("failed to read environment from %s: %w", envPath, err) + } + + // Parse environment variables (null-separated) + envMap := make(map[string]string) + envPairs := strings.Split(string(envData), "\x00") + for _, pair := range envPairs { + if pair == "" { + continue + } + parts := strings.SplitN(pair, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + // Extract Kubernetes pod information from environment (injected by webhook) + podName := envMap[constants.PodNameEnv] + namespace := envMap[constants.PodNamespaceEnv] + containerName := envMap[constants.ContainerNameEnv] + + // Read container PID (namespaced PID) from /proc/{pid}/status + containerPID, err := getContainerPIDFromStatus(procDir) + if err != nil { + // If we can't get container PID, use host PID as fallback + containerPID = hostPID + } + + // Validate required fields (must exist as they are injected by webhook) + if podName == "" { + return nil, fmt.Errorf("POD_NAME not found in environment for process %d", hostPID) + } + if namespace == "" { + return nil, fmt.Errorf("POD_NAMESPACE not found in environment for process %d", hostPID) + } + if containerName == "" { + return nil, fmt.Errorf("CONTAINER_NAME not found in environment for process %d", hostPID) + } + + return &framework.ProcessMappingInfo{ + GuestID: fmt.Sprintf("%s_%s_%s", namespace, podName, containerName), + HostPID: hostPID, + GuestPID: containerPID, + }, nil +} + +// getContainerPIDFromStatus reads the container PID (NSpid) from /proc/{pid}/status +func getContainerPIDFromStatus(procDir string) (uint32, error) { + statusPath := filepath.Join(procDir, "status") + file, err := os.Open(statusPath) + if err != nil { + return 0, fmt.Errorf("failed to open status file: %w", err) + } + defer func() { + _ = file.Close() + }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "NSpid:") { + // NSpid format: "NSpid: 1234 5678" (host PID, then container PID) + // or "NSpid: 1234" (if same namespace) + fields := strings.Fields(line) + if len(fields) >= 2 { + // The last field is typically the container PID + // If there are multiple PIDs, the last one is in the innermost namespace + pidStr := fields[len(fields)-1] + pid, err := strconv.ParseUint(pidStr, 10, 32) + if err != nil { + return 0, fmt.Errorf("failed to parse container PID: %w", err) + } + return uint32(pid), nil + } + } + } + + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("failed to read status file: %w", err) + } + + return 0, fmt.Errorf("NSpid not found in status file") +} diff --git a/internal/hypervisor/backend/kubernetes/pod_cache.go b/internal/hypervisor/backend/kubernetes/pod_cache.go new file mode 100644 index 00000000..54dec7ad --- /dev/null +++ b/internal/hypervisor/backend/kubernetes/pod_cache.go @@ -0,0 +1,429 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubernetes + +import ( + "context" + "fmt" + "strconv" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/utils" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" +) + +// workerInfoSubscriber represents a subscriber waiting for worker info for a specific pod index +type workerInfoSubscriber struct { + ch chan *api.WorkerInfo +} + +const subscriberTimeout = 10 * time.Minute + +// PodCacheManager manages pod watching and worker information extraction +type PodCacheManager struct { + ctx context.Context + clientset *kubernetes.Clientset + restConfig *rest.Config + nodeName string + + mu sync.RWMutex + cachedPod map[string]*corev1.Pod // key: pod UID + indexToWorkerInfo map[int]*api.WorkerInfo // key: pod index annotation + + stopCh chan struct{} + workerChangedCh chan struct{} + + // Pub/Sub mechanism for waiting on worker info by index + subscribersMu sync.RWMutex + indexSubscribers map[int]map[*workerInfoSubscriber]struct{} // key: pod index + + podSubscribersMu sync.RWMutex + podSubscribers map[string]chan<- *api.WorkerInfo +} + +// NewPodCacheManager creates a new pod cache manager +func NewPodCacheManager(ctx context.Context, restConfig *rest.Config, nodeName string) (*PodCacheManager, error) { + clientset, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes clientset: %w", err) + } + + kc := &PodCacheManager{ + ctx: ctx, + clientset: clientset, + restConfig: restConfig, + nodeName: nodeName, + cachedPod: make(map[string]*corev1.Pod, 32), + indexToWorkerInfo: make(map[int]*api.WorkerInfo, 32), + stopCh: make(chan struct{}), + workerChangedCh: make(chan struct{}, 1), + indexSubscribers: make(map[int]map[*workerInfoSubscriber]struct{}), + podSubscribers: make(map[string]chan<- *api.WorkerInfo), + } + + // Start the Pub/Sub event bus goroutine + go kc.runWorkerChangeEventBus() + + return kc, nil +} + +// Start starts watching pods on this node +func (kc *PodCacheManager) Start() error { + // Create a field selector to watch only pods on this node + fieldSelector := fields.OneTermEqualSelector("spec.nodeName", kc.nodeName).String() + + // Create a label selector for pods with tensor-fusion.ai/enabled=true + labelSelector := labels.Set{ + constants.TensorFusionEnabledLabelKey: constants.TrueStringValue, + }.AsSelector().String() + + // Create list watcher + lw := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + options.FieldSelector = fieldSelector + options.LabelSelector = labelSelector + return kc.clientset.CoreV1().Pods(metav1.NamespaceAll).List(kc.ctx, options) + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + options.FieldSelector = fieldSelector + options.LabelSelector = labelSelector + return kc.clientset.CoreV1().Pods(metav1.NamespaceAll).Watch(kc.ctx, options) + }, + } + + // Create informer + _, controller := cache.NewInformerWithOptions(cache.InformerOptions{ + ListerWatcher: lw, + ObjectType: &corev1.Pod{}, + ResyncPeriod: 0, + Handler: cache.ResourceEventHandlerFuncs{ + AddFunc: kc.onPodAdd, + UpdateFunc: kc.onPodUpdate, + DeleteFunc: kc.onPodDelete, + }, + }) + + // Start the informer + go controller.Run(kc.stopCh) + + klog.Infof("Started watching pods on node %s with label %s=%s", kc.nodeName, constants.TensorFusionEnabledLabelKey, constants.TrueStringValue) + return nil +} + +// Stop stops the pod cache manager +func (kc *PodCacheManager) Stop() { + klog.Info("Stopping pod cache manager") + close(kc.stopCh) +} + +// onPodAdd handles pod addition events +func (kc *PodCacheManager) onPodAdd(obj any) { + pod := obj.(*corev1.Pod) + kc.mu.Lock() + defer kc.mu.Unlock() + kc.cachedPod[string(pod.UID)] = pod + + workerInfo, index, err := kc.extractWorkerInfo(pod) + if err != nil { + klog.Error(err, "Failed to extract worker info for pod", "pod", pod.Name, "namespace", pod.Namespace) + return + } + if index != "" { + podIndex, err := strconv.Atoi(index) + if err != nil { + klog.Error(err, "Failed to convert node index to int", "node index", index) + return + } + // Make sure indexToWorker only contains device allocating pods (Pod is pending and index was assigned) + if workerInfo.Status == api.WorkerStatusDeviceAllocating { + kc.indexToWorkerInfo[podIndex] = workerInfo + } + } + kc.notifyWorkerChanged(workerInfo) + klog.Infof("Pod %s/%s added to pending, state: %s node index: %s", pod.Namespace, pod.Name, workerInfo.Status, index) +} + +// onPodUpdate handles pod update events +func (kc *PodCacheManager) onPodUpdate(oldObj, newObj any) { + newPod := newObj.(*corev1.Pod) + kc.onPodAdd(newPod) +} + +// onPodDelete handles pod deletion events +func (kc *PodCacheManager) onPodDelete(obj any) { + pod, ok := obj.(*corev1.Pod) + if !ok { + // Handle deleted final state unknown + tombstone, ok := obj.(cache.DeletedFinalStateUnknown) + if !ok { + klog.Errorf("Unexpected object type, can not parsed to Pod: %T", obj) + return + } + pod, ok = tombstone.Obj.(*corev1.Pod) + if !ok { + klog.Errorf("Tombstone contained object that is not a pod: %T", tombstone.Obj) + return + } + } + + kc.mu.Lock() + defer kc.mu.Unlock() + podUID := string(pod.UID) + delete(kc.cachedPod, podUID) + workerInfo, index, err := kc.extractWorkerInfo(pod) + if err != nil { + klog.Error(err, "Failed to extract worker info for pod", "pod", pod.Name, "namespace", pod.Namespace) + return + } + workerInfo.DeletedAt = time.Now().UnixMilli() + kc.notifyWorkerChanged(workerInfo) + + if index != "" { + podIndex, err := strconv.Atoi(index) + if err != nil { + klog.Error(err, "Failed to convert node index to int", "node index", index) + return + } + delete(kc.indexToWorkerInfo, podIndex) + } + klog.Infof("Pod %s/%s (UID: %s) deleted. state: %s node index: %s", pod.Namespace, pod.Name, pod.UID, workerInfo.Status, index) + +} + +// runWorkerChangeEventBus runs a standalone goroutine that consumes workerChangedCh +// and notifies all subscribers when worker information changes for their requested index +func (kc *PodCacheManager) runWorkerChangeEventBus() { + for { + select { + case <-kc.stopCh: + return + case <-kc.ctx.Done(): + return + case <-kc.workerChangedCh: + // Worker information changed, check if any subscribers are waiting + kc.notifySubscribers() + } + } +} + +// notifySubscribers checks all subscribers and sends worker info if available +func (kc *PodCacheManager) notifySubscribers() { + kc.subscribersMu.Lock() + defer kc.subscribersMu.Unlock() + + kc.mu.RLock() + defer kc.mu.RUnlock() + + // Iterate through all subscribed indices + for podIndex, subs := range kc.indexSubscribers { + // Check if worker info is now available for this index + if workerInfo, exists := kc.indexToWorkerInfo[podIndex]; exists && workerInfo != nil { + // Notify all subscribers for this index + for sub := range subs { + select { + case sub.ch <- workerInfo: + // Successfully sent, remove subscriber + delete(subs, sub) + close(sub.ch) + default: + // Channel is full or closed, skip + } + } + // Clean up empty subscriber set + if len(subs) == 0 { + delete(kc.indexSubscribers, podIndex) + } + } + } +} + +func (kc *PodCacheManager) notifyWorkerChanged(workerInfo *api.WorkerInfo) { + kc.podSubscribersMu.Lock() + defer kc.podSubscribersMu.Unlock() + for _, subscriber := range kc.podSubscribers { + select { + case subscriber <- workerInfo: + default: + klog.Warningf("Channel is full, skipping notification for worker change %s", workerInfo.WorkerUID) + } + } +} + +func (kc *PodCacheManager) RegisterWorkerInfoSubscriber(name string, subscriber chan<- *api.WorkerInfo) { + kc.podSubscribersMu.Lock() + defer kc.podSubscribersMu.Unlock() + if _, exists := kc.podSubscribers[name]; exists { + klog.Errorf("Worker info subscriber for %s already registered", name) + return + } + kc.podSubscribers[name] = subscriber + klog.Infof("Registered worker info subscriber for %s", name) +} + +func (kc *PodCacheManager) UnregisterWorkerInfoSubscriber(name string) { + kc.podSubscribersMu.Lock() + defer kc.podSubscribersMu.Unlock() + delete(kc.podSubscribers, name) + klog.Infof("Unregistered worker info subscriber for %s", name) +} + +// GetWorkerInfoForAllocationByIndex finds a pod by its index annotation and extracts worker info +// It implements a Pub/Sub pattern where callers subscribe to worker info changes for a specific pod index. +// If worker info is already available, it returns immediately. Otherwise, it waits for up to 10 minutes +// for the worker info to become available. +func (kc *PodCacheManager) GetWorkerInfoForAllocationByIndex(podIndex int) (*api.WorkerInfo, error) { + kc.subscribersMu.Lock() + defer kc.subscribersMu.Unlock() + // First, check if worker info is already available (fast path) + + kc.mu.RLock() + if workerInfo, exists := kc.indexToWorkerInfo[podIndex]; exists && workerInfo != nil { + kc.mu.RUnlock() + return workerInfo, nil + } + kc.mu.RUnlock() + + // Worker info not available yet, subscribe to changes + subscriber := &workerInfoSubscriber{ + ch: make(chan *api.WorkerInfo, 1), + } + + // Register subscriber + if _, exists := kc.indexSubscribers[podIndex]; !exists { + kc.indexSubscribers[podIndex] = make(map[*workerInfoSubscriber]struct{}) + } + kc.indexSubscribers[podIndex][subscriber] = struct{}{} + + timeoutTimer := time.NewTimer(subscriberTimeout) + defer timeoutTimer.Stop() + + select { + case workerInfo := <-subscriber.ch: + // Worker info received + if workerInfo == nil { + return nil, fmt.Errorf("worker info channel closed for pod index %d", podIndex) + } + return workerInfo, nil + case <-timeoutTimer.C: + // Timeout reached + kc.unregisterSubscriber(podIndex, subscriber) + return nil, fmt.Errorf("timeout waiting for worker info for pod index %d after %v", podIndex, subscriberTimeout) + case <-kc.ctx.Done(): + // Context cancelled + kc.unregisterSubscriber(podIndex, subscriber) + return nil, fmt.Errorf("context cancelled while waiting for worker info for pod index %d", podIndex) + case <-kc.stopCh: + // Pod cache manager stopped + kc.unregisterSubscriber(podIndex, subscriber) + return nil, fmt.Errorf("pod cache manager stopped while waiting for worker info for pod index %d", podIndex) + } +} + +// unregisterSubscriber removes a subscriber from the subscribers map +func (kc *PodCacheManager) unregisterSubscriber(podIndex int, sub *workerInfoSubscriber) { + kc.subscribersMu.Lock() + defer kc.subscribersMu.Unlock() + + if subs, exists := kc.indexSubscribers[podIndex]; exists { + if _, stillSubscribed := subs[sub]; stillSubscribed { + delete(subs, sub) + // Close channel - safe because we just removed it from map, so event bus won't close it + close(sub.ch) + } + // Clean up empty subscriber set + if len(subs) == 0 { + delete(kc.indexSubscribers, podIndex) + } + } +} + +// GetPodByUID retrieves a pod from the cache by its UID +func (kc *PodCacheManager) GetPodByUID(podUID string) *corev1.Pod { + kc.mu.RLock() + defer kc.mu.RUnlock() + return kc.cachedPod[podUID] +} + +// extractWorkerInfo extracts worker information from pod annotations using the common utility function +func (kc *PodCacheManager) extractWorkerInfo(pod *corev1.Pod) (*api.WorkerInfo, string, error) { + // Use common utility function to extract pod worker info + index := "" + allocRequest, msg, err := utils.ComposeAllocationRequest(kc.ctx, pod) + if err != nil { + klog.Error(err, "Failed to compose allocation request for existing worker Pod, annotation may not be valid", "pod", pod.Name, "msg", msg) + return nil, index, err + } + + status := api.WorkerStatusPending + if utils.IsPodRunning(pod) { + status = api.WorkerStatusRunning + } else if utils.IsPodStopped(pod) { + status = api.WorkerStatusTerminated + } else { + // Must be PodPending state, check if can allocate device (use annotation index to check if index-lock released) + if nodeIndex, exists := pod.Annotations[constants.PodIndexAnnotation]; exists { + index = nodeIndex + status = api.WorkerStatusDeviceAllocating + } + } + info := &api.WorkerInfo{ + WorkerUID: string(pod.UID), + Status: status, + WorkerName: pod.Name, + Namespace: pod.Namespace, + + AllocatedDevices: allocRequest.GPUNames, + IsolationMode: allocRequest.Isolation, + QoS: allocRequest.QoS, + + Requests: allocRequest.Request, + Limits: allocRequest.Limit, + + PartitionTemplateID: allocRequest.PartitionTemplateID, + + WorkloadName: allocRequest.WorkloadNameNamespace.Name, + WorkloadNamespace: allocRequest.WorkloadNameNamespace.Namespace, + + Labels: pod.Labels, + Annotations: pod.Annotations, + } + return info, index, nil +} + +// GetAllPods returns all pods currently in the cache +func (kc *PodCacheManager) GetAllPods() map[string]*corev1.Pod { + kc.mu.RLock() + defer kc.mu.RUnlock() + + result := make(map[string]*corev1.Pod, len(kc.cachedPod)) + for k, v := range kc.cachedPod { + result[k] = v + } + return result +} diff --git a/internal/hypervisor/backend/single_node/filestate.go b/internal/hypervisor/backend/single_node/filestate.go new file mode 100644 index 00000000..2b4ec19a --- /dev/null +++ b/internal/hypervisor/backend/single_node/filestate.go @@ -0,0 +1,198 @@ +package single_node + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +const ( + defaultStateDir = "/tmp/tensor-fusion-state" + workersFile = "workers.json" + devicesFile = "devices.json" +) + +// FileStateManager manages file-based state persistence +type FileStateManager struct { + stateDir string + mu sync.RWMutex +} + +// NewFileStateManager creates a new file state manager +func NewFileStateManager(stateDir string) *FileStateManager { + if stateDir == "" { + stateDir = defaultStateDir + } + return &FileStateManager{ + stateDir: stateDir, + } +} + +// ensureStateDir ensures the state directory exists +func (fsm *FileStateManager) ensureStateDir() error { + return os.MkdirAll(fsm.stateDir, 0755) +} + +// SaveWorkers saves workers to JSON file +func (fsm *FileStateManager) SaveWorkers(workers map[string]*api.WorkerInfo) error { + fsm.mu.Lock() + defer fsm.mu.Unlock() + + if err := fsm.ensureStateDir(); err != nil { + return err + } + + // Convert map to slice for JSON + workersList := make([]*api.WorkerInfo, 0, len(workers)) + for _, worker := range workers { + workersList = append(workersList, worker) + } + + data, err := json.MarshalIndent(workersList, "", " ") + if err != nil { + return err + } + + filePath := filepath.Join(fsm.stateDir, workersFile) + tmpPath := filePath + ".tmp" + if err := os.WriteFile(tmpPath, data, 0644); err != nil { + return err + } + + return os.Rename(tmpPath, filePath) +} + +// LoadWorkers loads workers from JSON file +func (fsm *FileStateManager) LoadWorkers() (map[string]*api.WorkerInfo, error) { + fsm.mu.RLock() + defer fsm.mu.RUnlock() + + filePath := filepath.Join(fsm.stateDir, workersFile) + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return make(map[string]*api.WorkerInfo), nil + } + return nil, err + } + + var workersList []*api.WorkerInfo + if err := json.Unmarshal(data, &workersList); err != nil { + return nil, err + } + + workers := make(map[string]*api.WorkerInfo, len(workersList)) + for _, worker := range workersList { + if worker != nil { + workers[worker.WorkerUID] = worker + } + } + + return workers, nil +} + +// SaveDevices saves devices to JSON file +func (fsm *FileStateManager) SaveDevices(devices map[string]*api.DeviceInfo) error { + fsm.mu.Lock() + defer fsm.mu.Unlock() + + if err := fsm.ensureStateDir(); err != nil { + return err + } + + // Convert map to slice for JSON + devicesList := make([]*api.DeviceInfo, 0, len(devices)) + for _, device := range devices { + devicesList = append(devicesList, device) + } + + data, err := json.MarshalIndent(devicesList, "", " ") + if err != nil { + return err + } + + filePath := filepath.Join(fsm.stateDir, devicesFile) + tmpPath := filePath + ".tmp" + if err := os.WriteFile(tmpPath, data, 0644); err != nil { + return err + } + + return os.Rename(tmpPath, filePath) +} + +// LoadDevices loads devices from JSON file +func (fsm *FileStateManager) LoadDevices() (map[string]*api.DeviceInfo, error) { + fsm.mu.RLock() + defer fsm.mu.RUnlock() + + filePath := filepath.Join(fsm.stateDir, devicesFile) + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return make(map[string]*api.DeviceInfo), nil + } + return nil, err + } + + var devicesList []*api.DeviceInfo + if err := json.Unmarshal(data, &devicesList); err != nil { + return nil, err + } + + devices := make(map[string]*api.DeviceInfo, len(devicesList)) + for _, device := range devicesList { + if device != nil { + devices[device.UUID] = device + } + } + + return devices, nil +} + +// AddWorker adds a worker to the state +func (fsm *FileStateManager) AddWorker(worker *api.WorkerInfo) error { + workers, err := fsm.LoadWorkers() + if err != nil { + return err + } + workers[worker.WorkerUID] = worker + return fsm.SaveWorkers(workers) +} + +// RemoveWorker removes a worker from the state +func (fsm *FileStateManager) RemoveWorker(workerUID string) error { + workers, err := fsm.LoadWorkers() + if err != nil { + return err + } + delete(workers, workerUID) + return fsm.SaveWorkers(workers) +} + +// AddDevice adds a device to the state +func (fsm *FileStateManager) AddDevice(device *api.DeviceInfo) error { + devices, err := fsm.LoadDevices() + if err != nil { + return err + } + devices[device.UUID] = device + return fsm.SaveDevices(devices) +} + +// RemoveDevice removes a device from the state +func (fsm *FileStateManager) RemoveDevice(deviceUUID string) error { + devices, err := fsm.LoadDevices() + if err != nil { + return err + } + delete(devices, deviceUUID) + return fsm.SaveDevices(devices) +} + +// UpdateDevice updates a device in the state +func (fsm *FileStateManager) UpdateDevice(device *api.DeviceInfo) error { + return fsm.AddDevice(device) +} diff --git a/internal/hypervisor/backend/single_node/single_node_backend.go b/internal/hypervisor/backend/single_node/single_node_backend.go new file mode 100644 index 00000000..adf53a48 --- /dev/null +++ b/internal/hypervisor/backend/single_node/single_node_backend.go @@ -0,0 +1,299 @@ +package single_node + +import ( + "context" + "os" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/google/uuid" + "github.com/samber/lo" + "k8s.io/klog/v2" +) + +type SingleNodeBackend struct { + ctx context.Context + deviceController framework.DeviceController + fileState *FileStateManager + mu sync.RWMutex + workers map[string]*api.WorkerInfo + stopCh chan struct{} + stopOnce sync.Once + + // Worker watching + subscribersMu sync.RWMutex + subscribers map[string]chan *api.WorkerInfo + workerHandler *framework.WorkerChangeHandler +} + +func NewSingleNodeBackend(ctx context.Context, deviceController framework.DeviceController) *SingleNodeBackend { + stateDir := os.Getenv("TENSOR_FUSION_STATE_DIR") + if stateDir == "" { + stateDir = "/tmp/tensor-fusion-state" + } + return &SingleNodeBackend{ + ctx: ctx, + deviceController: deviceController, + fileState: NewFileStateManager(stateDir), + workers: make(map[string]*api.WorkerInfo), + stopCh: make(chan struct{}), + subscribers: make(map[string]chan *api.WorkerInfo), + } +} + +func (b *SingleNodeBackend) Start() error { + // Load initial state from files + if err := b.loadState(); err != nil { + klog.Warningf("Failed to load initial state: %v", err) + } + + // Start periodic worker discovery + go b.periodicWorkerDiscovery() + return nil +} + +func (b *SingleNodeBackend) Stop() error { + // Use sync.Once to ensure stopCh is only closed once + b.stopOnce.Do(func() { + close(b.stopCh) + }) + + // Close all subscriber channels + b.subscribersMu.Lock() + for id, ch := range b.subscribers { + close(ch) + delete(b.subscribers, id) + } + b.subscribersMu.Unlock() + + return nil +} + +// loadState loads workers and devices from file state +func (b *SingleNodeBackend) loadState() error { + workers, err := b.fileState.LoadWorkers() + if err != nil { + return err + } + + b.mu.Lock() + b.workers = workers + b.mu.Unlock() + + return nil +} + +// discoverWorkers discovers workers from file state and notifies subscribers of changes +func (b *SingleNodeBackend) discoverWorkers() { + workers, err := b.fileState.LoadWorkers() + if err != nil { + klog.Errorf("Failed to load workers from file state: %v", err) + return + } + + b.mu.Lock() + // Find new and updated workers + for uid, worker := range workers { + oldWorker, exists := b.workers[uid] + if !exists { + // New worker + b.workers[uid] = worker + b.mu.Unlock() + b.notifySubscribers(worker) + b.mu.Lock() + } else if !workersEqual(oldWorker, worker) { + // Updated worker + b.workers[uid] = worker + b.mu.Unlock() + b.notifySubscribers(worker) + b.mu.Lock() + } + } + + // Find removed workers + for uid := range b.workers { + if _, exists := workers[uid]; !exists { + delete(b.workers, uid) + } + } + b.mu.Unlock() +} + +// notifySubscribers notifies all subscribers of a worker change +func (b *SingleNodeBackend) notifySubscribers(worker *api.WorkerInfo) { + b.subscribersMu.RLock() + defer b.subscribersMu.RUnlock() + + for _, ch := range b.subscribers { + select { + case ch <- worker: + default: + klog.Warningf("Channel is full, skipping notification for worker change %s", worker.WorkerUID) + } + } +} + +// workersEqual checks if two workers are equal (simple comparison) +func workersEqual(w1, w2 *api.WorkerInfo) bool { + if w1 == nil && w2 == nil { + return true + } + if w1 == nil || w2 == nil { + return false + } + return w1.WorkerUID == w2.WorkerUID && + w1.Status == w2.Status && + len(w1.AllocatedDevices) == len(w2.AllocatedDevices) +} + +func (b *SingleNodeBackend) periodicWorkerDiscovery() { + // Run initial discovery immediately + b.discoverWorkers() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-b.stopCh: + return + case <-b.ctx.Done(): + return + case <-ticker.C: + b.discoverWorkers() + } + } +} + +func (b *SingleNodeBackend) RegisterWorkerUpdateHandler(handler framework.WorkerChangeHandler) error { + b.workerHandler = &handler + + // Create channel for this subscriber + workerCh := make(chan *api.WorkerInfo, 16) + subscriberID := uuid.NewString() + + // Register subscriber + b.subscribersMu.Lock() + b.subscribers[subscriberID] = workerCh + b.subscribersMu.Unlock() + + // Start bridge goroutine to convert channel messages to handler calls + go func() { + defer func() { + b.subscribersMu.Lock() + delete(b.subscribers, subscriberID) + b.subscribersMu.Unlock() + }() + + for { + select { + case <-b.ctx.Done(): + return + case <-b.stopCh: + return + case worker, ok := <-workerCh: + if !ok { + return + } + if worker == nil { + continue + } + + // Determine if this is add, update, or remove + b.mu.Lock() + oldWorker, exists := b.workers[worker.WorkerUID] + + if worker.DeletedAt > 0 { + // Worker was deleted + if exists && handler.OnRemove != nil { + handler.OnRemove(worker) + } + delete(b.workers, worker.WorkerUID) + } else if !exists { + // New worker + b.workers[worker.WorkerUID] = worker + if handler.OnAdd != nil { + handler.OnAdd(worker) + } + } else { + // Updated worker + b.workers[worker.WorkerUID] = worker + if handler.OnUpdate != nil { + handler.OnUpdate(oldWorker, worker) + } + } + b.mu.Unlock() + } + } + }() + return nil +} + +func (b *SingleNodeBackend) StartWorker(worker *api.WorkerInfo) error { + if err := b.fileState.AddWorker(worker); err != nil { + return err + } + + b.mu.Lock() + b.workers[worker.WorkerUID] = worker + b.mu.Unlock() + + b.notifySubscribers(worker) + klog.Infof("Worker started: %s", worker.WorkerUID) + return nil +} + +func (b *SingleNodeBackend) StopWorker(workerUID string) error { + if err := b.fileState.RemoveWorker(workerUID); err != nil { + return err + } + + b.mu.Lock() + delete(b.workers, workerUID) + b.mu.Unlock() + + klog.Infof("Worker stopped: %s", workerUID) + return nil +} + +func (b *SingleNodeBackend) GetProcessMappingInfo(workerUID string, hostPID uint32) (*framework.ProcessMappingInfo, error) { + return &framework.ProcessMappingInfo{ + GuestID: workerUID, + HostPID: hostPID, + GuestPID: hostPID, + }, nil +} + +func (b *SingleNodeBackend) GetDeviceChangeHandler() framework.DeviceChangeHandler { + return framework.DeviceChangeHandler{ + OnAdd: func(device *api.DeviceInfo) { + if err := b.fileState.AddDevice(device); err != nil { + klog.Errorf("Failed to save device to file state: %v", err) + } else { + klog.Infof("Device added: %s", device.UUID) + } + }, + OnRemove: func(device *api.DeviceInfo) { + if err := b.fileState.RemoveDevice(device.UUID); err != nil { + klog.Errorf("Failed to remove device from file state: %v", err) + } else { + klog.Infof("Device removed: %s", device.UUID) + } + }, + OnUpdate: func(oldDevice, newDevice *api.DeviceInfo) { + if err := b.fileState.UpdateDevice(newDevice); err != nil { + klog.Errorf("Failed to update device in file state: %v", err) + } else { + klog.Infof("Device updated: %s", newDevice.UUID) + } + }, + } +} + +func (b *SingleNodeBackend) ListWorkers() []*api.WorkerInfo { + b.mu.RLock() + defer b.mu.RUnlock() + return lo.Values(b.workers) +} diff --git a/internal/hypervisor/device/accelerator.go b/internal/hypervisor/device/accelerator.go new file mode 100644 index 00000000..df14d5ef --- /dev/null +++ b/internal/hypervisor/device/accelerator.go @@ -0,0 +1,460 @@ +package device + +/* +#cgo CFLAGS: -I../../../provider +#cgo LDFLAGS: -ldl +#include "../../../provider/accelerator.h" +#include +#include +#include +#include +#include +#include + +// Forward declarations from wrapper.c +extern int loadAcceleratorLibrary(const char* libPath); +extern void unloadAcceleratorLibrary(void); +extern Result GetDeviceCountWrapper(size_t* deviceCount); +extern Result GetAllDevicesWrapper(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount); +extern Result GetPartitionTemplatesWrapper(int32_t deviceIndex, PartitionTemplate* templates, size_t maxCount, size_t* templateCount); +extern bool AssignPartitionWrapper(PartitionAssignment* assignment); +extern bool RemovePartitionWrapper(const char* templateId, const char* deviceUUID); +extern Result SetMemHardLimitWrapper(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes); +extern Result SetComputeUnitHardLimitWrapper(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit); +extern Result GetProcessComputeUtilizationWrapper(ComputeUtilization* utilizations, size_t maxCount, size_t* utilizationCount); +extern Result GetProcessMemoryUtilizationWrapper(MemoryUtilization* utilizations, size_t maxCount, size_t* utilizationCount); +extern Result GetDeviceMetricsWrapper(const char** deviceUUIDArray, size_t deviceCount, DeviceMetrics* metrics, size_t maxExtraMetricsPerDevice); +extern Result GetVendorMountLibsWrapper(Mount* mounts, size_t maxCount, size_t* mountCount); +extern const char* getDlError(void); +*/ +import "C" +import ( + "fmt" + "sync" + "unsafe" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +// AcceleratorInterface provides Go bindings for the C accelerator library +type AcceleratorInterface struct { + libPath string + // deviceProcesses maps device UUID to list of process IDs + deviceProcesses map[string][]string + mu sync.RWMutex + loaded bool +} + +// NewAcceleratorInterface creates a new accelerator interface and loads the library +func NewAcceleratorInterface(libPath string) (*AcceleratorInterface, error) { + accel := &AcceleratorInterface{ + libPath: libPath, + deviceProcesses: make(map[string][]string), + loaded: false, + } + + // Load the library + if err := accel.Load(); err != nil { + return nil, fmt.Errorf("failed to load accelerator library from %s: %w", libPath, err) + } + + return accel, nil +} + +// Load loads the accelerator library dynamically +func (a *AcceleratorInterface) Load() error { + if a.libPath == "" { + return fmt.Errorf("library path is empty") + } + + cLibPath := C.CString(a.libPath) + defer C.free(unsafe.Pointer(cLibPath)) + + result := C.loadAcceleratorLibrary(cLibPath) + if result != 0 { + var errMsg string + if dlErr := C.getDlError(); dlErr != nil { + errMsg = C.GoString(dlErr) + } else { + errMsg = "unknown error" + } + + switch result { + case -1: + return fmt.Errorf("failed to load library: %s", errMsg) + case -2: + return fmt.Errorf("missing required symbols in library: %s", errMsg) + } + return fmt.Errorf("failed to load library (code %d): %s", result, errMsg) + } + + a.loaded = true + return nil +} + +// Close unloads the accelerator library +func (a *AcceleratorInterface) Close() error { + if a.loaded { + C.unloadAcceleratorLibrary() + a.loaded = false + } + return nil +} + +// GetTotalProcessCount returns the total number of processes across all devices +func (a *AcceleratorInterface) GetTotalProcessCount() int { + a.mu.RLock() + defer a.mu.RUnlock() + + total := 0 + for _, processes := range a.deviceProcesses { + total += len(processes) + } + return total +} + +// GetDeviceMetrics retrieves device metrics for the specified device UUIDs +func (a *AcceleratorInterface) GetDeviceMetrics(deviceUUIDs []string) ([]*api.GPUUsageMetrics, error) { + if len(deviceUUIDs) == 0 { + return []*api.GPUUsageMetrics{}, nil + } + + const maxStackDevices = 64 + deviceCount := len(deviceUUIDs) + if deviceCount > maxStackDevices { + deviceCount = maxStackDevices + } + + // Allocate C strings for device UUIDs + cDeviceUUIDs := make([]*C.char, deviceCount) + for i := 0; i < deviceCount; i++ { + cDeviceUUIDs[i] = C.CString(deviceUUIDs[i]) + } + defer func() { + for _, cDeviceUUID := range cDeviceUUIDs { + if cDeviceUUID != nil { + C.free(unsafe.Pointer(cDeviceUUID)) + } + } + }() + + // Convert Go slice to C array pointer + // In CGO, we can directly use the slice's underlying array pointer + var cUUIDArray **C.char + if deviceCount > 0 { + cUUIDArray = (**C.char)(unsafe.Pointer(&cDeviceUUIDs[0])) + } + + // Allocate stack buffer for metrics + const maxExtraMetricsPerDevice = 32 + var cMetrics [maxStackDevices]C.DeviceMetrics + var cExtraMetrics [maxStackDevices][maxExtraMetricsPerDevice]C.ExtraMetric + + // Initialize extraMetrics pointers + for i := 0; i < deviceCount; i++ { + cMetrics[i].extraMetrics = &cExtraMetrics[i][0] + cMetrics[i].extraMetricsCount = 0 + } + + //nolint:staticcheck + result := C.GetDeviceMetricsWrapper(cUUIDArray, C.size_t(deviceCount), &cMetrics[0], C.size_t(maxExtraMetricsPerDevice)) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get device metrics: %d", result) + } + + // Convert C metrics to Go metrics + metrics := make([]*api.GPUUsageMetrics, deviceCount) + for i := 0; i < deviceCount; i++ { + cm := &cMetrics[i] + memoryTotal := uint64(cm.memoryTotalBytes) + memoryUsed := uint64(cm.memoryUsedBytes) + var memoryPercentage float64 + if memoryTotal > 0 { + memoryPercentage = float64(memoryUsed) / float64(memoryTotal) * 100.0 + } + + // Convert extra metrics from C to Go map + extraMetrics := make(map[string]float64, int(cm.extraMetricsCount)+1) + // Always include tensorCoreUsagePercent as it's a standard field + extraMetrics["tensorCoreUsagePercent"] = float64(cm.tensorCoreUsagePercent) + + // Add other extra metrics from C array + if cm.extraMetrics != nil && cm.extraMetricsCount > 0 { + // Convert C pointer to Go slice for indexing + extraMetricsSlice := (*[maxExtraMetricsPerDevice]C.ExtraMetric)(unsafe.Pointer(cm.extraMetrics)) + for j := 0; j < int(cm.extraMetricsCount); j++ { + em := &extraMetricsSlice[j] + key := C.GoString(&em.key[0]) + if key != "" { + extraMetrics[key] = float64(em.value) + } + } + } + + metrics[i] = &api.GPUUsageMetrics{ + DeviceUUID: C.GoString(&cm.deviceUUID[0]), + MemoryBytes: memoryUsed, + MemoryPercentage: memoryPercentage, + ComputePercentage: float64(cm.smActivePercent), + ComputeTflops: 0, // Not available in DeviceMetrics + Rx: float64(cm.pcieRxBytes) / 1024.0, // Convert bytes to KB + Tx: float64(cm.pcieTxBytes) / 1024.0, // Convert bytes to KB + Temperature: float64(cm.temperatureCelsius), + PowerUsage: int64(cm.powerUsageWatts), + ExtraMetrics: extraMetrics, + } + } + + return metrics, nil +} + +// GetAllDevices retrieves all available devices from the accelerator library +func (a *AcceleratorInterface) GetAllDevices() ([]*api.DeviceInfo, error) { + // First, get the device count + var cDeviceCount C.size_t + //nolint:staticcheck + result := C.GetDeviceCountWrapper(&cDeviceCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get device count: %d", result) + } + + if cDeviceCount == 0 { + return []*api.DeviceInfo{}, nil + } + + // Allocate stack buffer (max 64 devices to avoid stack overflow) + const maxStackDevices = 64 + var stackDevices [maxStackDevices]C.ExtendedDeviceInfo + maxDevices := int(cDeviceCount) + if maxDevices > maxStackDevices { + maxDevices = maxStackDevices + } + + var cCount C.size_t + //nolint:staticcheck + result = C.GetAllDevicesWrapper(&stackDevices[0], C.size_t(maxDevices), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get all devices: %d", result) + } + + if cCount == 0 { + return []*api.DeviceInfo{}, nil + } + + devices := make([]*api.DeviceInfo, int(cCount)) + + for i := 0; i < int(cCount); i++ { + cInfo := &stackDevices[i] + devices[i] = &api.DeviceInfo{ + UUID: C.GoString(&cInfo.basic.uuid[0]), + Vendor: C.GoString(&cInfo.basic.vendor[0]), + Model: C.GoString(&cInfo.basic.model[0]), + Index: int32(cInfo.basic.index), + NUMANode: int32(cInfo.basic.numaNode), + TotalMemoryBytes: uint64(cInfo.basic.totalMemoryBytes), + MaxTflops: float64(cInfo.basic.maxTflops), + Capabilities: api.DeviceCapabilities{ + SupportsPartitioning: bool(cInfo.capabilities.supportsPartitioning), + SupportsSoftIsolation: bool(cInfo.capabilities.supportsSoftIsolation), + SupportsHardIsolation: bool(cInfo.capabilities.supportsHardIsolation), + SupportsSnapshot: bool(cInfo.capabilities.supportsSnapshot), + SupportsMetrics: bool(cInfo.capabilities.supportsMetrics), + MaxPartitions: uint32(cInfo.capabilities.maxPartitions), + MaxWorkersPerDevice: uint32(cInfo.capabilities.maxWorkersPerDevice), + }, + Properties: make(map[string]string, 0), + } + } + + return devices, nil +} + +// AssignPartition assigns a partition to a device +func (a *AcceleratorInterface) AssignPartition(templateID, deviceUUID string) (string, error) { + cTemplateID := C.CString(templateID) + defer C.free(unsafe.Pointer(cTemplateID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + var assignment C.PartitionAssignment + C.strncpy(&assignment.templateId[0], cTemplateID, C.size_t(len(templateID))) + C.strncpy(&assignment.deviceUUID[0], cDeviceUUID, C.size_t(len(deviceUUID))) + + //nolint:staticcheck + result := C.AssignPartitionWrapper(&assignment) + if !result { + return "", fmt.Errorf("failed to assign partition") + } + + partitionUUID := C.GoString(&assignment.partitionUUID[0]) + return partitionUUID, nil +} + +// RemovePartition removes a partition from a device +func (a *AcceleratorInterface) RemovePartition(partitionUUID, deviceUUID string) error { + cPartitionUUID := C.CString(partitionUUID) + defer C.free(unsafe.Pointer(cPartitionUUID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.RemovePartitionWrapper(cPartitionUUID, cDeviceUUID) + if !result { + return fmt.Errorf("failed to remove partition") + } + + return nil +} + +// SetMemHardLimit sets hard memory limit for a worker +func (a *AcceleratorInterface) SetMemHardLimit(workerID, deviceUUID string, memoryLimitBytes uint64) error { + cWorkerID := C.CString(workerID) + defer C.free(unsafe.Pointer(cWorkerID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.SetMemHardLimitWrapper(cWorkerID, cDeviceUUID, C.uint64_t(memoryLimitBytes)) + if result != C.RESULT_SUCCESS { + return fmt.Errorf("failed to set memory hard limit: %d", result) + } + + return nil +} + +// SetComputeUnitHardLimit sets hard compute unit limit for a worker +func (a *AcceleratorInterface) SetComputeUnitHardLimit(workerID, deviceUUID string, computeUnitLimit uint32) error { + cWorkerID := C.CString(workerID) + defer C.free(unsafe.Pointer(cWorkerID)) + + cDeviceUUID := C.CString(deviceUUID) + defer C.free(unsafe.Pointer(cDeviceUUID)) + + //nolint:staticcheck + result := C.SetComputeUnitHardLimitWrapper(cWorkerID, cDeviceUUID, C.uint32_t(computeUnitLimit)) + if result != C.RESULT_SUCCESS { + return fmt.Errorf("failed to set compute unit hard limit: %d", result) + } + + return nil +} + +// GetProcessComputeUtilization retrieves compute utilization for all tracked processes +func (a *AcceleratorInterface) GetProcessComputeUtilization() ([]api.ComputeUtilization, error) { + // Get total process count from the map + totalCount := a.GetTotalProcessCount() + if totalCount == 0 { + return []api.ComputeUtilization{}, nil + } + + // Allocate stack buffer (max 1024 to avoid stack overflow) + const maxStackUtilizations = 1024 + var stackUtilizations [maxStackUtilizations]C.ComputeUtilization + maxCount := totalCount + if maxCount > maxStackUtilizations { + maxCount = maxStackUtilizations + } + + var cCount C.size_t + //nolint:staticcheck + result := C.GetProcessComputeUtilizationWrapper(&stackUtilizations[0], C.size_t(maxCount), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get process compute utilization: %d", result) + } + + if cCount == 0 { + return []api.ComputeUtilization{}, nil + } + + utilizations := make([]api.ComputeUtilization, int(cCount)) + for i := 0; i < int(cCount); i++ { + cu := &stackUtilizations[i] + utilizations[i] = api.ComputeUtilization{ + ProcessID: C.GoString(&cu.processId[0]), + DeviceUUID: C.GoString(&cu.deviceUUID[0]), + UtilizationPercent: float64(cu.utilizationPercent), + // Note: ActiveSMs, TotalSMs, and TFLOPsUsed will be added to ComputeUtilization if needed + } + } + + return utilizations, nil +} + +// GetProcessMemoryUtilization retrieves memory utilization for all tracked processes +func (a *AcceleratorInterface) GetProcessMemoryUtilization() ([]api.MemoryUtilization, error) { + // Get total process count from the map + totalCount := a.GetTotalProcessCount() + if totalCount == 0 { + return []api.MemoryUtilization{}, nil + } + + // Allocate stack buffer (max 1024 to avoid stack overflow) + const maxStackUtilizations = 1024 + var stackUtilizations [maxStackUtilizations]C.MemoryUtilization + maxCount := totalCount + if maxCount > maxStackUtilizations { + maxCount = maxStackUtilizations + } + + var cCount C.size_t + //nolint:staticcheck + result := C.GetProcessMemoryUtilizationWrapper(&stackUtilizations[0], C.size_t(maxCount), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get process memory utilization: %d", result) + } + + if cCount == 0 { + return []api.MemoryUtilization{}, nil + } + + utilizations := make([]api.MemoryUtilization, int(cCount)) + for i := 0; i < int(cCount); i++ { + mu := &stackUtilizations[i] + utilizations[i] = api.MemoryUtilization{ + ProcessID: C.GoString(&mu.processId[0]), + DeviceUUID: C.GoString(&mu.deviceUUID[0]), + UsedBytes: uint64(mu.usedBytes), + ReservedBytes: uint64(mu.reservedBytes), + // Note: UtilizationPercent will be calculated separately if needed + } + } + + return utilizations, nil +} + +// GetVendorMountLibs retrieves vendor mount libs +func (a *AcceleratorInterface) GetVendorMountLibs() ([]*api.Mount, error) { + const maxStackMounts = 64 + var stackMounts [maxStackMounts]C.Mount + var cCount C.size_t + + result := C.GetVendorMountLibsWrapper(&stackMounts[0], C.size_t(maxStackMounts), &cCount) + if result != C.RESULT_SUCCESS { + return nil, fmt.Errorf("failed to get vendor mount libs: %d", result) + } + + if cCount == 0 { + return []*api.Mount{}, nil + } + + mounts := make([]*api.Mount, int(cCount)) + for i := 0; i < int(cCount); i++ { + cm := &stackMounts[i] + var hostPath, guestPath string + if cm.hostPath != nil { + hostPath = C.GoString(cm.hostPath) + } + if cm.guestPath != nil { + guestPath = C.GoString(cm.guestPath) + } + mounts[i] = &api.Mount{ + HostPath: hostPath, + GuestPath: guestPath, + } + } + + return mounts, nil +} diff --git a/internal/hypervisor/device/accelerator_suite_test.go b/internal/hypervisor/device/accelerator_suite_test.go new file mode 100644 index 00000000..c09cb22b --- /dev/null +++ b/internal/hypervisor/device/accelerator_suite_test.go @@ -0,0 +1,13 @@ +package device + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAccelerator(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Accelerator Suite") +} diff --git a/internal/hypervisor/device/accelerator_test.go b/internal/hypervisor/device/accelerator_test.go new file mode 100644 index 00000000..b4b119e1 --- /dev/null +++ b/internal/hypervisor/device/accelerator_test.go @@ -0,0 +1,313 @@ +package device + +import ( + "fmt" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("AcceleratorInterface", func() { + var ( + accel *AcceleratorInterface + stubLibPath string + ) + + BeforeEach(func() { + // Try to find stub library + stubLibPath = "./provider/build/libaccelerator_stub.so" + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + // Try alternative path + stubLibPath = filepath.Join("..", "..", "..", "provider", "build", "libaccelerator_stub.so") + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + Skip("Stub library not found, skipping tests") + } + } + }) + + AfterEach(func() { + if accel != nil { + Expect(accel.Close()).To(Succeed()) + } + }) + + Describe("Library Loading", func() { + FIt("should load stub library successfully", func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + Expect(accel).NotTo(BeNil()) + Expect(accel.loaded).To(BeTrue()) + }) + + FIt("should fail to load non-existent library", func() { + accel, err := NewAcceleratorInterface("/non/existent/library.so") + Expect(err).To(HaveOccurred()) + Expect(accel).To(BeNil()) + }) + + It("should handle multiple load/unload cycles", func() { + accel, err := NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + + // Reload + Expect(accel.Load()).To(Succeed()) + Expect(accel.Close()).To(Succeed()) + Expect(accel.Load()).To(Succeed()) + }) + }) + + Describe("GetDeviceMetrics", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should return empty slice for empty input", func() { + metrics, err := accel.GetDeviceMetrics([]string{}) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(BeEmpty()) + }) + + It("should retrieve metrics for single device with ExtraMetrics", func() { + deviceUUIDs := []string{"test-device-001"} + metrics, err := accel.GetDeviceMetrics(deviceUUIDs) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(1)) + + m := metrics[0] + Expect(m.DeviceUUID).To(Equal(deviceUUIDs[0])) + Expect(m.MemoryBytes).To(BeNumerically(">", 0)) + Expect(m.MemoryPercentage).To(BeNumerically(">=", 0)) + Expect(m.MemoryPercentage).To(BeNumerically("<=", 100)) + Expect(m.PowerUsage).To(BeNumerically(">", 0)) + Expect(m.Temperature).To(BeNumerically(">", 0)) + + // Verify ExtraMetrics are populated + Expect(m.ExtraMetrics).NotTo(BeEmpty()) + Expect(m.ExtraMetrics).To(HaveKey("tensorCoreUsagePercent")) + Expect(m.ExtraMetrics).To(HaveKey("gpuUtilization")) + Expect(m.ExtraMetrics).To(HaveKey("memoryBandwidthMBps")) + Expect(m.ExtraMetrics["gpuUtilization"]).To(BeNumerically(">=", 0)) + Expect(m.ExtraMetrics["gpuUtilization"]).To(BeNumerically("<=", 100)) + }) + + It("should handle multiple devices", func() { + deviceUUIDs := []string{"device-1", "device-2", "device-3"} + metrics, err := accel.GetDeviceMetrics(deviceUUIDs) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(3)) + + for i, m := range metrics { + Expect(m.DeviceUUID).To(Equal(deviceUUIDs[i])) + Expect(m.ExtraMetrics).NotTo(BeEmpty()) + } + }) + + It("should correctly convert PCIe bytes to KB", func() { + metrics, err := accel.GetDeviceMetrics([]string{"test-device"}) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(1)) + + // Rx and Tx should be in KB (bytes / 1024) + Expect(metrics[0].Rx).To(BeNumerically(">", 0)) + Expect(metrics[0].Tx).To(BeNumerically(">", 0)) + }) + + It("should calculate memory percentage correctly", func() { + metrics, err := accel.GetDeviceMetrics([]string{"test-device"}) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(1)) + + m := metrics[0] + // Memory percentage should be between 0 and 100 + Expect(m.MemoryPercentage).To(BeNumerically(">=", 0)) + Expect(m.MemoryPercentage).To(BeNumerically("<=", 100)) + }) + }) + + Describe("GetAllDevices", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should retrieve device list", func() { + devices, err := accel.GetAllDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).NotTo(BeNil()) + + // Stub may return 0 or more devices + if len(devices) > 0 { + for _, d := range devices { + Expect(d.UUID).NotTo(BeEmpty()) + Expect(d.TotalMemoryBytes).To(BeNumerically(">", 0)) + } + } + }) + }) + + Describe("Process Utilization", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should return empty slices when no processes tracked", func() { + // Test both compute and memory utilization + computeUtil, err := accel.GetProcessComputeUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(computeUtil).To(BeEmpty()) + + memoryUtil, err := accel.GetProcessMemoryUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(memoryUtil).To(BeEmpty()) + + // Verify GetTotalProcessCount returns 0 + Expect(accel.GetTotalProcessCount()).To(Equal(0)) + }) + }) + + Describe("GetVendorMountLibs", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should retrieve mount libs", func() { + mounts, err := accel.GetVendorMountLibs() + Expect(err).NotTo(HaveOccurred()) + Expect(mounts).NotTo(BeNil()) + // Stub may return empty or populated mounts + }) + }) + + Describe("Memory Management", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should not leak memory on repeated GetDeviceMetrics calls", func() { + deviceUUIDs := []string{"device-1", "device-2"} + + // Call multiple times to check for memory leaks + for i := 0; i < 10; i++ { + metrics, err := accel.GetDeviceMetrics(deviceUUIDs) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(2)) + } + }) + + It("should handle large number of devices (up to limit)", func() { + // Create 64 device UUIDs (maxStackDevices limit) + deviceUUIDs := make([]string, 64) + for i := range deviceUUIDs { + deviceUUIDs[i] = fmt.Sprintf("device-%d", i) + } + + metrics, err := accel.GetDeviceMetrics(deviceUUIDs) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(64)) + }) + }) + + Describe("Edge Cases", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should handle various device UUID formats", func() { + // Test different UUID formats that might be encountered + uuidVariants := []string{ + "device-1", + "device-2_@#$", + "device-3-中文", + "12345678-1234-1234-1234-123456789abc", // UUID format + } + metrics, err := accel.GetDeviceMetrics(uuidVariants) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(len(uuidVariants))) + }) + + It("should handle empty strings in device UUIDs", func() { + metrics, err := accel.GetDeviceMetrics([]string{""}) + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(1)) + }) + }) + + Describe("AssignPartition", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should assign partition successfully", func() { + partitionUUID, err := accel.AssignPartition("mig-1g.7gb", "stub-device-0") + Expect(err).NotTo(HaveOccurred()) + Expect(partitionUUID).NotTo(BeEmpty()) + }) + + It("should reject template ID that is too long", func() { + longTemplateID := make([]byte, 100) + for i := range longTemplateID { + longTemplateID[i] = 'a' + } + _, err := accel.AssignPartition(string(longTemplateID), "stub-device-0") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too long")) + }) + + It("should reject device UUID that is too long", func() { + longDeviceUUID := make([]byte, 100) + for i := range longDeviceUUID { + longDeviceUUID[i] = 'a' + } + _, err := accel.AssignPartition("mig-1g.7gb", string(longDeviceUUID)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too long")) + }) + }) + + Describe("RemovePartition", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should remove partition successfully", func() { + err := accel.RemovePartition("partition-123", "stub-device-0") + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Describe("SetLimits", func() { + BeforeEach(func() { + var err error + accel, err = NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should set memory hard limit successfully", func() { + err := accel.SetMemHardLimit("worker-1", "stub-device-0", 1024*1024*1024) // 1GB + Expect(err).NotTo(HaveOccurred()) + }) + + It("should set compute unit hard limit successfully", func() { + err := accel.SetComputeUnitHardLimit("worker-1", "stub-device-0", 50) // 50% + Expect(err).NotTo(HaveOccurred()) + }) + }) +}) diff --git a/internal/hypervisor/device/controller.go b/internal/hypervisor/device/controller.go new file mode 100644 index 00000000..a7df706b --- /dev/null +++ b/internal/hypervisor/device/controller.go @@ -0,0 +1,358 @@ +package device + +import ( + "context" + "fmt" + "maps" + "os" + "strings" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/metrics" + "github.com/samber/lo" + "k8s.io/apimachinery/pkg/api/equality" + "k8s.io/klog/v2" +) + +var tmpDir = os.TempDir() + +// Controller manages GPU device discovery, allocation, and lifecycle +type Controller struct { + ctx context.Context + mu sync.RWMutex + devices map[string]*api.DeviceInfo // key: device UUID + deviceAllocations map[string][]*api.WorkerAllocation + + accelerator *AcceleratorInterface + acceleratorVendor string + discoveryInterval time.Duration + deviceUpdateHandlers []framework.DeviceChangeHandler + isolationMode string +} + +var _ framework.DeviceController = &Controller{} + +// NewController creates a new device manager +func NewController(ctx context.Context, acceleratorLibPath string, acceleratorVendor string, discoveryInterval time.Duration, isolationMode string) (framework.DeviceController, error) { + accel, err := NewAcceleratorInterface(acceleratorLibPath) + if err != nil { + return nil, fmt.Errorf("failed to create accelerator interface: %w", err) + } + return &Controller{ + ctx: ctx, + devices: make(map[string]*api.DeviceInfo), + deviceAllocations: make(map[string][]*api.WorkerAllocation, 32), + accelerator: accel, + acceleratorVendor: acceleratorVendor, + discoveryInterval: discoveryInterval, + deviceUpdateHandlers: make([]framework.DeviceChangeHandler, 2), + isolationMode: isolationMode, + }, nil +} + +// DiscoverDevices discovers all available GPU devices +func (m *Controller) StartDiscoverDevices() error { + // Initial device discovery + if err := m.discoverDevices(); err != nil { + return fmt.Errorf("initial device discovery failed: %w", err) + } + + go m.periodicDiscovery() + return nil +} + +// discoverDevices discovers all available GPU devices +func (m *Controller) discoverDevices() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Get all devices at once + devices, err := m.accelerator.GetAllDevices() + if err != nil { + return fmt.Errorf("failed to get all devices: %w", err) + } + + // Build a map of newly fetched devices by UUID + newDevicesMap := make(map[string]*api.DeviceInfo, len(devices)) + for _, device := range devices { + // Convert UUID to lowercase for case-insensitive comparison + // Kubernetes resource name has to be lowercase + device.UUID = strings.ToLower(device.UUID) + newDevicesMap[device.UUID] = device + } + + // Diff logic: compare new devices with existing devices (K8s reconcile pattern) + // First, identify all changes without modifying state + var addedDevices []*api.DeviceInfo + var removedDevices []*api.DeviceInfo + var updatedDevices []struct { + old *api.DeviceInfo + new *api.DeviceInfo + } + + // Find added devices (in new but not in old) + for uuid, newDevice := range newDevicesMap { + if _, exists := m.devices[uuid]; !exists { + addedDevices = append(addedDevices, newDevice) + } + } + + // Find removed devices (in old but not in new) + for uuid, oldDevice := range m.devices { + if _, exists := newDevicesMap[uuid]; !exists { + removedDevices = append(removedDevices, oldDevice) + } + } + + // Find updated devices (in both but changed) + for uuid, newDevice := range newDevicesMap { + if oldDevice, exists := m.devices[uuid]; exists { + // Check if device has changed + if !equality.Semantic.DeepEqual(oldDevice, newDevice) { + updatedDevices = append(updatedDevices, struct { + old *api.DeviceInfo + new *api.DeviceInfo + }{old: oldDevice, new: newDevice}) + } + } + } + + // Notify handlers for all changes (similar to K8s reconcile) + for _, device := range addedDevices { + m.notifyHandlers(func(handler framework.DeviceChangeHandler) { + if handler.OnAdd != nil { + handler.OnAdd(device) + } + }) + klog.V(4).Infof("Device added: %s (UUID: %s)", device.Model, device.UUID) + } + + for _, device := range removedDevices { + m.notifyHandlers(func(handler framework.DeviceChangeHandler) { + if handler.OnRemove != nil { + handler.OnRemove(device) + } + }) + klog.V(4).Infof("Device removed: %s (UUID: %s)", device.Model, device.UUID) + } + + for _, update := range updatedDevices { + m.notifyHandlers(func(handler framework.DeviceChangeHandler) { + if handler.OnUpdate != nil { + handler.OnUpdate(update.old, update.new) + } + }) + klog.V(4).Infof("Device updated: %s (UUID: %s)", update.new.Model, update.new.UUID) + } + + // Update state after notifying handlers + for _, device := range addedDevices { + m.devices[device.UUID] = device + } + for _, device := range removedDevices { + delete(m.devices, device.UUID) + } + for _, update := range updatedDevices { + m.devices[update.new.UUID] = update.new + } + + nodeInfo := m.AggregateNodeInfo() + + if metrics.ShouldSendTelemetry() { + sampleGPUModel := "" + if len(m.devices) > 0 { + for _, device := range m.devices { + if device.Model != "" { + sampleGPUModel = device.Model + break + } + } + } + workersCount := 0 + for _, allocations := range m.deviceAllocations { + workersCount += len(allocations) + } + + go metrics.SendAnonymousTelemetry( + nodeInfo, m.acceleratorVendor, sampleGPUModel, workersCount, m.isolationMode, + ) + } + m.notifyHandlers(func(handler framework.DeviceChangeHandler) { + if handler.OnDiscoveryComplete != nil { + handler.OnDiscoveryComplete(nodeInfo) + } + }) + return nil +} + +// notifyHandlers calls the provided function for each registered handler +func (m *Controller) notifyHandlers(fn func(framework.DeviceChangeHandler)) { + for _, handler := range m.deviceUpdateHandlers { + fn(handler) + } +} + +// periodicDiscovery periodically discovers devices +func (m *Controller) periodicDiscovery() { + ticker := time.NewTicker(m.discoveryInterval) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + if err := m.discoverDevices(); err != nil { + // Log error but continue + continue + } + } + } +} + +// GetDevices returns all discovered devices +func (m *Controller) GetDevices() []*api.DeviceInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + devices := make([]*api.DeviceInfo, 0, len(m.devices)) + for _, device := range m.devices { + devices = append(devices, device) + } + return devices +} + +// Start implements framework.DeviceController +func (m *Controller) Start() error { + // Start device discovery + return m.StartDiscoverDevices() +} + +func (m *Controller) Stop() error { + return m.accelerator.Close() +} + +// DiscoverDevices implements framework.DeviceController +func (m *Controller) DiscoverDevices() error { + return m.discoverDevices() +} + +// ListDevices implements framework.DeviceController +func (m *Controller) ListDevices() ([]*api.DeviceInfo, error) { + return m.GetDevices(), nil +} + +// GetDevice implements framework.DeviceController +func (m *Controller) GetDevice(deviceUUID string) (*api.DeviceInfo, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + device, exists := m.devices[deviceUUID] + return device, exists +} + +// GetDeviceMetrics implements framework.DeviceController +func (m *Controller) GetDeviceMetrics() (map[string]*api.GPUUsageMetrics, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]*api.GPUUsageMetrics, len(m.devices)) + metrics, err := m.accelerator.GetDeviceMetrics(lo.Keys(m.devices)) + if err != nil { + return nil, fmt.Errorf("failed to get device metrics: %w", err) + } + for _, metric := range metrics { + result[metric.DeviceUUID] = metric + } + return result, nil +} + +func (m *Controller) GetVendorMountLibs() ([]*api.Mount, error) { + return m.accelerator.GetVendorMountLibs() +} + +func (m *Controller) SplitDevice(partitionTemplateID string, deviceUUID string) (*api.DeviceInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + existingDevice, exists := m.devices[deviceUUID] + newPartitionedDevice := *existingDevice + if !exists { + return nil, fmt.Errorf("device %s not found, can not partition", deviceUUID) + } + partitionUUID, err := m.accelerator.AssignPartition(partitionTemplateID, deviceUUID) + if err != nil { + return nil, err + } + newPartitionedDevice.ParentUUID = newPartitionedDevice.UUID + newPartitionedDevice.UUID = partitionUUID + m.devices[partitionUUID] = &newPartitionedDevice + return &newPartitionedDevice, nil +} + +func (m *Controller) RemovePartitionedDevice(partitionUUID, deviceUUID string) error { + m.mu.Lock() + defer m.mu.Unlock() + _, exists := m.devices[partitionUUID] + if !exists { + return fmt.Errorf("partition %s not found, can not remove", partitionUUID) + } + + err := m.accelerator.RemovePartition(partitionUUID, deviceUUID) + if err != nil { + return err + } + klog.Infof("removed partition %s from device %s", partitionUUID, deviceUUID) + delete(m.devices, partitionUUID) + return nil +} + +func (m *Controller) RegisterDeviceUpdateHandler(handler framework.DeviceChangeHandler) { + m.mu.Lock() + defer m.mu.Unlock() + m.deviceUpdateHandlers = append(m.deviceUpdateHandlers, handler) +} + +func (m *Controller) GetAcceleratorVendor() string { + return m.acceleratorVendor +} + +func (m *Controller) AggregateNodeInfo() *api.NodeInfo { + info := &api.NodeInfo{ + RAMSizeBytes: GetTotalHostRAMBytes(), + DataDiskBytes: GetDiskInfo(tmpDir), + } + for _, device := range m.devices { + info.TotalTFlops += device.MaxTflops + info.TotalVRAMBytes += int64(device.TotalMemoryBytes) + info.DeviceIDs = append(info.DeviceIDs, device.UUID) + } + return info +} + +func (m *Controller) GetDeviceAllocations() map[string][]*api.WorkerAllocation { + m.mu.RLock() + defer m.mu.RUnlock() + return maps.Clone(m.deviceAllocations) +} + +func (m *Controller) AddDeviceAllocation(deviceUUID string, allocation *api.WorkerAllocation) { + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.deviceAllocations[deviceUUID]; !exists { + m.deviceAllocations[deviceUUID] = make([]*api.WorkerAllocation, 0, 8) + } + m.deviceAllocations[deviceUUID] = append(m.deviceAllocations[deviceUUID], allocation) +} + +func (m *Controller) RemoveDeviceAllocation(deviceUUID string, allocation *api.WorkerAllocation) { + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.deviceAllocations[deviceUUID]; !exists { + return + } + m.deviceAllocations[deviceUUID] = lo.Filter(m.deviceAllocations[deviceUUID], func(wa *api.WorkerAllocation, _ int) bool { + return wa.WorkerInfo.WorkerUID != allocation.WorkerInfo.WorkerUID + }) +} diff --git a/internal/hypervisor/device/host_discovery.go b/internal/hypervisor/device/host_discovery.go new file mode 100644 index 00000000..7c7f730f --- /dev/null +++ b/internal/hypervisor/device/host_discovery.go @@ -0,0 +1,51 @@ +package device + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "syscall" + + "github.com/shirou/gopsutil/mem" +) + +func GetTotalHostRAMBytes() int64 { + v, err := mem.VirtualMemory() + if err != nil { + fmt.Printf("[warning] getting memory info failed: %v\n", err) + return 0 + } + return int64(v.Total) +} + +func GetDiskInfo(path string) (total int64) { + absPath, err := filepath.Abs(path) + if err != nil { + fmt.Printf("[warning] getting disk path failed: %v\n", err) + return 0 + } + + var stat syscall.Statfs_t + err = syscall.Statfs(absPath, &stat) + if err != nil { + if errors.Is(err, syscall.ENOENT) { + err = os.MkdirAll(absPath, 0o755) + if err != nil { + fmt.Printf("[warning] creating folder to discover disk space failed: %s, err: %v\n", absPath, err) + return 0 + } + err = syscall.Statfs(absPath, &stat) + if err != nil { + fmt.Printf("[warning] getting disk stats after creation failed: %v\n", err) + return 0 + } + } else { + fmt.Printf("[warning] getting disk stats failed: %v\n", err) + return 0 + } + } + + total = int64(stat.Blocks * uint64(stat.Bsize)) + return total +} diff --git a/internal/hypervisor/device/provider_log.go b/internal/hypervisor/device/provider_log.go new file mode 100644 index 00000000..aa425e78 --- /dev/null +++ b/internal/hypervisor/device/provider_log.go @@ -0,0 +1,56 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package device + +/* +#cgo CFLAGS: -I../../../provider +#include +*/ +import "C" +import ( + "k8s.io/klog/v2" +) + +// GoLog is exported to C code via //export directive +// This function is called by C code (wrapper.c) to log messages using klog +// +//export GoLog +func GoLog(level *C.char, message *C.char) { + if level == nil || message == nil { + return + } + + levelStr := C.GoString(level) + messageStr := C.GoString(message) + + // Map C log levels to klog levels + switch levelStr { + case "DEBUG", "debug": + klog.V(4).Info(messageStr) + case "INFO", "info": + klog.Info(messageStr) + case "WARN", "warn", "WARNING", "warning": + klog.Warning(messageStr) + case "ERROR", "error": + klog.Error(messageStr) + case "FATAL", "fatal": + klog.Fatal(messageStr) + default: + // Default to Info level for unknown levels + klog.Info(messageStr) + } +} diff --git a/internal/hypervisor/device/wrapper.c b/internal/hypervisor/device/wrapper.c new file mode 100644 index 00000000..791fdbda --- /dev/null +++ b/internal/hypervisor/device/wrapper.c @@ -0,0 +1,227 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../../provider/accelerator.h" +#include +#include +#include +#include +#include +#include +#include + +// Forward declaration of Go Log function +extern void GoLog(const char* level, const char* message); + +// Function pointer types for dynamic loading +typedef Result (*GetDeviceCountFunc)(size_t*); +typedef Result (*GetAllDevicesFunc)(ExtendedDeviceInfo*, size_t, size_t*); +typedef Result (*GetPartitionTemplatesFunc)(int32_t, PartitionTemplate*, size_t, size_t*); +typedef bool (*AssignPartitionFunc)(PartitionAssignment*); +typedef bool (*RemovePartitionFunc)(const char*, const char*); +typedef Result (*SetMemHardLimitFunc)(const char*, const char*, uint64_t); +typedef Result (*SetComputeUnitHardLimitFunc)(const char*, const char*, uint32_t); +typedef Result (*GetProcessComputeUtilizationFunc)(ComputeUtilization*, size_t, size_t*); +typedef Result (*GetProcessMemoryUtilizationFunc)(MemoryUtilization*, size_t, size_t*); +typedef Result (*GetDeviceMetricsFunc)(const char**, size_t, DeviceMetrics*, size_t); +typedef Result (*GetVendorMountLibsFunc)(Mount*, size_t, size_t*); +typedef Result (*LogFunc)(const char*, const char*); + +// Global handle for the loaded library +static void* libHandle = NULL; + +// Function pointers +static GetDeviceCountFunc getDeviceCountFunc = NULL; +static GetAllDevicesFunc getAllDevicesFunc = NULL; +static GetPartitionTemplatesFunc getPartitionTemplatesFunc = NULL; +static AssignPartitionFunc assignPartitionFunc = NULL; +static RemovePartitionFunc removePartitionFunc = NULL; +static SetMemHardLimitFunc setMemHardLimitFunc = NULL; +static SetComputeUnitHardLimitFunc setComputeUnitHardLimitFunc = NULL; +static GetProcessComputeUtilizationFunc getProcessComputeUtilizationFunc = NULL; +static GetProcessMemoryUtilizationFunc getProcessMemoryUtilizationFunc = NULL; +static GetDeviceMetricsFunc getDeviceMetricsFunc = NULL; +static GetVendorMountLibsFunc getVendorMountLibsFunc = NULL; +static LogFunc logFunc = NULL; + +// Load library dynamically +int loadAcceleratorLibrary(const char* libPath) { + if (libHandle != NULL) { + dlclose(libHandle); + } + + libHandle = dlopen(libPath, RTLD_LAZY | RTLD_LOCAL); + if (libHandle == NULL) { + return -1; // Failed to load + } + + // Load function symbols + getDeviceCountFunc = (GetDeviceCountFunc)dlsym(libHandle, "GetDeviceCount"); + getAllDevicesFunc = (GetAllDevicesFunc)dlsym(libHandle, "GetAllDevices"); + getPartitionTemplatesFunc = (GetPartitionTemplatesFunc)dlsym(libHandle, "GetPartitionTemplates"); + assignPartitionFunc = (AssignPartitionFunc)dlsym(libHandle, "AssignPartition"); + removePartitionFunc = (RemovePartitionFunc)dlsym(libHandle, "RemovePartition"); + setMemHardLimitFunc = (SetMemHardLimitFunc)dlsym(libHandle, "SetMemHardLimit"); + setComputeUnitHardLimitFunc = (SetComputeUnitHardLimitFunc)dlsym(libHandle, "SetComputeUnitHardLimit"); + getProcessComputeUtilizationFunc = (GetProcessComputeUtilizationFunc)dlsym(libHandle, "GetProcessComputeUtilization"); + getProcessMemoryUtilizationFunc = (GetProcessMemoryUtilizationFunc)dlsym(libHandle, "GetProcessMemoryUtilization"); + getDeviceMetricsFunc = (GetDeviceMetricsFunc)dlsym(libHandle, "GetDeviceMetrics"); + getVendorMountLibsFunc = (GetVendorMountLibsFunc)dlsym(libHandle, "GetVendorMountLibs"); + logFunc = (LogFunc)dlsym(libHandle, "Log"); + + // Check if all required functions are loaded (Log is optional) + if (!getDeviceCountFunc || !getAllDevicesFunc || !getPartitionTemplatesFunc || + !assignPartitionFunc || !removePartitionFunc || !setMemHardLimitFunc || + !setComputeUnitHardLimitFunc || !getProcessComputeUtilizationFunc || + !getProcessMemoryUtilizationFunc || !getDeviceMetricsFunc || !getVendorMountLibsFunc) { + dlclose(libHandle); + libHandle = NULL; + return -2; // Missing symbols + } + + // If the library has a Log function, we can't directly replace it, + // but we provide our own Log function that the library can use. + // The library's internal Log calls will use its own implementation, + // but if the library is designed to call Log via function pointer or + // if it doesn't have its own Log, it will use our implementation. + + return 0; // Success +} + +// Unload library +void unloadAcceleratorLibrary(void) { + if (libHandle != NULL) { + dlclose(libHandle); + libHandle = NULL; + getDeviceCountFunc = NULL; + getAllDevicesFunc = NULL; + getPartitionTemplatesFunc = NULL; + assignPartitionFunc = NULL; + removePartitionFunc = NULL; + setMemHardLimitFunc = NULL; + setComputeUnitHardLimitFunc = NULL; + getProcessComputeUtilizationFunc = NULL; + getProcessMemoryUtilizationFunc = NULL; + getDeviceMetricsFunc = NULL; + getVendorMountLibsFunc = NULL; + logFunc = NULL; + } +} + +// Wrapper functions that call the dynamically loaded functions +Result GetDeviceCountWrapper(size_t* deviceCount) { + if (getDeviceCountFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getDeviceCountFunc(deviceCount); +} + +Result GetAllDevicesWrapper(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount) { + if (getAllDevicesFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getAllDevicesFunc(devices, maxCount, deviceCount); +} + +Result GetPartitionTemplatesWrapper(int32_t deviceIndex, PartitionTemplate* templates, size_t maxCount, size_t* templateCount) { + if (getPartitionTemplatesFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getPartitionTemplatesFunc(deviceIndex, templates, maxCount, templateCount); +} + +bool AssignPartitionWrapper(PartitionAssignment* assignment) { + if (assignPartitionFunc == NULL) { + return false; + } + return assignPartitionFunc(assignment); +} + +bool RemovePartitionWrapper(const char* templateId, const char* deviceUUID) { + if (removePartitionFunc == NULL) { + return false; + } + return removePartitionFunc(templateId, deviceUUID); +} + +Result SetMemHardLimitWrapper(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes) { + if (setMemHardLimitFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return setMemHardLimitFunc(workerId, deviceUUID, memoryLimitBytes); +} + +Result SetComputeUnitHardLimitWrapper(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit) { + if (setComputeUnitHardLimitFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return setComputeUnitHardLimitFunc(workerId, deviceUUID, computeUnitLimit); +} + +Result GetProcessComputeUtilizationWrapper(ComputeUtilization* utilizations, size_t maxCount, size_t* utilizationCount) { + if (getProcessComputeUtilizationFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getProcessComputeUtilizationFunc(utilizations, maxCount, utilizationCount); +} + +Result GetProcessMemoryUtilizationWrapper(MemoryUtilization* utilizations, size_t maxCount, size_t* utilizationCount) { + if (getProcessMemoryUtilizationFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getProcessMemoryUtilizationFunc(utilizations, maxCount, utilizationCount); +} + +Result GetDeviceMetricsWrapper(const char** deviceUUIDArray, size_t deviceCount, DeviceMetrics* metrics, size_t maxExtraMetricsPerDevice) { + if (getDeviceMetricsFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getDeviceMetricsFunc(deviceUUIDArray, deviceCount, metrics, maxExtraMetricsPerDevice); +} + +Result GetVendorMountLibsWrapper(Mount* mounts, size_t maxCount, size_t* mountCount) { + if (getVendorMountLibsFunc == NULL) { + return RESULT_ERROR_INTERNAL; + } + return getVendorMountLibsFunc(mounts, maxCount, mountCount); +} + +// Get error message from dlopen +const char* getDlError(void) { + return dlerror(); +} + +// Log wrapper that calls Go's Log function +// This function provides a Log implementation that the dynamically loaded library can use +// When the library calls Log(), it will call this function which forwards to Go's klog +Result LogWrapper(const char* level, const char* message) { + if (level == NULL || message == NULL) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Call Go's Log function + GoLog(level, message); + + return RESULT_SUCCESS; +} + +// Provide a Log function that can be called by the dynamically loaded library +// This is the Log function that accelerator.h defines - we provide an implementation +// that forwards to Go's klog via GoLog +Result Log(const char* level, const char* message) { + return LogWrapper(level, message); +} + diff --git a/internal/hypervisor/framework/framework.go b/internal/hypervisor/framework/framework.go new file mode 100644 index 00000000..fa0a3e75 --- /dev/null +++ b/internal/hypervisor/framework/framework.go @@ -0,0 +1,109 @@ +package framework + +import ( + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +type DeviceController interface { + Start() error + + Stop() error + + DiscoverDevices() error + + ListDevices() ([]*api.DeviceInfo, error) + + GetDevice(deviceUUID string) (*api.DeviceInfo, bool) + + SplitDevice(deviceUUID string, partitionID string) (*api.DeviceInfo, error) + + RemovePartitionedDevice(partitionUUID, deviceUUID string) error + + GetDeviceMetrics() (map[string]*api.GPUUsageMetrics, error) + + GetVendorMountLibs() ([]*api.Mount, error) + + RegisterDeviceUpdateHandler(handler DeviceChangeHandler) + + GetAcceleratorVendor() string + + GetDeviceAllocations() map[string][]*api.WorkerAllocation + + AddDeviceAllocation(deviceUUID string, allocation *api.WorkerAllocation) + + RemoveDeviceAllocation(workerUID string, allocation *api.WorkerAllocation) +} + +type WorkerController interface { + Start() error + + Stop() error + + AllocateWorkerDevices(request *api.WorkerInfo) (*api.WorkerAllocation, error) + + DeallocateWorker(workerUID string) error + + ListWorkers() ([]*api.WorkerInfo, error) + + GetWorkerAllocation(workerUID string) (*api.WorkerAllocation, bool) + + // GetWorkerMetrics returns current worker metrics for all workers + // Returns map keyed by device UUID, then by worker UID, then by process ID + GetWorkerMetrics() (map[string]map[string]map[string]*api.WorkerMetrics, error) +} + +type QuotaController interface { + // SetQuota sets quota for a worker + SetQuota(workerUID string) error + + StartSoftQuotaLimiter() error + + StopSoftQuotaLimiter() error + + // GetWorkerQuotaStatus gets quota status for a worker + GetWorkerQuotaStatus(workerUID string) error +} + +// The backend interface for the hypervisor to interact with the underlying infrastructure +type Backend interface { + Start() error + + Stop() error + + // RegisterWorkerUpdateHandler registers a handler for worker updates + // The handler will be called for all existing workers (OnAdd) and all future worker changes (add, update, remove) + RegisterWorkerUpdateHandler(handler WorkerChangeHandler) error + + // StartWorker spawns worker process + StartWorker(worker *api.WorkerInfo) error + + // StopWorker stops worker process + StopWorker(workerUID string) error + + // GetProcessMappingInfo gets process mapping information for a worker + GetProcessMappingInfo(workerUID string, hostPID uint32) (*ProcessMappingInfo, error) + + GetDeviceChangeHandler() DeviceChangeHandler + + ListWorkers() []*api.WorkerInfo +} + +// ProcessWorkerInfo contains worker information extracted from a process +type ProcessMappingInfo struct { + GuestID string + HostPID uint32 + GuestPID uint32 +} + +type DeviceChangeHandler struct { + OnAdd func(device *api.DeviceInfo) + OnRemove func(device *api.DeviceInfo) + OnUpdate func(oldDevice, newDevice *api.DeviceInfo) + OnDiscoveryComplete func(nodeInfo *api.NodeInfo) +} + +type WorkerChangeHandler struct { + OnAdd func(worker *api.WorkerInfo) + OnRemove func(worker *api.WorkerInfo) + OnUpdate func(oldWorker, newWorker *api.WorkerInfo) +} diff --git a/internal/hypervisor/hypervisor_suite_test.go b/internal/hypervisor/hypervisor_suite_test.go new file mode 100644 index 00000000..31b9efe4 --- /dev/null +++ b/internal/hypervisor/hypervisor_suite_test.go @@ -0,0 +1,552 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package hypervisor + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/apimachinery/pkg/api/resource" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/backend/single_node" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/metrics" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker" +) + +func TestHypervisor(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Hypervisor Suite") +} + +var _ = Describe("Hypervisor Integration Tests", func() { + var ( + ctx context.Context + cancel context.CancelFunc + deviceController framework.DeviceController + backend framework.Backend + workerController framework.WorkerController + metricsRecorder *metrics.HypervisorMetricsRecorder + httpServer *server.Server + stubLibPath string + tempMetricsFile string + ) + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + + // Find stub library path + // Try relative path first (from provider/build) + stubLibPath = filepath.Join("..", "..", "provider", "build", "libaccelerator_stub.so") + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + // Try absolute path from workspace root + workspaceRoot := os.Getenv("WORKSPACE_ROOT") + if workspaceRoot == "" { + // Try to find it relative to current directory + cwd, _ := os.Getwd() + stubLibPath = filepath.Join(cwd, "..", "..", "provider", "build", "libaccelerator_stub.so") + } else { + stubLibPath = filepath.Join(workspaceRoot, "provider", "build", "libaccelerator_stub.so") + } + } + + // Create temp file for metrics + tempFile, err := os.CreateTemp("", "hypervisor-metrics-*.log") + Expect(err).NotTo(HaveOccurred()) + tempMetricsFile = tempFile.Name() + _ = tempFile.Close() + }) + + AfterEach(func() { + if cancel != nil { + cancel() + } + if httpServer != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer shutdownCancel() + _ = httpServer.Stop(shutdownCtx) + } + if workerController != nil { + _ = workerController.Stop() + } + if backend != nil { + _ = backend.Stop() + } + if deviceController != nil { + if closer, ok := deviceController.(interface{ Close() error }); ok { + _ = closer.Close() + } + } + _ = os.Remove(tempMetricsFile) + }) + + Context("With stub device library", func() { + BeforeEach(func() { + // Check if stub library exists, skip if not + if _, err := os.Stat(stubLibPath); os.IsNotExist(err) { + Skip("Stub library not found. Run 'make stub' in provider directory first.") + } + + var err error + deviceController, err = device.NewController(ctx, stubLibPath, "stub", 1*time.Hour, tfv1.IsolationModeShared) + Expect(err).NotTo(HaveOccurred()) + Expect(deviceController).NotTo(BeNil()) + + backend = single_node.NewSingleNodeBackend(ctx, deviceController) + Expect(backend).NotTo(BeNil()) + + workerController = worker.NewWorkerController(deviceController, tfv1.IsolationModeShared, backend) + Expect(workerController).NotTo(BeNil()) + + metricsRecorder = metrics.NewHypervisorMetricsRecorder(ctx, tempMetricsFile, deviceController, workerController) + Expect(metricsRecorder).NotTo(BeNil()) + + httpServer = server.NewServer(ctx, deviceController, workerController, metricsRecorder, backend, 0) + Expect(httpServer).NotTo(BeNil()) + }) + + Describe("C Stub Library Integration", func() { + It("should load stub accelerator library", func() { + // Verify library can be loaded + accel, err := device.NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + Expect(accel).NotTo(BeNil()) + + // Test device discovery through C library + devices, err := accel.GetAllDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + // Verify stub device properties + device := devices[0] + Expect(device.UUID).To(ContainSubstring("stub-device")) + Expect(device.Vendor).To(Equal("STUB")) + Expect(device.TotalMemoryBytes).To(Equal(uint64(16 * 1024 * 1024 * 1024))) // 16GB + + _ = accel.Close() + }) + + It("should get process utilization from stub library", func() { + accel, err := device.NewAcceleratorInterface(stubLibPath) + Expect(err).NotTo(HaveOccurred()) + defer func() { + _ = accel.Close() + }() + + // Get compute utilization (may be empty for stub) + computeUtils, err := accel.GetProcessComputeUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(computeUtils).NotTo(BeNil()) + + // Get memory utilization (may be empty for stub) + memUtils, err := accel.GetProcessMemoryUtilization() + Expect(err).NotTo(HaveOccurred()) + Expect(memUtils).NotTo(BeNil()) + }) + }) + + Describe("Device Controller", func() { + It("should start and discover devices", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + // Wait a bit for discovery + time.Sleep(100 * time.Millisecond) + + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty(), "Should discover at least one stub device") + + // Verify device properties + device := devices[0] + Expect(device.UUID).NotTo(BeEmpty()) + Expect(device.Vendor).To(Equal("STUB")) + Expect(device.TotalMemoryBytes).To(BeNumerically(">", 0)) + }) + + It("should allocate devices", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(100 * time.Millisecond) + + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + deviceUUID := devices[0].UUID + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{deviceUUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + + resp, err := workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + // TODO verify the mounts/envs + + // Verify allocation exists through worker controller + allocation, found := workerController.GetWorkerAllocation("test-worker-1") + Expect(found).To(BeTrue()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal("test-worker-1")) + }) + + It("should get GPU metrics", func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(100 * time.Millisecond) + + metrics, err := deviceController.GetDeviceMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).NotTo(BeNil()) + + // Should have metrics for all discovered devices + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(metrics).To(HaveLen(len(devices))) + }) + }) + + Describe("Single Node Backend", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = backend.Start() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should start and stop", func() { + Expect(backend).NotTo(BeNil()) + }) + + It("should list workers from allocations", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + + // Start the worker in the backend + err = backend.StartWorker(req) + Expect(err).NotTo(HaveOccurred()) + + // Wait a bit for state to sync + time.Sleep(500 * time.Millisecond) + + // Register a handler to receive updates and track initial workers + var found bool + handler := framework.WorkerChangeHandler{ + OnAdd: func(worker *api.WorkerInfo) { + if worker.WorkerUID == "test-worker-1" { + found = true + } + }, + OnRemove: func(worker *api.WorkerInfo) {}, + OnUpdate: func(oldWorker, newWorker *api.WorkerInfo) {}, + } + err = backend.RegisterWorkerUpdateHandler(handler) + Expect(err).NotTo(HaveOccurred()) + + // Wait a bit for OnAdd callbacks to be invoked + time.Sleep(100 * time.Millisecond) + Expect(found).To(BeTrue(), "Should find test-worker-1 via OnAdd callback") + }) + + It("should track worker to process mapping", func() { + // Start a worker + worker := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{}, + IsolationMode: tfv1.IsolationModeSoft, + } + err := backend.StartWorker(worker) + Expect(err).NotTo(HaveOccurred()) + + // Test process mapping + processInfo, err := backend.GetProcessMappingInfo("test-worker-1", 12345) + Expect(err).NotTo(HaveOccurred()) + Expect(processInfo).NotTo(BeNil()) + Expect(processInfo.GuestID).To(Equal("test-worker-1")) + Expect(processInfo.HostPID).To(Equal(uint32(12345))) + }) + }) + + Describe("Worker Controller", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should start and stop", func() { + Expect(workerController).NotTo(BeNil()) + }) + + It("should list workers", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + + workers, err := workerController.ListWorkers() + Expect(err).NotTo(HaveOccurred()) + found := false + for _, worker := range workers { + if worker.WorkerUID == "test-worker-1" { + found = true + break + } + } + Expect(found).To(BeTrue()) + }) + + It("should get worker allocation", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + + allocation, found := workerController.GetWorkerAllocation("test-worker-1") + Expect(found).To(BeTrue()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal("test-worker-1")) + }) + + It("should get worker metrics", func() { + // Create an allocation + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + + req := &api.WorkerInfo{ + WorkerUID: "test-worker-1", + AllocatedDevices: []string{devices[0].UUID}, + IsolationMode: tfv1.IsolationModeSoft, + } + _, err = workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + + metrics, err := workerController.GetWorkerMetrics() + Expect(err).NotTo(HaveOccurred()) + // Metrics may be empty for stub devices, which is okay + // Just verify we got a valid response (nil or empty map is acceptable) + _ = metrics + }) + }) + + Describe("Metrics Recorder", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + }) + + It("should record metrics", func() { + // Wait for metrics to be recorded + time.Sleep(2 * time.Second) + + // Check if metrics file was created and has content + info, err := os.Stat(tempMetricsFile) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Size()).To(BeNumerically(">=", 0)) + }) + }) + + Describe("HTTP Server", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + }) + + It("should start HTTP server", func() { + // Start server in background + go func() { + err := httpServer.Start() + Expect(err).To(Or(BeNil(), MatchError("http: Server closed"))) + }() + + // Wait for server to start + time.Sleep(500 * time.Millisecond) + + // Server should be running (we can't easily test HTTP endpoints without knowing the port) + // But we can verify the server object is created + Expect(httpServer).NotTo(BeNil()) + }) + }) + + Describe("Full Integration", func() { + BeforeEach(func() { + err := deviceController.Start() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + + err = backend.Start() + Expect(err).NotTo(HaveOccurred()) + + err = workerController.Start() + Expect(err).NotTo(HaveOccurred()) + + metricsRecorder.Start() + + // Start HTTP server in background + go func() { + _ = httpServer.Start() + }() + time.Sleep(500 * time.Millisecond) + }) + + It("should handle complete workflow: discover -> allocate -> track -> metrics", func() { + // 1. Discover devices + devices, err := deviceController.ListDevices() + Expect(err).NotTo(HaveOccurred()) + Expect(devices).ToNot(BeEmpty()) + deviceUUID := devices[0].UUID + + // 2. Allocate device + req := &api.WorkerInfo{ + WorkerUID: "integration-worker-1", + AllocatedDevices: []string{deviceUUID}, + IsolationMode: tfv1.IsolationModeSoft, + Requests: tfv1.Resource{ + Tflops: resource.MustParse("1000"), + Vram: resource.MustParse("1Gi"), + }, + } + resp, err := workerController.AllocateWorkerDevices(req) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).To(Not(BeNil())) + + // Start worker in backend + err = backend.StartWorker(req) + Expect(err).NotTo(HaveOccurred()) + + // 3. Verify allocation through worker controller + allocation, found := workerController.GetWorkerAllocation("integration-worker-1") + Expect(found).To(BeTrue()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal("integration-worker-1")) + + // 4. Backend should list worker + time.Sleep(500 * time.Millisecond) + // Register a handler to receive updates and track initial workers + var foundInList bool + handler := framework.WorkerChangeHandler{ + OnAdd: func(worker *api.WorkerInfo) { + if worker.WorkerUID == "integration-worker-1" { + foundInList = true + } + }, + OnRemove: func(worker *api.WorkerInfo) {}, + OnUpdate: func(oldWorker, newWorker *api.WorkerInfo) {}, + } + err = backend.RegisterWorkerUpdateHandler(handler) + Expect(err).NotTo(HaveOccurred()) + + // Wait a bit for OnAdd callbacks to be invoked + time.Sleep(100 * time.Millisecond) + Expect(foundInList).To(BeTrue(), "Should find integration-worker-1 via OnAdd callback") + + // 5. Worker controller should list worker + workerList, err := workerController.ListWorkers() + Expect(err).NotTo(HaveOccurred()) + foundInWorkerList := false + for _, worker := range workerList { + if worker.WorkerUID == "integration-worker-1" { + foundInWorkerList = true + break + } + } + Expect(foundInWorkerList).To(BeTrue()) + + // 6. Get worker allocation + allocation, found = workerController.GetWorkerAllocation("integration-worker-1") + Expect(found).To(BeTrue()) + Expect(allocation).NotTo(BeNil()) + Expect(allocation.WorkerInfo.WorkerUID).To(Equal("integration-worker-1")) + + // 7. Get metrics + gpuMetrics, err := deviceController.GetDeviceMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(gpuMetrics).NotTo(BeNil()) + Expect(gpuMetrics[deviceUUID]).NotTo(BeNil()) + + workerMetrics, err := workerController.GetWorkerMetrics() + Expect(err).NotTo(HaveOccurred()) + Expect(workerMetrics).NotTo(BeNil()) + + // 8. Deallocate worker + err = workerController.DeallocateWorker("integration-worker-1") + Expect(err).NotTo(HaveOccurred()) + + // 9. Verify deallocation + _, found = workerController.GetWorkerAllocation("integration-worker-1") + Expect(found).To(BeFalse()) + }) + }) + }) +}) diff --git a/internal/hypervisor/metrics/metrics.go b/internal/hypervisor/metrics/metrics.go new file mode 100644 index 00000000..b2b2f685 --- /dev/null +++ b/internal/hypervisor/metrics/metrics.go @@ -0,0 +1,394 @@ +package metrics + +import ( + "context" + "encoding/json" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/metrics" + "github.com/NexusGPU/tensor-fusion/internal/utils" + "github.com/NexusGPU/tensor-fusion/internal/version" + "github.com/posthog/posthog-go" + "golang.org/x/sys/unix" + "gopkg.in/natefinch/lumberjack.v2" + "k8s.io/klog/v2" +) + +type HypervisorMetricsRecorder struct { + ctx context.Context + outputPath string + nodeName string + gpuPool string + deviceController framework.DeviceController + workerController framework.WorkerController + extraLabelsMap map[string]string // podLabelKey -> tagName mapping from env config +} + +const ( + defaultNodeName = "unknown" + defaultGPUPool = "unknown" +) + +var ( + startTime = time.Now() + telemetryClient posthog.Client + telemetryClientMu sync.Once + telemetryLockMu sync.Mutex + telemetryMinInterval = 24 * time.Hour +) + +func NewHypervisorMetricsRecorder( + ctx context.Context, outputPath string, + deviceController framework.DeviceController, + workerController framework.WorkerController, +) *HypervisorMetricsRecorder { + nodeName := os.Getenv(constants.HypervisorGPUNodeNameEnv) + if nodeName == "" { + nodeName = defaultNodeName + } + gpuPool := os.Getenv(constants.HypervisorPoolNameEnv) + if gpuPool == "" { + gpuPool = defaultGPUPool + } + + // Parse extra labels config once at initialization + extraLabelsMap := make(map[string]string) + extraLabelsConfig := os.Getenv(constants.HypervisorMetricsExtraLabelsEnv) + if extraLabelsConfig != "" { + if err := json.Unmarshal([]byte(extraLabelsConfig), &extraLabelsMap); err != nil { + // Log error but continue without extra labels + extraLabelsMap = make(map[string]string) + } + } + + return &HypervisorMetricsRecorder{ + ctx: ctx, + outputPath: outputPath, + nodeName: nodeName, + gpuPool: gpuPool, + deviceController: deviceController, + workerController: workerController, + extraLabelsMap: extraLabelsMap, + } +} + +func (h *HypervisorMetricsRecorder) Start() { + writer := &lumberjack.Logger{ + Filename: h.outputPath, + MaxSize: 100, + MaxBackups: 10, + MaxAge: 14, + } + + // Record device and worker metrics + deviceMetricsTicker := time.NewTicker(10 * time.Second) + go func() { + for { + select { + case <-h.ctx.Done(): + return + case <-deviceMetricsTicker.C: + h.RecordDeviceMetrics(writer) + h.RecordWorkerMetrics(writer) + } + } + }() +} + +func (h *HypervisorMetricsRecorder) RecordDeviceMetrics(writer io.Writer) { + gpuMetrics, err := h.deviceController.GetDeviceMetrics() + if err != nil { + return + } + + // Output GPU metrics directly + now := time.Now() + enc := metrics.NewEncoder(os.Getenv(constants.HypervisorMetricsFormatEnv)) + + for gpuUUID, metrics := range gpuMetrics { + enc.StartLine("tf_gpu_usage") + enc.AddTag("uuid", gpuUUID) + enc.AddTag("node", h.nodeName) + enc.AddTag("pool", h.gpuPool) + + enc.AddField("rx", metrics.Rx) + enc.AddField("tx", metrics.Tx) + enc.AddField("temperature", metrics.Temperature) + enc.AddField("memory_bytes", int64(metrics.MemoryBytes)) + enc.AddField("memory_percentage", metrics.MemoryPercentage) + enc.AddField("compute_percentage", metrics.ComputePercentage) + enc.AddField("compute_tflops", metrics.ComputeTflops) + enc.AddField("power_usage", float64(metrics.PowerUsage)) + if metrics.ExtraMetrics != nil { + for key, value := range metrics.ExtraMetrics { + enc.AddField(key, value) + } + } + enc.EndLine(now) + } + + if err := enc.Err(); err == nil { + _, _ = writer.Write(enc.Bytes()) + } +} + +func (h *HypervisorMetricsRecorder) RecordWorkerMetrics(writer io.Writer) { + workerMetrics, err := h.workerController.GetWorkerMetrics() + if err != nil { + return + } + + workers, err := h.workerController.ListWorkers() + if err != nil { + return + } + + // Get worker allocations for metadata + workerAllocations := make(map[string]*api.WorkerAllocation) + for _, worker := range workers { + allocation, found := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if found && allocation != nil { + workerAllocations[worker.WorkerUID] = allocation + } + } + + // Output worker metrics directly + now := time.Now() + enc := metrics.NewEncoder(os.Getenv(constants.HypervisorMetricsFormatEnv)) + + for deviceUUID, workerMap := range workerMetrics { + for workerUID, processMap := range workerMap { + allocation, ok := workerAllocations[workerUID] + if !ok { + continue + } + + var memoryBytes uint64 + var computePercentage float64 + var computeTflops float64 + var memoryPercentage float64 + + // Sum up metrics from all processes for this worker + for _, metrics := range processMap { + memoryBytes += metrics.MemoryBytes + computePercentage += metrics.ComputePercentage + computeTflops += metrics.ComputeTflops + + // Calculate memory percentage + vramLimit := float64(0) + if allocation.WorkerInfo != nil { + vramLimit = float64(allocation.WorkerInfo.Limits.Vram.Value()) + } + if vramLimit > 0 { + memoryPercentage += float64(metrics.MemoryBytes) / vramLimit * 100.0 + } + } + + enc.StartLine("tf_worker_usage") + enc.AddTag("uuid", deviceUUID) + enc.AddTag("node", h.nodeName) + enc.AddTag("pool", h.gpuPool) + if allocation.WorkerInfo != nil { + enc.AddTag("pod_name", allocation.WorkerInfo.WorkerName) + enc.AddTag("namespace", allocation.WorkerInfo.Namespace) + } + + workloadName := "unknown" + // Try to get workload name from worker ID or pod name + if allocation.WorkerInfo != nil && allocation.WorkerInfo.WorkerUID != "" { + workloadName = allocation.WorkerInfo.WorkerUID + } + enc.AddTag("workload", workloadName) + enc.AddTag("worker", workerUID) + + // Add extra labels if configured + h.addExtraLabels(enc, allocation) + + enc.AddField("memory_bytes", int64(memoryBytes)) + enc.AddField("compute_percentage", computePercentage) + enc.AddField("compute_tflops", computeTflops) + enc.AddField("memory_percentage", memoryPercentage) + + enc.EndLine(now) + } + } + + if err := enc.Err(); err == nil { + _, _ = writer.Write(enc.Bytes()) + } +} + +// addExtraLabels adds dynamic tags based on HypervisorMetricsExtraLabelsEnv configuration +// The config is a JSON map where keys are tag names and values are pod label keys to extract +// Labels are read directly from allocation.Labels which is populated by the backend +func (h *HypervisorMetricsRecorder) addExtraLabels(enc metrics.Encoder, allocation *api.WorkerAllocation) { + if len(h.extraLabelsMap) == 0 { + return + } + + if allocation.WorkerInfo == nil || len(allocation.WorkerInfo.Annotations) == 0 { + return + } + + // Add tags based on the mapping + for podLabelKey, tagName := range h.extraLabelsMap { + if labelValue, exists := allocation.WorkerInfo.Annotations[podLabelKey]; exists && labelValue != "" { + enc.AddTag(tagName, labelValue) + } + } +} + +// TelemetryConfig contains optional telemetry parameters +type TelemetryConfig struct { + WorkersCount int + IsolationMode string + SampleGPUModel string + DeviceController framework.DeviceController +} + +// getPostHogClient initializes and returns the PostHog client (singleton) +func getPostHogClient() posthog.Client { + telemetryClientMu.Do(func() { + endpoint := os.Getenv(constants.TelemetryEndpointEnvVar) + if endpoint == "" { + endpoint = constants.DefaultTelemetryEndpoint + } + + pubKey := os.Getenv(constants.TelemetryPublicKeyEnvVar) + if pubKey == "" { + pubKey = constants.DefaultTelemetryPublicKey + } + + client, err := posthog.NewWithConfig(pubKey, posthog.Config{ + Endpoint: endpoint, + }) + if err != nil { + klog.V(4).Infof("Failed to initialize PostHog client: %v", err) + return + } + telemetryClient = client + }) + return telemetryClient +} + +// fileLock and fileUnlock use flock for file locking on Unix-like systems +func fileLock(fd uintptr) error { + return unix.Flock(int(fd), unix.LOCK_EX|unix.LOCK_NB) +} + +func fileUnlock(fd uintptr) error { + return unix.Flock(int(fd), unix.LOCK_UN) +} + +func ShouldSendTelemetry() bool { + if os.Getenv("DISABLE_TENSOR_FUSION_TELEMETRY") != "" { + return false + } + if utils.IsTestMode { + return false + } + + telemetryLockMu.Lock() + defer telemetryLockMu.Unlock() + + // Try to open or create the lock file + telemetryLockFile := filepath.Join(os.TempDir(), "tensor-fusion-telemetry.lock") + file, err := os.OpenFile(telemetryLockFile, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + klog.V(4).Infof("Failed to open telemetry lock file: %v", err) + return false + } + defer func() { + if err := file.Close(); err != nil { + klog.V(4).Infof("Failed to close telemetry lock file: %v", err) + } + }() + + // Try to acquire an exclusive lock (non-blocking) + err = fileLock(file.Fd()) + if err != nil { + klog.V(4).Infof("Failed to acquire telemetry lock: %v", err) + // Lock is already held by another process + return false + } + defer func() { + if err := fileUnlock(file.Fd()); err != nil { + klog.V(4).Infof("Failed to release telemetry lock: %v", err) + } + }() + + // Read and parse the timestamp from the file + var lastSentTime time.Time + if data, err := io.ReadAll(file); err == nil { + if timestamp, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64); err == nil { + lastSentTime = time.Unix(timestamp, 0) + } + } + if !lastSentTime.IsZero() && time.Since(lastSentTime) < telemetryMinInterval { + return false + } + + // Write current timestamp to the file + now := time.Now() + timestampStr := strconv.FormatInt(now.Unix(), 10) + if _, err := file.Seek(0, 0); err != nil { + klog.V(4).Infof("Failed to seek telemetry lock file: %v", err) + return false + } + if err := file.Truncate(0); err != nil { + klog.V(4).Infof("Failed to truncate telemetry lock file: %v", err) + return false + } + if _, err := file.WriteString(timestampStr); err != nil { + klog.V(4).Infof("Failed to write telemetry lock file: %v", err) + return false + } + if err := file.Sync(); err != nil { + klog.V(4).Infof("Failed to sync telemetry lock file: %v", err) + return false + } + return true +} + +// SendAnonymousTelemetry sends Anonymous telemetry data without ANY sensitive data +func SendAnonymousTelemetry(nodeInfo *api.NodeInfo, hardwareVendor string, sampleGPUModel string, workersCount int, isolationMode string) { + // Get PostHog client + client := getPostHogClient() + if client == nil { + klog.V(4).Infof("PostHog client not available, skipping telemetry") + return + } + + // Prepare event properties + properties := posthog.NewProperties(). + Set("ramSizeBytes", nodeInfo.RAMSizeBytes). + Set("totalTFlops", nodeInfo.TotalTFlops). + Set("totalVRAMBytes", nodeInfo.TotalVRAMBytes). + Set("totalDevices", len(nodeInfo.DeviceIDs)). + Set("brand", constants.Domain). + Set("version", version.BuildVersion). + Set("uptime", time.Since(startTime).String()). + Set("workersCount", workersCount). + Set("isolationMode", isolationMode). + Set("vendor", hardwareVendor). + Set("sampleGPUModel", sampleGPUModel) + + // Send event to PostHog + err := client.Enqueue(posthog.Capture{ + Event: "hypervisor_telemetry", + Properties: properties, + }) + if err != nil { + klog.V(4).Infof("Failed to send telemetry: %v", err) + return + } +} diff --git a/internal/hypervisor/server/handlers/device.go b/internal/hypervisor/server/handlers/device.go new file mode 100644 index 00000000..d417fe45 --- /dev/null +++ b/internal/hypervisor/server/handlers/device.go @@ -0,0 +1,67 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// DeviceHandler handles device-related endpoints +type DeviceHandler struct { + deviceController framework.DeviceController +} + +// NewDeviceHandler creates a new device handler +func NewDeviceHandler(deviceController framework.DeviceController) *DeviceHandler { + return &DeviceHandler{ + deviceController: deviceController, + } +} + +// HandleGetDevices handles GET /api/v1/devices +func (h *DeviceHandler) HandleGetDevices(c *gin.Context) { + devices, err := h.deviceController.ListDevices() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + c.JSON(http.StatusOK, api.DataResponse[[]*api.DeviceInfo]{Data: devices}) +} + +// HandleGetDevice handles GET /api/v1/devices/:uuid +func (h *DeviceHandler) HandleGetDevice(c *gin.Context) { + uuid := c.Param("uuid") + device, exists := h.deviceController.GetDevice(uuid) + if !exists { + c.JSON(http.StatusNotFound, api.ErrorResponse{Error: "Device not found"}) + return + } + c.JSON(http.StatusOK, api.DataResponse[*api.DeviceInfo]{Data: device}) +} + +// HandleDiscoverDevices handles POST /api/v1/devices/discover +func (h *DeviceHandler) HandleDiscoverDevices(c *gin.Context) { + if err := h.deviceController.DiscoverDevices(); err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + c.JSON(http.StatusOK, api.StatusResponse{Status: "Device discovery triggered"}) +} diff --git a/internal/hypervisor/server/handlers/health.go b/internal/hypervisor/server/handlers/health.go new file mode 100644 index 00000000..2ccd1167 --- /dev/null +++ b/internal/hypervisor/server/handlers/health.go @@ -0,0 +1,47 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// HealthHandler handles health check endpoints +type HealthHandler struct{} + +// NewHealthHandler creates a new health handler +func NewHealthHandler() *HealthHandler { + return &HealthHandler{} +} + +// HandleHealthz handles GET /healthz +func (h *HealthHandler) HandleHealthz(c *gin.Context) { + c.JSON(http.StatusOK, api.StatusResponse{Status: "ok"}) +} + +// HandleReadyz handles GET /readyz +func (h *HealthHandler) HandleReadyz(c *gin.Context, deviceController framework.DeviceController, workerController framework.WorkerController) { + if deviceController == nil || workerController == nil { + c.JSON(http.StatusServiceUnavailable, api.StatusResponse{Status: "not ready"}) + return + } + c.JSON(http.StatusOK, api.StatusResponse{Status: "ready"}) +} diff --git a/internal/hypervisor/server/handlers/legacy.go b/internal/hypervisor/server/handlers/legacy.go new file mode 100644 index 00000000..e024bf1e --- /dev/null +++ b/internal/hypervisor/server/handlers/legacy.go @@ -0,0 +1,166 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" + "k8s.io/utils/ptr" +) + +// LegacyHandler handles legacy endpoints +type LegacyHandler struct { + workerController framework.WorkerController + backend framework.Backend +} + +// NewLegacyHandler creates a new legacy handler +func NewLegacyHandler(workerController framework.WorkerController, backend framework.Backend) *LegacyHandler { + return &LegacyHandler{ + workerController: workerController, + backend: backend, + } +} + +// HandleGetLimiter handles GET /api/v1/limiter +func (h *LegacyHandler) HandleGetLimiter(c *gin.Context) { + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + limiterInfos := make([]api.LimiterInfo, 0, len(workers)) + for _, worker := range workers { + allocation, exists := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if !exists || allocation == nil { + continue + } + + var requests, limits *tfv1.Resource + if allocation.WorkerInfo != nil { + requests = &allocation.WorkerInfo.Requests + limits = &allocation.WorkerInfo.Limits + } + + limiterInfos = append(limiterInfos, api.LimiterInfo{ + WorkerUID: worker.WorkerUID, + Requests: requests, + Limits: limits, + }) + } + + c.JSON(http.StatusOK, api.ListLimitersResponse{Limiters: limiterInfos}) +} + +// HandleTrap handles POST /api/v1/trap +func (h *LegacyHandler) HandleTrap(c *gin.Context) { + // Trap endpoint: start snapshot low QoS workers to release VRAM + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + snapshotCount := 0 + for _, worker := range workers { + allocation, exists := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if !exists || allocation == nil { + continue + } + + // TODO: Check QoS level and snapshot low QoS workers + // For now, snapshot all workers (this should be filtered by QoS) + snapshotCount++ + } + + c.JSON(http.StatusOK, api.TrapResponse{ + Message: "trap initiated", + SnapshotCount: snapshotCount, + }) +} + +// HandleGetPods handles GET /api/v1/pod +func (h *LegacyHandler) HandleGetPods(c *gin.Context) { + // Only available when k8s backend is enabled + if h.backend == nil { + c.JSON(http.StatusServiceUnavailable, api.ErrorResponse{Error: "kubernetes backend not enabled"}) + return + } + + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + pods := make([]api.PodInfo, 0) + for _, worker := range workers { + allocation, exists := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if !exists || allocation == nil { + continue + } + + var vramLimit *uint64 + var tflopsLimit *float64 + if allocation.WorkerInfo != nil { + if allocation.WorkerInfo.Limits.Vram.Value() > 0 { + vramLimit = ptr.To(uint64(allocation.WorkerInfo.Limits.Vram.Value())) + } + if allocation.WorkerInfo.Limits.Tflops.Value() > 0 { + tflopsLimit = ptr.To(allocation.WorkerInfo.Limits.Tflops.AsApproximateFloat64()) + } + } + pods = append(pods, api.PodInfo{ + PodName: getAllocationPodName(allocation), + Namespace: getAllocationNamespace(allocation), + GPUIDs: getDeviceUUIDs(allocation), + TflopsLimit: tflopsLimit, + VramLimit: vramLimit, + QoSLevel: allocation.WorkerInfo.QoS, + }) + } + + c.JSON(http.StatusOK, api.ListPodsResponse{Pods: pods}) +} + +// Helper functions for WorkerAllocation field access +func getAllocationPodName(allocation *api.WorkerAllocation) string { + if allocation.WorkerInfo != nil { + return allocation.WorkerInfo.WorkerName + } + return "" +} + +func getAllocationNamespace(allocation *api.WorkerAllocation) string { + if allocation.WorkerInfo != nil { + return allocation.WorkerInfo.Namespace + } + return "" +} + +func getDeviceUUIDs(allocation *api.WorkerAllocation) []string { + uuids := make([]string, 0, len(allocation.DeviceInfos)) + for _, device := range allocation.DeviceInfos { + uuids = append(uuids, device.UUID) + } + return uuids +} diff --git a/internal/hypervisor/server/handlers/worker.go b/internal/hypervisor/server/handlers/worker.go new file mode 100644 index 00000000..6e72051e --- /dev/null +++ b/internal/hypervisor/server/handlers/worker.go @@ -0,0 +1,109 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/gin-gonic/gin" +) + +// WorkerHandler handles worker-related endpoints +type WorkerHandler struct { + workerController framework.WorkerController +} + +// NewWorkerHandler creates a new worker handler +func NewWorkerHandler(workerController framework.WorkerController) *WorkerHandler { + return &WorkerHandler{ + workerController: workerController, + } +} + +// HandleGetWorkers handles GET /api/v1/workers +func (h *WorkerHandler) HandleGetWorkers(c *gin.Context) { + workers, err := h.workerController.ListWorkers() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + // Get worker details + workerDetails := make([]*api.WorkerAllocation, 0, len(workers)) + for _, worker := range workers { + allocation, exists := h.workerController.GetWorkerAllocation(worker.WorkerUID) + if !exists || allocation == nil { + continue + } + workerDetails = append(workerDetails, allocation) + } + + c.JSON(http.StatusOK, api.DataResponse[[]*api.WorkerAllocation]{Data: workerDetails}) +} + +// HandleGetWorker handles GET /api/v1/workers/:id +func (h *WorkerHandler) HandleGetWorker(c *gin.Context) { + workerID := c.Param("id") + allocation, exists := h.workerController.GetWorkerAllocation(workerID) + if !exists || allocation == nil { + c.JSON(http.StatusNotFound, api.ErrorResponse{Error: "worker not found"}) + return + } + + // Get worker metrics + workerMetrics, err := h.workerController.GetWorkerMetrics() + if err != nil { + c.JSON(http.StatusInternalServerError, api.ErrorResponse{Error: err.Error()}) + return + } + + metrics, exists := workerMetrics[workerID] + if !exists || metrics == nil { + c.JSON(http.StatusOK, api.DataResponse[map[string]any]{ + Data: map[string]any{ + "worker_uid": workerID, + "allocation": allocation, + }, + }) + return + } + // TODO +} + +// HandleSnapshotWorker handles POST /api/v1/workers/:id/snapshot +func (h *WorkerHandler) HandleSnapshotWorker(c *gin.Context) { + workerID := c.Param("id") + // TODO: Implement actual snapshot logic using accelerator interface + // For now, return success + c.JSON(http.StatusOK, api.MessageAndDataResponse[string]{ + Message: "worker snapshot initiated", + Data: workerID, + }) +} + +// HandleResumeWorker handles POST /api/v1/workers/:id/resume +func (h *WorkerHandler) HandleResumeWorker(c *gin.Context) { + workerID := c.Param("id") + // TODO: Implement actual resume logic using accelerator interface + // For now, return success + c.JSON(http.StatusOK, api.MessageAndDataResponse[string]{ + Message: "worker resume initiated", + Data: workerID, + }) +} diff --git a/internal/hypervisor/server/server.go b/internal/hypervisor/server/server.go new file mode 100644 index 00000000..61cea575 --- /dev/null +++ b/internal/hypervisor/server/server.go @@ -0,0 +1,131 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "context" + "fmt" + "net/http" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/server/handlers" + "github.com/gin-gonic/gin" + "k8s.io/klog/v2" +) + +// MetricsRecorder interface for metrics +type MetricsRecorder interface { + Start() +} + +// Server represents the hypervisor HTTP server +type Server struct { + deviceController framework.DeviceController + workerController framework.WorkerController + metricsRecorder MetricsRecorder + backend framework.Backend + ctx context.Context + router *gin.Engine + httpServer *http.Server + + // Handlers + healthHandler *handlers.HealthHandler + deviceHandler *handlers.DeviceHandler + workerHandler *handlers.WorkerHandler + legacyHandler *handlers.LegacyHandler +} + +// NewServer creates a new hypervisor HTTP server +func NewServer( + ctx context.Context, + deviceController framework.DeviceController, + workerController framework.WorkerController, + metricsRecorder MetricsRecorder, + backend framework.Backend, + port int, +) *Server { + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Logger(), gin.Recovery()) + + // Initialize handlers + healthHandler := handlers.NewHealthHandler() + deviceHandler := handlers.NewDeviceHandler(deviceController) + workerHandler := handlers.NewWorkerHandler(workerController) + legacyHandler := handlers.NewLegacyHandler(workerController, backend) + + s := &Server{ + deviceController: deviceController, + workerController: workerController, + metricsRecorder: metricsRecorder, + backend: backend, + ctx: ctx, + router: router, + httpServer: &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: router, + }, + healthHandler: healthHandler, + deviceHandler: deviceHandler, + workerHandler: workerHandler, + legacyHandler: legacyHandler, + } + + s.setupRoutes() + return s +} + +func (s *Server) setupRoutes() { + // Health check routes + s.router.GET("/healthz", s.healthHandler.HandleHealthz) + s.router.GET("/readyz", func(c *gin.Context) { + s.healthHandler.HandleReadyz(c, s.deviceController, s.workerController) + }) + + // RESTful API routes + // TODO: add authentication and authorization for worker APIs + apiV1 := s.router.Group("/api/v1") + { + // Device routes + apiV1.GET("/devices", s.deviceHandler.HandleGetDevices) + apiV1.GET("/devices/:uuid", s.deviceHandler.HandleGetDevice) + apiV1.POST("/devices/discover", s.deviceHandler.HandleDiscoverDevices) + + // Worker routes + apiV1.GET("/workers", s.workerHandler.HandleGetWorkers) + apiV1.GET("/workers/:id", s.workerHandler.HandleGetWorker) + apiV1.POST("/workers/:id/snapshot", s.workerHandler.HandleSnapshotWorker) + apiV1.POST("/workers/:id/resume", s.workerHandler.HandleResumeWorker) + + // Legacy routes + apiV1.GET("/limiter", s.legacyHandler.HandleGetLimiter) + apiV1.POST("/trap", s.legacyHandler.HandleTrap) + apiV1.GET("/pod", s.legacyHandler.HandleGetPods) + // TODO: should eliminate this API from limiter: apiV1.GET("/process", s.legacyHandler.HandleGetProcesses) + } +} + +// Start starts the HTTP server +func (s *Server) Start() error { + klog.Infof("Starting hypervisor HTTP server on %s", s.httpServer.Addr) + return s.httpServer.ListenAndServe() +} + +// Stop stops the HTTP server +func (s *Server) Stop(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} diff --git a/internal/hypervisor/tui/chart.go b/internal/hypervisor/tui/chart.go new file mode 100644 index 00000000..ed5f1fb4 --- /dev/null +++ b/internal/hypervisor/tui/chart.go @@ -0,0 +1,219 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" +) + +const ( + maxHistorySize = 60 // Keep 60 data points for ~2 minutes at 2s intervals +) + +// TimeSeriesChart represents a time-series chart for metrics +type TimeSeriesChart struct { + data []float64 + width int + height int + maxValue float64 + minValue float64 + label string +} + +// NewTimeSeriesChart creates a new time-series chart +func NewTimeSeriesChart(width, height int, label string) *TimeSeriesChart { + return &TimeSeriesChart{ + data: make([]float64, 0, maxHistorySize), + width: width, + height: height, + maxValue: 100.0, // Default max for percentages + minValue: 0.0, + label: label, + } +} + +// AddDataPoint adds a new data point to the chart +func (c *TimeSeriesChart) AddDataPoint(value float64) { + c.data = append(c.data, value) + if len(c.data) > maxHistorySize { + c.data = c.data[1:] // Remove oldest point + } + + // Auto-scale max value + if value > c.maxValue { + c.maxValue = value * 1.1 // Add 10% padding + } + if value < c.minValue { + c.minValue = value + } +} + +// SetMaxValue sets the maximum value for the chart scale +func (c *TimeSeriesChart) SetMaxValue(max float64) { + c.maxValue = max +} + +// SetDimensions sets the width and height of the chart +func (c *TimeSeriesChart) SetDimensions(width, height int) { + c.width = width + c.height = height +} + +// Render renders the time-series chart as a string +// +//nolint:gocyclo // Complex rendering logic with multiple conditional branches +func (c *TimeSeriesChart) Render() string { + if len(c.data) == 0 { + return fmt.Sprintf("%s: No data\n", c.label) + } + + var result strings.Builder + result.WriteString(fmt.Sprintf("%s (max: %.1f)\n", c.label, c.maxValue)) + + if c.height < 2 { + // Single line mode - just show current value + lastValue := c.data[len(c.data)-1] + result.WriteString(renderBarChart(lastValue, c.width)) + return result.String() + } + + // Multi-line chart + chartHeight := c.height - 1 // Reserve one line for label + if chartHeight < 1 { + chartHeight = 1 + } + + // Create a grid for the chart + grid := make([][]rune, chartHeight) + for i := range grid { + grid[i] = make([]rune, c.width) + for j := range grid[i] { + grid[i][j] = ' ' + } + } + + // Handle edge case: maxValue == minValue + valueRange := c.maxValue - c.minValue + if valueRange == 0 { + valueRange = 1.0 // Avoid division by zero + } + + // Draw the data + dataLen := len(c.data) + if dataLen > c.width { + // Downsample if we have more data points than width + step := float64(dataLen) / float64(c.width) + for x := 0; x < c.width; x++ { + idx := int(float64(x) * step) + if idx >= dataLen { + idx = dataLen - 1 + } + value := c.data[idx] + y := int((c.maxValue - value) / valueRange * float64(chartHeight-1)) + if y < 0 { + y = 0 + } + if y >= chartHeight { + y = chartHeight - 1 + } + grid[y][x] = '█' + + // Draw line connecting to previous point + if x > 0 { + prevIdx := int(float64(x-1) * step) + if prevIdx >= dataLen { + prevIdx = dataLen - 1 + } + prevValue := c.data[prevIdx] + prevY := int((c.maxValue - prevValue) / valueRange * float64(chartHeight-1)) + if prevY < 0 { + prevY = 0 + } + if prevY >= chartHeight { + prevY = chartHeight - 1 + } + + // Draw connecting line + startY, endY := prevY, y + if startY > endY { + startY, endY = endY, startY + } + for lineY := startY; lineY <= endY; lineY++ { + if lineY < chartHeight { + if grid[lineY][x] == ' ' { + grid[lineY][x] = '│' + } + } + } + } + } + } else { + // Draw all data points + for x, value := range c.data { + if x >= c.width { + break + } + y := int((c.maxValue - value) / valueRange * float64(chartHeight-1)) + if y < 0 { + y = 0 + } + if y >= chartHeight { + y = chartHeight - 1 + } + grid[y][x] = '█' + + // Draw connecting line + if x > 0 { + prevValue := c.data[x-1] + prevY := int((c.maxValue - prevValue) / valueRange * float64(chartHeight-1)) + if prevY < 0 { + prevY = 0 + } + if prevY >= chartHeight { + prevY = chartHeight - 1 + } + + startY, endY := prevY, y + if startY > endY { + startY, endY = endY, startY + } + for lineY := startY; lineY <= endY; lineY++ { + if lineY < chartHeight { + if grid[lineY][x] == ' ' { + grid[lineY][x] = '│' + } + } + } + } + } + } + + // Render the grid + for _, row := range grid { + result.WriteString(ChartBarStyle.Render(string(row))) + result.WriteString("\n") + } + + // Add current value + if len(c.data) > 0 { + lastValue := c.data[len(c.data)-1] + result.WriteString(fmt.Sprintf("Current: %.1f", lastValue)) + } + + return result.String() +} diff --git a/internal/hypervisor/tui/client.go b/internal/hypervisor/tui/client.go new file mode 100644 index 00000000..a6368118 --- /dev/null +++ b/internal/hypervisor/tui/client.go @@ -0,0 +1,181 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" +) + +// Client is an HTTP client for fetching data from the hypervisor server +type Client struct { + baseURL string + httpClient *http.Client +} + +// NewClient creates a new HTTP client for the hypervisor +func NewClient(host string, port int) *Client { + return &Client{ + baseURL: fmt.Sprintf("http://%s:%d/api/v1", host, port), + httpClient: &http.Client{ + Timeout: 5 * time.Second, + }, + } +} + +// doRequest performs an HTTP request and decodes the JSON response +// +//nolint:unparam // method parameter is kept for API consistency, even though it's always "GET" +func (c *Client) doRequest(ctx context.Context, method, path string, result any) error { + url := fmt.Sprintf("%s/%s", c.baseURL, path) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("execute request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { + return fmt.Errorf("decode response: %w", err) + } + + return nil +} + +// ListDevices fetches all devices from the hypervisor +func (c *Client) ListDevices(ctx context.Context) ([]*api.DeviceInfo, error) { + var result api.DataResponse[[]*api.DeviceInfo] + if err := c.doRequest(ctx, "GET", "devices", &result); err != nil { + return nil, fmt.Errorf("list devices: %w", err) + } + return result.Data, nil +} + +// GetDevice fetches a specific device by UUID +func (c *Client) GetDevice(ctx context.Context, uuid string) (*api.DeviceInfo, error) { + var result api.DataResponse[*api.DeviceInfo] + if err := c.doRequest(ctx, "GET", fmt.Sprintf("devices/%s", uuid), &result); err != nil { + return nil, fmt.Errorf("get device %s: %w", uuid, err) + } + return result.Data, nil +} + +// GetDeviceAllocations fetches allocations for a specific device +func (c *Client) GetDeviceAllocations(ctx context.Context, uuid string) ([]*api.WorkerAllocation, error) { + workers, err := c.ListWorkers(ctx) + if err != nil { + return nil, fmt.Errorf("list workers: %w", err) + } + + allocations := make([]*api.WorkerAllocation, 0) + for _, worker := range workers { + // Check if any device in the allocation matches the UUID + for _, device := range worker.DeviceInfos { + if device.UUID == uuid { + allocations = append(allocations, worker) + break + } + } + } + + return allocations, nil +} + +// GetGPUMetrics fetches GPU metrics for all devices +// Note: This is a placeholder until a dedicated metrics endpoint is available +func (c *Client) GetGPUMetrics(ctx context.Context) (map[string]*api.GPUUsageMetrics, error) { + // TODO: Implement when metrics endpoint is available + // For now, return empty metrics to avoid errors + return make(map[string]*api.GPUUsageMetrics), nil +} + +// ListWorkers fetches all workers from the hypervisor +func (c *Client) ListWorkers(ctx context.Context) ([]*api.WorkerAllocation, error) { + var result api.DataResponse[[]*api.WorkerAllocation] + if err := c.doRequest(ctx, "GET", "workers", &result); err != nil { + return nil, fmt.Errorf("list workers: %w", err) + } + return result.Data, nil +} + +// GetWorker fetches a specific worker by ID +func (c *Client) GetWorker(ctx context.Context, workerID string) (*api.WorkerAllocation, map[string]map[string]map[string]*api.WorkerMetrics, error) { + type WorkerDetail struct { + WorkerUID string `json:"worker_uid"` + Allocation *api.WorkerAllocation `json:"allocation"` + Metrics map[string]map[string]map[string]*api.WorkerMetrics `json:"metrics,omitempty"` + } + + var result api.DataResponse[WorkerDetail] + if err := c.doRequest(ctx, "GET", fmt.Sprintf("workers/%s", workerID), &result); err != nil { + return nil, nil, fmt.Errorf("get worker %s: %w", workerID, err) + } + return result.Data.Allocation, result.Data.Metrics, nil +} + +// GetWorkerMetrics fetches worker metrics for all workers +// This is optimized to batch requests when possible +func (c *Client) GetWorkerMetrics(ctx context.Context) (map[string]map[string]map[string]*api.WorkerMetrics, error) { + workers, err := c.ListWorkers(ctx) + if err != nil { + return nil, err + } + + metrics := make(map[string]map[string]map[string]*api.WorkerMetrics) + for _, worker := range workers { + // Get WorkerUID from WorkerInfo + if worker.WorkerInfo == nil { + continue + } + workerUID := worker.WorkerInfo.WorkerUID + _, workerMetrics, err := c.GetWorker(ctx, workerUID) + if err != nil { + // Continue on individual worker errors to get as much data as possible + continue + } + + // Merge metrics by device UUID + for deviceUUID, deviceMetrics := range workerMetrics { + if metrics[deviceUUID] == nil { + metrics[deviceUUID] = make(map[string]map[string]*api.WorkerMetrics) + } + // Copy worker metrics for this device + for wUID, wMetrics := range deviceMetrics { + metrics[deviceUUID][wUID] = wMetrics + } + } + } + + return metrics, nil +} diff --git a/internal/hypervisor/tui/device_view.go b/internal/hypervisor/tui/device_view.go new file mode 100644 index 00000000..6238d4ef --- /dev/null +++ b/internal/hypervisor/tui/device_view.go @@ -0,0 +1,148 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "fmt" + "strings" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" +) + +// deviceItem represents a device in the list +type deviceItem struct { + uuid string + model string + index int32 +} + +func (d deviceItem) FilterValue() string { + return fmt.Sprintf("%s %s %d", d.uuid, d.model, d.index) +} + +func (d deviceItem) Title() string { + return fmt.Sprintf("[%d] %s", d.index, d.model) +} + +func (d deviceItem) Description() string { + return d.uuid +} + +func newDeviceDelegate() list.DefaultDelegate { + d := list.NewDefaultDelegate() + d.Styles.SelectedTitle = SelectedStyle + d.Styles.SelectedDesc = SelectedStyle + d.Styles.NormalTitle = NormalStyle + d.Styles.NormalDesc = NormalStyle + return d +} + +// updateDeviceList updates the device list with current devices +func updateDeviceList(deviceList *list.Model, devices []*api.DeviceInfo) { + deviceItems := make([]list.Item, len(devices)) + for i, device := range devices { + deviceItems[i] = deviceItem{ + uuid: device.UUID, + model: device.Model, + index: device.Index, + } + } + deviceList.SetItems(deviceItems) +} + +// updateDeviceDetail updates the device detail viewport +func updateDeviceDetail( + ctx context.Context, + client *Client, + deviceDetail *viewport.Model, + selectedDeviceUUID string, + devices []*api.DeviceInfo, + metrics map[string]*api.GPUUsageMetrics, + deviceMetricsHistory map[string]*DeviceMetricsHistory, +) { + var device *api.DeviceInfo + for _, d := range devices { + if d.UUID == selectedDeviceUUID { + device = d + break + } + } + if device == nil { + deviceDetail.SetContent("Device not found") + return + } + + deviceMetrics, hasMetrics := metrics[device.UUID] + + var content strings.Builder + content.WriteString(TitleStyle.Render("Device Details\n\n")) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(device.UUID))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Vendor"), MetricValueStyle.Render(device.Vendor))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Model"), MetricValueStyle.Render(device.Model))) + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Index"), device.Index)) + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("NUMA Node"), device.NUMANode)) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Total Memory"), formatBytes(device.TotalMemoryBytes))) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n\n", MetricLabelStyle.Render("Max TFLOPS"), device.MaxTflops)) + + if hasMetrics && deviceMetrics != nil { + content.WriteString(TitleStyle.Render("Current Metrics\n\n")) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Memory Usage"), deviceMetrics.MemoryPercentage)) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Used"), formatBytes(deviceMetrics.MemoryBytes))) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Compute Usage"), deviceMetrics.ComputePercentage)) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n", MetricLabelStyle.Render("Compute TFLOPS"), deviceMetrics.ComputeTflops)) + content.WriteString(fmt.Sprintf("%s: %.1f°C\n", MetricLabelStyle.Render("Temperature"), deviceMetrics.Temperature)) + content.WriteString(fmt.Sprintf("%s: %d W\n", MetricLabelStyle.Render("Power Usage"), deviceMetrics.PowerUsage)) + // TODO: handle extra metrics + + // Time-series charts + if history, exists := deviceMetricsHistory[selectedDeviceUUID]; exists && history != nil { + content.WriteString("\n") + content.WriteString(history.MemoryChart.Render()) + content.WriteString("\n") + content.WriteString(history.ComputeChart.Render()) + content.WriteString("\n") + content.WriteString(history.TempChart.Render()) + content.WriteString("\n") + content.WriteString(history.PowerChart.Render()) + content.WriteString("\n") + } + } + + // Get allocations for this device + allocations, err := client.GetDeviceAllocations(ctx, device.UUID) + if err == nil && len(allocations) > 0 { + content.WriteString(TitleStyle.Render("Allocations\n\n")) + for _, alloc := range allocations { + content.WriteString(fmt.Sprintf(" Worker: %s\n", alloc.WorkerInfo.WorkerUID)) + content.WriteString(fmt.Sprintf(" Pod: %s/%s\n", alloc.WorkerInfo.Namespace, alloc.WorkerInfo.WorkerName)) + content.WriteString(fmt.Sprintf(" Mode: %s\n", alloc.WorkerInfo.IsolationMode)) + if alloc.WorkerInfo.Limits.Vram.Value() > 0 { + content.WriteString(fmt.Sprintf(" Memory Limit: %s\n", formatBytes(uint64(alloc.WorkerInfo.Limits.Vram.Value())))) + } + if alloc.WorkerInfo.Limits.Tflops.Value() > 0 { + content.WriteString(fmt.Sprintf(" Compute Limit: %.2f\n", alloc.WorkerInfo.Limits.Tflops.AsApproximateFloat64())) + } + content.WriteString("\n") + } + } + + deviceDetail.SetContent(content.String()) +} diff --git a/internal/hypervisor/tui/metrics_view.go b/internal/hypervisor/tui/metrics_view.go new file mode 100644 index 00000000..1c0b97d1 --- /dev/null +++ b/internal/hypervisor/tui/metrics_view.go @@ -0,0 +1,83 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/viewport" +) + +// updateMetricsView updates the metrics viewport +func updateMetricsView( + metricsView *viewport.Model, + devices []*api.DeviceInfo, + workers []*api.WorkerInfo, + metrics map[string]*api.GPUUsageMetrics, + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics, + lastUpdate time.Time, +) { + var content strings.Builder + content.WriteString(TitleStyle.Render("System Metrics\n\n")) + content.WriteString(fmt.Sprintf("Last Update: %s\n\n", lastUpdate.Format(time.RFC3339))) + + // Device metrics overview + content.WriteString(TitleStyle.Render("Device Metrics Overview\n\n")) + for _, device := range devices { + metrics, hasMetrics := metrics[device.UUID] + content.WriteString(fmt.Sprintf("%s [%s]\n", device.Model, device.UUID[:8])) + if hasMetrics && metrics != nil { + content.WriteString(fmt.Sprintf(" Memory: %.1f%% %s\n", metrics.MemoryPercentage, renderBarChart(metrics.MemoryPercentage, 20))) + content.WriteString(fmt.Sprintf(" Compute: %.1f%% %s\n", metrics.ComputePercentage, renderBarChart(metrics.ComputePercentage, 20))) + content.WriteString(fmt.Sprintf(" Temperature: %.1f°C Power: %dW\n", metrics.Temperature, metrics.PowerUsage)) + } else { + content.WriteString(" No metrics available\n") + } + content.WriteString("\n") + } + + // Worker metrics overview + content.WriteString(TitleStyle.Render("Worker Metrics Overview\n\n")) + for _, worker := range workers { + content.WriteString(fmt.Sprintf("%s/%s\n", worker.Namespace, worker.WorkerName)) + for _, deviceUUID := range worker.AllocatedDevices { + content.WriteString(fmt.Sprintf(" Device: %s\n", deviceUUID)) + if workerMetrics, exists := workerMetrics[deviceUUID]; exists { + if wm, exists := workerMetrics[worker.WorkerUID]; exists { + var totalMemory uint64 + var totalCompute float64 + for _, metrics := range wm { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + } + content.WriteString(fmt.Sprintf(" Memory: %s\n", formatBytes(totalMemory))) + content.WriteString(fmt.Sprintf(" Compute: %.1f%% %s\n", totalCompute, renderBarChart(totalCompute, 20))) + } else { + content.WriteString(" No metrics available\n") + } + } else { + content.WriteString(" No metrics available\n") + } + content.WriteString("\n") + } + } + + metricsView.SetContent(content.String()) +} diff --git a/internal/hypervisor/tui/model.go b/internal/hypervisor/tui/model.go new file mode 100644 index 00000000..66307260 --- /dev/null +++ b/internal/hypervisor/tui/model.go @@ -0,0 +1,552 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "context" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +const ( + viewDevices = iota + viewWorkers + viewMetrics + viewDeviceDetail + viewWorkerDetail +) + +// Model represents the TUI model +type Model struct { + ctx context.Context + client *Client + + currentView int + devices []*api.DeviceInfo + workers []*api.WorkerInfo + metrics map[string]*api.GPUUsageMetrics + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics + + // Metrics history for time-series charts + deviceMetricsHistory map[string]*DeviceMetricsHistory + workerMetricsHistory map[string]*WorkerMetricsHistory + + deviceList list.Model + workerList list.Model + deviceDetail viewport.Model + workerDetail viewport.Model + metricsView viewport.Model + + shmDialog *ShmDialogModel + + selectedDeviceUUID string + selectedWorkerUID string + + width int + height int + + lastUpdate time.Time +} + +// DeviceMetricsHistory tracks historical metrics for a device +type DeviceMetricsHistory struct { + MemoryChart *TimeSeriesChart + ComputeChart *TimeSeriesChart + TempChart *TimeSeriesChart + PowerChart *TimeSeriesChart +} + +// WorkerMetricsHistory tracks historical metrics for a worker +type WorkerMetricsHistory struct { + MemoryChart *TimeSeriesChart + ComputeChart *TimeSeriesChart +} + +type tickMsg time.Time +type updateDataMsg struct { + devices []*api.DeviceInfo + workers []*api.WorkerInfo + metrics map[string]*api.GPUUsageMetrics + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics +} + +// NewModel creates a new TUI model +func NewModel(ctx context.Context, client *Client) *Model { + m := &Model{ + ctx: ctx, + client: client, + currentView: viewDevices, + metrics: make(map[string]*api.GPUUsageMetrics), + workerMetrics: make(map[string]map[string]map[string]*api.WorkerMetrics), + deviceMetricsHistory: make(map[string]*DeviceMetricsHistory), + workerMetricsHistory: make(map[string]*WorkerMetricsHistory), + } + + // Initialize device list + deviceItems := []list.Item{} + m.deviceList = list.New(deviceItems, newDeviceDelegate(), 0, 0) + m.deviceList.Title = "GPU Devices" + m.deviceList.SetShowStatusBar(false) + m.deviceList.SetFilteringEnabled(true) + m.deviceList.Styles.Title = TitleStyle + m.deviceList.Styles.FilterPrompt = SubtitleStyle + m.deviceList.Styles.FilterCursor = SelectedStyle + + // Initialize worker list + workerItems := []list.Item{} + m.workerList = list.New(workerItems, newWorkerDelegate(), 0, 0) + m.workerList.Title = "Workers" + m.workerList.SetShowStatusBar(false) + m.workerList.SetFilteringEnabled(true) + m.workerList.Styles.Title = TitleStyle + m.workerList.Styles.FilterPrompt = SubtitleStyle + m.workerList.Styles.FilterCursor = SelectedStyle + + // Initialize detail viewports + m.deviceDetail = viewport.New(0, 0) + m.workerDetail = viewport.New(0, 0) + m.metricsView = viewport.New(0, 0) + + // Initialize SHM dialog + m.shmDialog = NewShmDialogModel() + + return m +} + +func (m *Model) Init() tea.Cmd { + return tea.Batch( + m.updateData(), + tick(), + ) +} + +func (m *Model) updateData() tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second) + defer cancel() + + // Get devices + devices, err := m.client.ListDevices(ctx) + if err != nil { + devices = []*api.DeviceInfo{} + } + + // Get workers + workerDetails, err := m.client.ListWorkers(ctx) + if err != nil { + workerDetails = []*api.WorkerAllocation{} + } + + workers := make([]*api.WorkerInfo, 0, len(workerDetails)) + for _, worker := range workerDetails { + if worker == nil { + continue + } + workers = append(workers, worker.WorkerInfo) + } + + // Get GPU metrics - for now, we'll need to add a metrics endpoint + // For now, return empty metrics + metrics := make(map[string]*api.GPUUsageMetrics) + + // Get worker metrics + workerMetrics, err := m.client.GetWorkerMetrics(ctx) + if err != nil { + workerMetrics = make(map[string]map[string]map[string]*api.WorkerMetrics) + } + + return updateDataMsg{ + devices: devices, + workers: workers, + metrics: metrics, + workerMetrics: workerMetrics, + } + } +} + +func tick() tea.Cmd { + return tea.Tick(2*time.Second, func(t time.Time) tea.Msg { + return tickMsg(t) + }) +} + +//nolint:gocyclo // Complex state machine with many message types and view transitions +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.resizeViews() + if m.shmDialog != nil { + m.shmDialog.width = msg.Width + m.shmDialog.height = msg.Height + } + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c": + return m, tea.Quit + case "1": + m.currentView = viewDevices + return m, nil + case "2": + m.currentView = viewWorkers + return m, nil + case "3": + m.currentView = viewMetrics + return m, nil + case "esc": + // Close SHM dialog if visible + if m.shmDialog != nil && m.shmDialog.IsVisible() { + m.shmDialog.Hide() + return m, nil + } + if m.currentView == viewDeviceDetail || m.currentView == viewWorkerDetail { + if m.currentView == viewDeviceDetail { + m.currentView = viewDevices + } else { + m.currentView = viewWorkers + } + return m, nil + } + case "enter": + switch m.currentView { + case viewDevices: + if selectedItem := m.deviceList.SelectedItem(); selectedItem != nil { + item := selectedItem.(deviceItem) + m.selectedDeviceUUID = item.uuid + m.currentView = viewDeviceDetail + // Initialize history if needed + if m.deviceMetricsHistory[m.selectedDeviceUUID] == nil { + m.initDeviceHistory(m.selectedDeviceUUID) + } + updateDeviceDetail(m.ctx, m.client, &m.deviceDetail, m.selectedDeviceUUID, m.devices, m.metrics, m.deviceMetricsHistory) + return m, nil + } + case viewWorkers: + if selectedItem := m.workerList.SelectedItem(); selectedItem != nil { + item := selectedItem.(*api.WorkerInfo) + m.selectedWorkerUID = item.WorkerUID + m.currentView = viewWorkerDetail + // Initialize history if needed + if m.workerMetricsHistory[m.selectedWorkerUID] == nil { + m.initWorkerHistory(m.selectedWorkerUID) + } + updateWorkerDetail(&m.workerDetail, m.selectedWorkerUID, m.workers, m.workerMetrics, m.workerMetricsHistory) + return m, nil + } + case viewWorkerDetail: + // Check if SHM dialog is visible, if so, close it + if m.shmDialog != nil && m.shmDialog.IsVisible() { + m.shmDialog.Hide() + return m, nil + } + // Otherwise, show SHM dialog if isolation mode is soft + var worker *api.WorkerInfo + for _, w := range m.workers { + if w.WorkerUID == m.selectedWorkerUID { + worker = w + } + } + if worker != nil { + m.shmDialog.Show(worker) + return m, nil + } + } + } + + case tickMsg: + return m, tea.Batch(m.updateData(), tick()) + + case updateDataMsg: + m.devices = msg.devices + m.workers = msg.workers + m.metrics = msg.metrics + m.workerMetrics = msg.workerMetrics + m.lastUpdate = time.Now() + + // Update metrics history for charts + m.updateMetricsHistory() + + updateDeviceList(&m.deviceList, m.devices) + + workerItems := make([]list.Item, len(m.workers)) + for i, worker := range m.workers { + workerItems[i] = worker + } + m.workerList.SetItems(workerItems) + switch m.currentView { + case viewDeviceDetail: + updateDeviceDetail(m.ctx, m.client, &m.deviceDetail, m.selectedDeviceUUID, m.devices, m.metrics, m.deviceMetricsHistory) + case viewWorkerDetail: + updateWorkerDetail(&m.workerDetail, m.selectedWorkerUID, m.workers, m.workerMetrics, m.workerMetricsHistory) + case viewMetrics: + updateMetricsView(&m.metricsView, m.devices, m.workers, m.metrics, m.workerMetrics, m.lastUpdate) + } + return m, nil + } + + // Update sub-views + // If SHM dialog is visible, it should handle input first + if m.shmDialog != nil && m.shmDialog.IsVisible() { + var cmd tea.Cmd + _, cmd = m.shmDialog.Update(msg) + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) + } + + switch m.currentView { + case viewDevices: + var cmd tea.Cmd + m.deviceList, cmd = m.deviceList.Update(msg) + cmds = append(cmds, cmd) + case viewWorkers: + var cmd tea.Cmd + m.workerList, cmd = m.workerList.Update(msg) + cmds = append(cmds, cmd) + case viewDeviceDetail: + var cmd tea.Cmd + m.deviceDetail, cmd = m.deviceDetail.Update(msg) + cmds = append(cmds, cmd) + case viewWorkerDetail: + var cmd tea.Cmd + m.workerDetail, cmd = m.workerDetail.Update(msg) + cmds = append(cmds, cmd) + case viewMetrics: + var cmd tea.Cmd + m.metricsView, cmd = m.metricsView.Update(msg) + cmds = append(cmds, cmd) + } + + return m, tea.Batch(cmds...) +} + +func (m *Model) resizeViews() { + headerHeight := 3 + footerHeight := 2 + availableHeight := m.height - headerHeight - footerHeight + + switch m.currentView { + case viewDevices: + m.deviceList.SetWidth(m.width) + m.deviceList.SetHeight(availableHeight) + case viewWorkers: + m.workerList.SetWidth(m.width) + m.workerList.SetHeight(availableHeight) + case viewDeviceDetail, viewWorkerDetail, viewMetrics: + width := m.width + height := availableHeight + m.deviceDetail.Width = width + m.deviceDetail.Height = height + m.workerDetail.Width = width + m.workerDetail.Height = height + m.metricsView.Width = width + m.metricsView.Height = height + + // Update chart dimensions when resizing + chartWidth := width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + if m.currentView == viewDeviceDetail && m.selectedDeviceUUID != "" { + if history := m.deviceMetricsHistory[m.selectedDeviceUUID]; history != nil { + history.MemoryChart.SetDimensions(chartWidth, chartHeight) + history.ComputeChart.SetDimensions(chartWidth, chartHeight) + history.TempChart.SetDimensions(chartWidth, chartHeight) + history.PowerChart.SetDimensions(chartWidth, chartHeight) + } + } else if m.currentView == viewWorkerDetail && m.selectedWorkerUID != "" { + if history := m.workerMetricsHistory[m.selectedWorkerUID]; history != nil { + history.MemoryChart.SetDimensions(chartWidth, chartHeight) + history.ComputeChart.SetDimensions(chartWidth, chartHeight) + } + } + } +} + +func (m *Model) View() string { + if m.width == 0 || m.height == 0 { + return "Initializing..." + } + + var view string + switch m.currentView { + case viewDevices: + view = m.deviceList.View() + case viewWorkers: + view = m.workerList.View() + case viewDeviceDetail: + view = m.deviceDetail.View() + case viewWorkerDetail: + view = m.workerDetail.View() + case viewMetrics: + view = m.metricsView.View() + } + + header := m.renderHeader() + footer := m.renderFooter() + + mainView := lipgloss.JoinVertical(lipgloss.Left, header, view, footer) + + // Render SHM dialog on top if visible + if m.shmDialog != nil && m.shmDialog.IsVisible() { + dialogView := m.shmDialog.View() + // The dialog already handles centering, so we just return it + // It will overlay on top of the main view + return dialogView + } + + return mainView +} + +// initDeviceHistory initializes metrics history for a device +func (m *Model) initDeviceHistory(deviceUUID string) { + chartWidth := m.width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + m.deviceMetricsHistory[deviceUUID] = &DeviceMetricsHistory{ + MemoryChart: NewTimeSeriesChart(chartWidth, chartHeight, "Memory Usage"), + ComputeChart: NewTimeSeriesChart(chartWidth, chartHeight, "Compute Usage"), + TempChart: NewTimeSeriesChart(chartWidth, chartHeight, "Temperature"), + PowerChart: NewTimeSeriesChart(chartWidth, chartHeight, "Power Usage"), + } + + // Set max values + m.deviceMetricsHistory[deviceUUID].MemoryChart.SetMaxValue(100.0) + m.deviceMetricsHistory[deviceUUID].ComputeChart.SetMaxValue(100.0) + m.deviceMetricsHistory[deviceUUID].TempChart.SetMaxValue(100.0) // Will auto-scale + m.deviceMetricsHistory[deviceUUID].PowerChart.SetMaxValue(500.0) // Will auto-scale +} + +// initWorkerHistory initializes metrics history for a worker +func (m *Model) initWorkerHistory(workerUID string) { + chartWidth := m.width - 20 + if chartWidth < 40 { + chartWidth = 40 + } + chartHeight := 8 + + m.workerMetricsHistory[workerUID] = &WorkerMetricsHistory{ + MemoryChart: NewTimeSeriesChart(chartWidth, chartHeight, "Memory Usage"), + ComputeChart: NewTimeSeriesChart(chartWidth, chartHeight, "Compute Usage"), + } + + // Set max values + m.workerMetricsHistory[workerUID].MemoryChart.SetMaxValue(100.0) + m.workerMetricsHistory[workerUID].ComputeChart.SetMaxValue(100.0) +} + +// updateMetricsHistory updates the metrics history with current values +func (m *Model) updateMetricsHistory() { + // Update device metrics history + for deviceUUID, metrics := range m.metrics { + if metrics == nil { + continue + } + + history := m.deviceMetricsHistory[deviceUUID] + if history == nil { + // Only initialize if we're viewing this device + if m.currentView == viewDeviceDetail && m.selectedDeviceUUID == deviceUUID { + m.initDeviceHistory(deviceUUID) + history = m.deviceMetricsHistory[deviceUUID] + } else { + continue + } + } + + history.MemoryChart.AddDataPoint(metrics.MemoryPercentage) + history.ComputeChart.AddDataPoint(metrics.ComputePercentage) + history.TempChart.AddDataPoint(metrics.Temperature) + history.PowerChart.AddDataPoint(float64(metrics.PowerUsage)) + } + + // Update worker metrics history + for _, deviceWorkers := range m.workerMetrics { + for workerUID, workerMetrics := range deviceWorkers { + history := m.workerMetricsHistory[workerUID] + if history == nil { + // Only initialize if we're viewing this worker + if m.currentView == viewWorkerDetail && m.selectedWorkerUID == workerUID { + m.initWorkerHistory(workerUID) + history = m.workerMetricsHistory[workerUID] + } else { + continue + } + } + + // Aggregate metrics for this worker + var totalMemory uint64 + var totalCompute float64 + for _, metrics := range workerMetrics { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + } + + // Calculate percentage if we have allocation info + var memPercent float64 + for _, worker := range m.workers { + if worker.WorkerUID == workerUID && worker.Limits.Vram.Value() > 0 { + memPercent = float64(totalMemory) / float64(worker.Limits.Vram.Value()) * 100.0 + break + } + } + + history.MemoryChart.AddDataPoint(memPercent) + history.ComputeChart.AddDataPoint(totalCompute) + } + } +} + +func (m *Model) renderHeader() string { + title := TitleStyle.Render("Tensor Fusion Hypervisor") + tabs := []string{} + tabs = append(tabs, m.renderTab("Devices [1]", m.currentView == viewDevices)) + tabs = append(tabs, m.renderTab("Workers [2]", m.currentView == viewWorkers)) + tabs = append(tabs, m.renderTab("Metrics [3]", m.currentView == viewMetrics)) + tabLine := lipgloss.JoinHorizontal(lipgloss.Left, tabs...) + return lipgloss.JoinVertical(lipgloss.Left, title, tabLine) +} + +func (m *Model) renderTab(text string, active bool) string { + if active { + return SelectedStyle.Render(text) + } + return NormalStyle.Render(text) +} + +func (m *Model) renderFooter() string { + help := "Press 'q' to quit | 'Enter' to view details" + if m.currentView == viewWorkerDetail { + help += " (Enter again for SHM details if soft isolation)" + } + help += " | 'Esc' to go back | '1/2/3' to switch views" + return SubtitleStyle.Render(help) +} diff --git a/internal/hypervisor/tui/shm_dialog.go b/internal/hypervisor/tui/shm_dialog.go new file mode 100644 index 00000000..faa80223 --- /dev/null +++ b/internal/hypervisor/tui/shm_dialog.go @@ -0,0 +1,301 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "path/filepath" + "strings" + "time" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + workerstate "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker/state" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +var ( + shmBasePath = filepath.Join(constants.TFDataPath, constants.SharedMemMountSubPath) +) + +// ShmDialogModel represents the shared memory detail dialog +type ShmDialogModel struct { + viewport viewport.Model + content string + width int + height int + isVisible bool + workerInfo *api.WorkerInfo +} + +// NewShmDialogModel creates a new SHM dialog model +func NewShmDialogModel() *ShmDialogModel { + return &ShmDialogModel{ + viewport: viewport.New(0, 0), + isVisible: false, + } +} + +// Init initializes the dialog +func (m *ShmDialogModel) Init() tea.Cmd { + return nil +} + +// Update updates the dialog +func (m *ShmDialogModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + if !m.isVisible { + return m, nil + } + + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "esc", "q": + m.isVisible = false + return m, nil + } + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.resize() + return m, nil + } + + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + return m, cmd +} + +// View renders the dialog +func (m *ShmDialogModel) View() string { + if !m.isVisible { + return "" + } + + // Calculate dialog dimensions (80% of screen, centered) + dialogWidth := int(float64(m.width) * 0.8) + dialogHeight := int(float64(m.height) * 0.8) + + if dialogWidth < 40 { + dialogWidth = 40 + } + if dialogHeight < 10 { + dialogHeight = 10 + } + + // Create dialog box + box := BorderStyle. + Width(dialogWidth). + Height(dialogHeight). + Render(m.viewport.View()) + + // Center the dialog + return lipgloss.Place( + m.width, + m.height, + lipgloss.Center, + lipgloss.Center, + box, + ) +} + +// Show displays the dialog with SHM details for the given worker +func (m *ShmDialogModel) Show(workerInfo *api.WorkerInfo) { + m.workerInfo = workerInfo + m.isVisible = true + m.resize() + m.updateContent() +} + +// Hide hides the dialog +func (m *ShmDialogModel) Hide() { + m.isVisible = false +} + +// IsVisible returns whether the dialog is visible +func (m *ShmDialogModel) IsVisible() bool { + return m.isVisible +} + +// resize resizes the dialog viewport +func (m *ShmDialogModel) resize() { + if !m.isVisible { + return + } + + dialogWidth := int(float64(m.width) * 0.8) + dialogHeight := int(float64(m.height) * 0.8) + + if dialogWidth < 40 { + dialogWidth = 40 + } + if dialogHeight < 10 { + dialogHeight = 10 + } + + // Account for border + m.viewport.Width = dialogWidth - 2 + m.viewport.Height = dialogHeight - 2 +} + +// updateContent updates the dialog content with SHM details +func (m *ShmDialogModel) updateContent() { + if m.workerInfo == nil { + m.content = "No worker information available" + m.viewport.SetContent(m.content) + return + } + + var content strings.Builder + + // Title + content.WriteString(TitleStyle.Render("Shared Memory Details\n\n")) + + // Construct pod identifier and path + podIdentifier := workerstate.NewPodIdentifier(m.workerInfo.Namespace, m.workerInfo.WorkerName) + podPath := podIdentifier.ToPath(shmBasePath) + shmPath := filepath.Join(podPath, workerstate.ShmPathSuffix) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Pod"), MetricValueStyle.Render(podIdentifier.String()))) + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("SHM Path"), MetricValueStyle.Render(shmPath))) + + // Try to open the shared memory handle + handle, err := workerstate.OpenSharedMemoryHandle(podPath) + if err != nil { + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("Error"), MetricValueStyle.Render(err.Error()))) + m.content = content.String() + m.viewport.SetContent(m.content) + return + } + defer func() { + _ = handle.Close() + }() + + // Get the state + state := handle.GetState() + if state == nil { + content.WriteString(fmt.Sprintf("%s: %s\n\n", MetricLabelStyle.Render("Error"), MetricValueStyle.Render("Shared memory state is null"))) + m.content = content.String() + m.viewport.SetContent(m.content) + return + } + + // Basic information + deviceCount := state.DeviceCount() + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Device Count"), deviceCount)) + + lastHeartbeat := state.GetLastHeartbeat() + heartbeatTime := time.Unix(int64(lastHeartbeat), 0) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Last Heartbeat"), heartbeatTime.Format(time.RFC3339))) + + // Health check (2 seconds timeout) + isHealthy := state.IsHealthy(2 * time.Second) + healthStatus := "Healthy" + if !isHealthy { + healthStatus = "Unhealthy" + } + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Health Status"), MetricValueStyle.Render(healthStatus))) + + // Version information + version := state.Version() + content.WriteString(fmt.Sprintf("%s: v%d\n\n", MetricLabelStyle.Render("State Version"), version)) + + // Device details based on version + if version == 1 && state.V1 != nil { + // V1 format + for i := 0; i < deviceCount; i++ { + if !state.V1.HasDevice(i) { + continue + } + + device := &state.V1.Devices[i] + if !device.IsActive() { + continue + } + + uuid := device.GetUUID() + availableCores := device.DeviceInfo.AvailableCudaCores + totalCores := device.DeviceInfo.TotalCudaCores + memLimit := device.DeviceInfo.MemLimit + podMemoryUsed := device.DeviceInfo.PodMemoryUsed + upLimit := device.DeviceInfo.UpLimit + + content.WriteString(fmt.Sprintf("Device %d:\n", i)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(uuid))) + content.WriteString(fmt.Sprintf(" %s: %d / %d\n", MetricLabelStyle.Render("Cores"), availableCores, totalCores)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Limit"), formatBytes(memLimit))) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Used"), formatBytes(podMemoryUsed))) + content.WriteString(fmt.Sprintf(" %s: %d%%\n\n", MetricLabelStyle.Render("Up Limit"), upLimit)) + } + } else if version == 2 && state.V2 != nil { + // V2 format with ERL + for i := 0; i < deviceCount; i++ { + if !state.V2.HasDevice(i) { + continue + } + + device := &state.V2.Devices[i] + if !device.IsActive() { + continue + } + + uuid := device.GetUUID() + totalCores := device.DeviceInfo.TotalCudaCores + memLimit := device.DeviceInfo.MemLimit + podMemoryUsed := device.DeviceInfo.PodMemoryUsed + upLimit := device.DeviceInfo.UpLimit + + // ERL information + erlCurrentTokens := device.DeviceInfo.GetERLCurrentTokens() + erlTokenCapacity := device.DeviceInfo.GetERLTokenCapacity() + erlTokenRefillRate := device.DeviceInfo.GetERLTokenRefillRate() + erlLastTokenUpdate := device.DeviceInfo.GetERLLastTokenUpdate() + + content.WriteString(fmt.Sprintf("Device %d:\n", i)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("UUID"), MetricValueStyle.Render(uuid))) + content.WriteString(fmt.Sprintf(" %s: %d\n", MetricLabelStyle.Render("Total Cores"), totalCores)) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Limit"), formatBytes(memLimit))) + content.WriteString(fmt.Sprintf(" %s: %s\n", MetricLabelStyle.Render("Mem Used"), formatBytes(podMemoryUsed))) + content.WriteString(fmt.Sprintf(" %s: %d%%\n", MetricLabelStyle.Render("Up Limit"), upLimit)) + content.WriteString(fmt.Sprintf(" %s: %.1f / %.1f (rate: %.1f/s, updated: %.0fµs)\n\n", + MetricLabelStyle.Render("ERL Tokens"), + erlCurrentTokens, + erlTokenCapacity, + erlTokenRefillRate, + erlLastTokenUpdate)) + } + } else { + content.WriteString(fmt.Sprintf("Unknown shared memory version: %d\n\n", version)) + } + + // Additional state information + pids := state.GetAllPIDs() + content.WriteString(fmt.Sprintf("%s: %d\n", MetricLabelStyle.Render("Active PIDs Count"), len(pids))) + if len(pids) > 0 { + pidStrs := make([]string, len(pids)) + for i, pid := range pids { + pidStrs[i] = fmt.Sprintf("%d", pid) + } + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Active PIDs"), strings.Join(pidStrs, ", "))) + } + + m.content = content.String() + m.viewport.SetContent(m.content) + m.viewport.GotoTop() +} diff --git a/internal/hypervisor/tui/styles.go b/internal/hypervisor/tui/styles.go new file mode 100644 index 00000000..6fb4c01d --- /dev/null +++ b/internal/hypervisor/tui/styles.go @@ -0,0 +1,33 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "github.com/charmbracelet/lipgloss" +) + +var ( + TitleStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("63")) + SubtitleStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("241")) + BorderStyle = lipgloss.NewStyle().Border(lipgloss.RoundedBorder()).BorderForeground(lipgloss.Color("62")) + SelectedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("212")).Bold(true) + NormalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("250")) + MetricLabelStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("243")).Width(20) + MetricValueStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Bold(true) + ChartBarStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("46")) + ChartEmptyStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("238")) +) diff --git a/internal/hypervisor/tui/utils.go b/internal/hypervisor/tui/utils.go new file mode 100644 index 00000000..dc8722e0 --- /dev/null +++ b/internal/hypervisor/tui/utils.go @@ -0,0 +1,57 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" +) + +// formatBytes formats bytes into human-readable format +func formatBytes(bytes uint64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// renderBarChart renders a bar chart for a percentage value +// This is a simple wrapper that calls the chart implementation +func renderBarChart(percentage float64, width int) string { + if percentage > 100 { + percentage = 100 + } + if percentage < 0 { + percentage = 0 + } + + filled := int(percentage / 100.0 * float64(width)) + empty := width - filled + + var bar strings.Builder + bar.WriteString(ChartBarStyle.Render(strings.Repeat("█", filled))) + bar.WriteString(ChartEmptyStyle.Render(strings.Repeat("░", empty))) + bar.WriteString(fmt.Sprintf(" %.1f%%", percentage)) + + return bar.String() +} diff --git a/internal/hypervisor/tui/worker_view.go b/internal/hypervisor/tui/worker_view.go new file mode 100644 index 00000000..ce8d5275 --- /dev/null +++ b/internal/hypervisor/tui/worker_view.go @@ -0,0 +1,105 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tui + +import ( + "fmt" + "strings" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/viewport" +) + +func newWorkerDelegate() list.DefaultDelegate { + d := list.NewDefaultDelegate() + d.Styles.SelectedTitle = SelectedStyle + d.Styles.SelectedDesc = SelectedStyle + d.Styles.NormalTitle = NormalStyle + d.Styles.NormalDesc = NormalStyle + return d +} + +// updateWorkerDetail updates the worker detail viewport +func updateWorkerDetail( + workerDetail *viewport.Model, + selectedWorkerUID string, + workers []*api.WorkerInfo, + workerMetrics map[string]map[string]map[string]*api.WorkerMetrics, + workerMetricsHistory map[string]*WorkerMetricsHistory, +) { + var worker *api.WorkerInfo + for _, w := range workers { + if w.WorkerUID == selectedWorkerUID { + worker = w + break + } + } + if worker == nil { + workerDetail.SetContent("Worker not found") + return + } + + var content strings.Builder + content.WriteString(TitleStyle.Render("Worker Details\n\n")) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Worker UID"), MetricValueStyle.Render(worker.WorkerUID))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Pod Name"), MetricValueStyle.Render(worker.WorkerName))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Namespace"), MetricValueStyle.Render(worker.Namespace))) + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Device UUIDs"), MetricValueStyle.Render(strings.Join(worker.AllocatedDevices, ", ")))) + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Isolation Mode"), MetricValueStyle.Render(string(worker.IsolationMode)))) + if worker.Limits.Vram.Value() > 0 { + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Limit"), formatBytes(uint64(worker.Limits.Vram.Value())))) + } + if worker.Limits.Tflops.Value() > 0 { + content.WriteString(fmt.Sprintf("%s: %.2f\n", MetricLabelStyle.Render("Compute Limit"), worker.Limits.Tflops.AsApproximateFloat64())) + } + + // Get worker metrics + for _, deviceUUID := range worker.AllocatedDevices { + if deviceWorkerMetrics, exists := workerMetrics[deviceUUID]; exists { + if wm, exists := deviceWorkerMetrics[worker.WorkerUID]; exists { + content.WriteString(TitleStyle.Render("Current Metrics\n\n")) + var totalMemory uint64 + var totalCompute float64 + var totalTflops float64 + + for _, metrics := range wm { + totalMemory += metrics.MemoryBytes + totalCompute += metrics.ComputePercentage + totalTflops += metrics.ComputeTflops + } + + content.WriteString(fmt.Sprintf("%s: %s\n", MetricLabelStyle.Render("Memory Used"), formatBytes(totalMemory))) + content.WriteString(fmt.Sprintf("%s: %.1f%%\n", MetricLabelStyle.Render("Compute Usage"), totalCompute)) + content.WriteString(fmt.Sprintf("%s: %.2f TFLOPS\n\n", MetricLabelStyle.Render("Compute TFLOPS"), totalTflops)) + + // Time-series charts + if history, exists := workerMetricsHistory[deviceUUID]; exists && history != nil { + content.WriteString("\n") + content.WriteString(history.MemoryChart.Render()) + content.WriteString("\n") + content.WriteString(history.ComputeChart.Render()) + content.WriteString("\n") + } + } + } + } + + workerDetail.SetContent(content.String()) +} diff --git a/internal/hypervisor/worker/computing/erl.go b/internal/hypervisor/worker/computing/erl.go new file mode 100644 index 00000000..e882c738 --- /dev/null +++ b/internal/hypervisor/worker/computing/erl.go @@ -0,0 +1,352 @@ +package computing + +import ( + "errors" + "fmt" + "math" +) + +var ( + ErrInvalidConfig = errors.New("invalid configuration") +) + +// DeviceBackend defines the interface for device token/quota operations +type DeviceBackend interface { + ReadTokenState(device int) (*TokenState, error) + WriteTokenState(device int, state *TokenState) error + ReadQuota(device int) (*DeviceQuota, error) + WriteRefillRate(device int, refillRate float64) error + WriteCapacity(device int, capacity float64) error + FetchSubTokens(device int, cost float64) (float64, error) + FetchAddTokens(device int, amount float64) (float64, error) +} + +// TokenState represents the current token bucket state +type TokenState struct { + Tokens float64 + LastUpdate float64 +} + +// DeviceQuota represents device quota configuration +type DeviceQuota struct { + Capacity float64 + RefillRate float64 +} + +// DeviceControllerConfig holds configuration for the PID-based device controller +type DeviceControllerConfig struct { + // Target GPU utilization (0.0 to 1.0, e.g., 0.5 = 50%) + TargetUtilization float64 + + // Minimum refill rate (tokens/second) - prevents rate from dropping to zero + RateMin float64 + + // Maximum refill rate (tokens/second) + RateMax float64 + + // PID proportional gain - how aggressively to respond to error + Kp float64 + + // PID integral gain - how quickly to eliminate steady-state error + Ki float64 + + // PID derivative gain - how much to dampen oscillations + Kd float64 + + // Low-pass filter coefficient for smoothing utilization (0.0 to 1.0) + // Higher values = less filtering (more responsive, more noise) + FilterAlpha float64 + + // Burst window in seconds - capacity = refill_rate × burst_window + BurstWindow float64 + + // Minimum capacity (tokens) + CapacityMin float64 + + // Maximum capacity (tokens) - prevents unbounded growth + CapacityMax float64 + + // Minimum time between updates (seconds) + MinDeltaTime float64 + + // Integral decay factor (0.0 to 1.0) for exponential decay of integral term + // Higher values (closer to 1.0) = slower decay, retains more history + // Lower values = faster decay, responds more quickly to changes + // Default 0.95 means ~20 update cycles for integral to decay to ~35.8% of original value + IntegralDecayFactor float64 +} + +// DefaultDeviceControllerConfig returns a default configuration +func DefaultDeviceControllerConfig() DeviceControllerConfig { + return DeviceControllerConfig{ + TargetUtilization: 0.5, + RateMin: 10.0, + RateMax: 100_000.0, + Kp: 0.5, + Ki: 0.1, + Kd: 0.05, + FilterAlpha: 0.3, + BurstWindow: 2.0, + CapacityMin: 100.0, + CapacityMax: 200_000.0, + MinDeltaTime: 0.05, + IntegralDecayFactor: 0.95, + } +} + +// DeviceControllerState is a snapshot of controller state after an update +type DeviceControllerState struct { + TargetUtilization float64 + SmoothedUtilization float64 + CurrentRate float64 + CurrentCapacity float64 + TokenDrainRate float64 +} + +// DeviceController is a PID-based controller that dynamically adjusts token refill rates +type DeviceController struct { + backend DeviceBackend + device int + cfg DeviceControllerConfig + + // PID state + integral float64 + lastError float64 + + // Filtering state + smoothedUtil *float64 + + // Rate tracking + currentRate float64 + + // Drain rate estimation + lastTokenLevel float64 + lastTimestamp *float64 +} + +// NewDeviceController creates a new device controller +func NewDeviceController(backend DeviceBackend, device int, cfg DeviceControllerConfig) (*DeviceController, error) { + // Validate configuration + if cfg.TargetUtilization < 0.0 || cfg.TargetUtilization > 1.0 { + return nil, fmt.Errorf("%w: target_utilization must be in [0, 1]", ErrInvalidConfig) + } + if cfg.RateMin <= 0.0 || cfg.RateMax <= cfg.RateMin { + return nil, fmt.Errorf("%w: rate_max must be greater than rate_min > 0", ErrInvalidConfig) + } + if cfg.FilterAlpha < 0.0 || cfg.FilterAlpha > 1.0 { + return nil, fmt.Errorf("%w: filter_alpha must be in [0, 1]", ErrInvalidConfig) + } + if cfg.IntegralDecayFactor < 0.0 || cfg.IntegralDecayFactor > 1.0 { + return nil, fmt.Errorf("%w: integral_decay_factor must be in [0, 1]", ErrInvalidConfig) + } + + // Initialize with a conservative starting rate + startRate := math.Min(100.0, cfg.RateMax) + startRate = math.Max(startRate, cfg.RateMin) + initialCapacity := math.Max(cfg.CapacityMin, math.Min(cfg.CapacityMax, startRate*cfg.BurstWindow)) + + // Initialize backend + if err := backend.WriteCapacity(device, initialCapacity); err != nil { + return nil, err + } + if err := backend.WriteRefillRate(device, startRate); err != nil { + return nil, err + } + + tokenState, err := backend.ReadTokenState(device) + if err != nil { + return nil, err + } + tokenState.Tokens = initialCapacity + if err := backend.WriteTokenState(device, tokenState); err != nil { + return nil, err + } + + return &DeviceController{ + backend: backend, + device: device, + cfg: cfg, + integral: 0.0, + lastError: 0.0, + smoothedUtil: nil, + currentRate: startRate, + lastTokenLevel: initialCapacity, + lastTimestamp: nil, + }, nil +} + +// State returns the current controller state +func (dc *DeviceController) State() DeviceControllerState { + capacity := math.Max(dc.cfg.CapacityMin, math.Min(dc.cfg.CapacityMax, dc.currentRate*dc.cfg.BurstWindow)) + smoothedUtil := 0.0 + if dc.smoothedUtil != nil { + smoothedUtil = *dc.smoothedUtil + } + return DeviceControllerState{ + TargetUtilization: dc.cfg.TargetUtilization, + SmoothedUtilization: smoothedUtil, + CurrentRate: dc.currentRate, + CurrentCapacity: capacity, + TokenDrainRate: 0.0, // Will be updated during next cycle + } +} + +// Update updates controller with new utilization measurement and explicit delta time +func (dc *DeviceController) Update(utilization float64, deltaTime float64) (*DeviceControllerState, error) { + if deltaTime < dc.cfg.MinDeltaTime { + state := dc.State() + return &state, nil + } + return dc.updateInternal(utilization, deltaTime) +} + +// UpdateWithTimestamp updates controller with timestamp (calculates delta automatically) +func (dc *DeviceController) UpdateWithTimestamp(utilization float64, timestampMicros uint64) (*DeviceControllerState, error) { + seconds := float64(timestampMicros) / 1_000_000.0 + var delta float64 + if dc.lastTimestamp != nil { + rawDelta := seconds - *dc.lastTimestamp + if rawDelta < dc.cfg.MinDeltaTime { + state := dc.State() + return &state, nil + } + delta = rawDelta + } else { + delta = dc.cfg.MinDeltaTime + } + dc.lastTimestamp = &seconds + return dc.updateInternal(utilization, delta) +} + +// updateInternal performs the core update logic +func (dc *DeviceController) updateInternal(measuredUtil float64, deltaTime float64) (*DeviceControllerState, error) { + // Clamp measured utilization + measured := math.Max(0.0, math.Min(1.0, measuredUtil)) + + // Step 1: Low-pass filter to smooth NVML noise + smoothed := dc.smoothUtilization(measured) + + // Step 2: Estimate token drain rate + drainRate, err := dc.estimateDrainRate(deltaTime) + if err != nil { + return nil, err + } + + // Step 3: Calculate base rate from drain rate and target + baseRate := dc.calculateBaseRate(smoothed, drainRate) + + // Step 4: Compute PID correction + error := dc.cfg.TargetUtilization - smoothed + correction := dc.computePIDCorrection(error, deltaTime) + + // Step 5: Apply correction to base rate + newRate := math.Max(dc.cfg.RateMin, math.Min(dc.cfg.RateMax, baseRate*(1.0+correction))) + dc.currentRate = newRate + + // Step 6: Calculate capacity (bounded) + newCapacity := math.Max(dc.cfg.CapacityMin, math.Min(dc.cfg.CapacityMax, newRate*dc.cfg.BurstWindow)) + + // Step 7: Refill tokens + refillAmount := newRate * deltaTime + if _, err := dc.backend.FetchAddTokens(dc.device, refillAmount); err != nil { + return nil, err + } + + // Step 8: Update backend (capacity must be updated before clamping) + if err := dc.backend.WriteRefillRate(dc.device, newRate); err != nil { + return nil, err + } + if err := dc.backend.WriteCapacity(dc.device, newCapacity); err != nil { + return nil, err + } + + // Step 9: Clamp tokens to capacity (after capacity update, tokens may exceed new capacity) + // Optimization: only read and write if clamping is needed + state, err := dc.backend.ReadTokenState(dc.device) + if err != nil { + return nil, err + } + if state.Tokens > newCapacity { + state.Tokens = newCapacity + if err := dc.backend.WriteTokenState(dc.device, state); err != nil { + return nil, err + } + } + + return &DeviceControllerState{ + TargetUtilization: dc.cfg.TargetUtilization, + SmoothedUtilization: smoothed, + CurrentRate: newRate, + CurrentCapacity: newCapacity, + TokenDrainRate: drainRate, + }, nil +} + +// smoothUtilization applies exponential moving average to smooth utilization measurements +func (dc *DeviceController) smoothUtilization(measured float64) float64 { + alpha := dc.cfg.FilterAlpha + var smoothed float64 + if dc.smoothedUtil != nil { + smoothed = alpha*measured + (1.0-alpha)**dc.smoothedUtil + } else { + smoothed = measured + } + dc.smoothedUtil = &smoothed + return smoothed +} + +// estimateDrainRate estimates token drain rate from bucket level changes +func (dc *DeviceController) estimateDrainRate(deltaTime float64) (float64, error) { + currentState, err := dc.backend.ReadTokenState(dc.device) + if err != nil { + return 0, err + } + currentTokens := currentState.Tokens + + // Expected tokens = last level + refill during delta_time + expectedTokens := dc.lastTokenLevel + dc.currentRate*deltaTime + + // Actual drain = expected - actual + drainRate := math.Max(0.0, (expectedTokens-currentTokens)/deltaTime) + + dc.lastTokenLevel = currentTokens + return drainRate, nil +} + +// calculateBaseRate calculates base refill rate from current utilization and drain rate +// The idea: if we're at `actual_util` with `drain_rate`, then to reach +// `target_util` we need: `base_rate = drain_rate × (target / actual)` +func (dc *DeviceController) calculateBaseRate(smoothedUtil float64, drainRate float64) float64 { + if smoothedUtil > 0.01 { + // Theoretical base rate to reach target + theoretical := drainRate * (dc.cfg.TargetUtilization / smoothedUtil) + return math.Max(dc.cfg.RateMin, math.Min(dc.cfg.RateMax, theoretical)) + } + // Very low utilization - maintain current rate or use minimum + return math.Max(dc.currentRate, dc.cfg.RateMin) +} + +// computePIDCorrection computes PID correction term +// Returns a correction factor in the range [-0.5, 0.5] to apply to base_rate +func (dc *DeviceController) computePIDCorrection(error float64, deltaTime float64) float64 { + // Proportional term + p := dc.cfg.Kp * error + + // Integral term with exponential decay and anti-windup + // Apply decay factor to forget old errors gradually + dc.integral *= dc.cfg.IntegralDecayFactor + // Add new error contribution + dc.integral += error * deltaTime + // Clamp to prevent windup + dc.integral = math.Max(-1.0, math.Min(1.0, dc.integral)) + i := dc.cfg.Ki * dc.integral + + // Derivative term + derivative := (error - dc.lastError) / deltaTime + d := dc.cfg.Kd * derivative + + dc.lastError = error + + // Total correction, clamped to avoid over-reaction + return math.Max(-0.5, math.Min(0.5, p+i+d)) +} diff --git a/internal/hypervisor/worker/computing/erl_test.go b/internal/hypervisor/worker/computing/erl_test.go new file mode 100644 index 00000000..bb7e5978 --- /dev/null +++ b/internal/hypervisor/worker/computing/erl_test.go @@ -0,0 +1,335 @@ +package computing + +import ( + "math" + "sync" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestERL(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "ERL Controller Suite") +} + +var _ = Describe("DeviceController", func() { + var ( + backend *MockBackend + device int + cfg DeviceControllerConfig + ) + + BeforeEach(func() { + device = 0 + cfg = DefaultDeviceControllerConfig() + cfg.RateMax = 50000.0 + cfg.CapacityMax = 100_000.0 + }) + + Describe("Initialization", func() { + It("should initialize correctly with valid config", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.TargetUtilization = 0.7 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + Expect(ctrl).NotTo(BeNil()) + Expect(ctrl.cfg.TargetUtilization).To(Equal(0.7)) + Expect(ctrl.currentRate).To(BeNumerically(">=", ctrl.cfg.RateMin)) + Expect(ctrl.currentRate).To(BeNumerically("<=", ctrl.cfg.RateMax)) + }) + + It("should reject invalid target_utilization", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.TargetUtilization = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("target_utilization must be in [0, 1]"))) + }) + + It("should reject invalid rate_min/rate_max", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.RateMin = 100.0 + cfg.RateMax = 50.0 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("rate_max must be greater than rate_min"))) + }) + + It("should reject invalid filter_alpha", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.FilterAlpha = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("filter_alpha must be in [0, 1]"))) + }) + + It("should reject invalid integral_decay_factor", func() { + backend = NewMockBackend(0.0, 0.0, 0.0) + cfg.IntegralDecayFactor = 1.5 + + _, err := NewDeviceController(backend, device, cfg) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("integral_decay_factor must be in [0, 1]"))) + }) + }) + + Describe("Rate Adjustment", func() { + It("should increase rate when utilization is below target", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.7 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Utilization 20% when target is 70% -> should increase rate + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateAfter := ctrl.currentRate + Expect(rateAfter).To(BeNumerically(">", rateBefore), "Rate should increase when utilization is below target") + }) + + It("should decrease rate when utilization is above target", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // First establish a higher rate + _, err = ctrl.Update(0.3, 0.1) + Expect(err).NotTo(HaveOccurred()) + _, err = ctrl.Update(0.3, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Now push utilization above target + _, err = ctrl.Update(0.95, 0.1) + Expect(err).NotTo(HaveOccurred()) + + rateAfter := ctrl.currentRate + Expect(rateAfter).To(BeNumerically("<", rateBefore), "Rate should decrease when utilization is above target") + }) + + It("should respect rate limits", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + cfg.RateMin = 50.0 + cfg.RateMax = 500.0 + cfg.CapacityMax = 1000.0 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Try to push rate very low + for i := 0; i < 10; i++ { + _, err = ctrl.Update(0.99, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + Expect(ctrl.currentRate).To(BeNumerically(">=", 50.0), "Rate should not go below rate_min") + + // Try to push rate very high + for i := 0; i < 10; i++ { + _, err = ctrl.Update(0.01, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + Expect(ctrl.currentRate).To(BeNumerically("<=", 500.0), "Rate should not exceed rate_max") + }) + }) + + Describe("Utilization Smoothing", func() { + It("should smooth utilization measurements", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + cfg.FilterAlpha = 0.3 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Feed alternating utilization values + _, err = ctrl.Update(0.8, 0.1) + Expect(err).NotTo(HaveOccurred()) + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + + state := ctrl.State() + // Smoothed value should be between the extremes + Expect(state.SmoothedUtilization).To(BeNumerically(">", 0.2)) + Expect(state.SmoothedUtilization).To(BeNumerically("<", 0.8)) + }) + }) + + Describe("Edge Cases", func() { + It("should handle zero utilization", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Feed zero utilization repeatedly + for i := 0; i < 5; i++ { + _, err = ctrl.Update(0.0, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + + // Rate should still be above minimum + Expect(ctrl.currentRate).To(BeNumerically(">=", ctrl.cfg.RateMin), "Rate should never drop below rate_min") + }) + + It("should handle very small delta_time", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + rateBefore := ctrl.currentRate + + // Update with delta_time smaller than min_delta_time + _, err = ctrl.Update(0.3, 0.001) + Expect(err).NotTo(HaveOccurred()) + + // Rate should not change + Expect(ctrl.currentRate).To(Equal(rateBefore)) + }) + }) + + Describe("Capacity Scaling", func() { + It("should scale capacity with rate", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + state1 := ctrl.State() + + // Continue to increase rate + for i := 0; i < 5; i++ { + _, err = ctrl.Update(0.2, 0.1) + Expect(err).NotTo(HaveOccurred()) + } + + state2 := ctrl.State() + if state2.CurrentRate > state1.CurrentRate { + Expect(state2.CurrentCapacity).To(BeNumerically(">=", state1.CurrentCapacity), "Capacity should scale with rate") + } + }) + }) + + Describe("Timestamp-based Updates", func() { + It("should handle timestamp-based updates", func() { + backend = NewMockBackend(1000.0, 100.0, 500.0) + cfg.TargetUtilization = 0.5 + + ctrl, err := NewDeviceController(backend, device, cfg) + Expect(err).NotTo(HaveOccurred()) + + // Update with timestamps (in microseconds) + t1 := uint64(1_000_000) // 1 second + t2 := uint64(1_200_000) // 1.2 seconds (0.2s delta) + + _, err = ctrl.UpdateWithTimestamp(0.3, t1) + Expect(err).NotTo(HaveOccurred()) + + _, err = ctrl.UpdateWithTimestamp(0.4, t2) + Expect(err).NotTo(HaveOccurred()) + }) + }) +}) + +// MockBackend is a mock implementation of DeviceBackend for testing +type MockBackend struct { + mu sync.RWMutex + quotaCapacity float64 + quotaRefillRate float64 + tokens float64 + lastUpdate float64 +} + +func NewMockBackend(capacity, refillRate, tokens float64) *MockBackend { + return &MockBackend{ + quotaCapacity: capacity, + quotaRefillRate: refillRate, + tokens: tokens, + lastUpdate: 0, + } +} + +func (m *MockBackend) ReadTokenState(device int) (*TokenState, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return &TokenState{ + Tokens: m.tokens, + LastUpdate: m.lastUpdate, + }, nil +} + +func (m *MockBackend) WriteTokenState(device int, state *TokenState) error { + m.mu.Lock() + defer m.mu.Unlock() + m.tokens = state.Tokens + m.lastUpdate = state.LastUpdate + return nil +} + +func (m *MockBackend) ReadQuota(device int) (*DeviceQuota, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return &DeviceQuota{ + Capacity: m.quotaCapacity, + RefillRate: m.quotaRefillRate, + }, nil +} + +func (m *MockBackend) WriteRefillRate(device int, refillRate float64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.quotaRefillRate = refillRate + return nil +} + +func (m *MockBackend) WriteCapacity(device int, capacity float64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.quotaCapacity = capacity + return nil +} + +func (m *MockBackend) FetchSubTokens(device int, cost float64) (float64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + current := m.tokens + if current < cost { + return current, nil + } + + capacity := m.quotaCapacity + newTokens := math.Max(0.0, math.Min(capacity, current-cost)) + m.tokens = newTokens + return current, nil +} + +func (m *MockBackend) FetchAddTokens(device int, amount float64) (float64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + current := m.tokens + capacity := m.quotaCapacity + newTokens := math.Max(0.0, math.Min(capacity, current+amount)) + m.tokens = newTokens + return current, nil +} diff --git a/internal/hypervisor/worker/computing/qos.go b/internal/hypervisor/worker/computing/qos.go new file mode 100644 index 00000000..0bfc86b9 --- /dev/null +++ b/internal/hypervisor/worker/computing/qos.go @@ -0,0 +1,3 @@ +package computing + +// diff --git a/internal/hypervisor/worker/computing/quota_controller.go b/internal/hypervisor/worker/computing/quota_controller.go new file mode 100644 index 00000000..91bb9330 --- /dev/null +++ b/internal/hypervisor/worker/computing/quota_controller.go @@ -0,0 +1,72 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package computing + +import ( + "sync" + + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "k8s.io/klog/v2" +) + +type Controller struct { + deviceController framework.DeviceController + mu sync.RWMutex + running bool + stopCh chan struct{} +} + +func NewQuotaController(deviceController framework.DeviceController) framework.QuotaController { + return &Controller{ + deviceController: deviceController, + stopCh: make(chan struct{}), + } +} + +func (c *Controller) SetQuota(workerUID string) error { + // TODO: Implement quota setting + return nil +} + +func (c *Controller) StartSoftQuotaLimiter() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.running { + return nil + } + c.running = true + // TODO: Start soft quota limiter thread + klog.Info("Soft quota limiter started") + return nil +} + +func (c *Controller) StopSoftQuotaLimiter() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.running { + return nil + } + close(c.stopCh) + c.running = false + klog.Info("Soft quota limiter stopped") + return nil +} + +func (c *Controller) GetWorkerQuotaStatus(workerUID string) error { + // TODO: Implement quota status retrieval + return nil +} diff --git a/internal/hypervisor/worker/controller.go b/internal/hypervisor/worker/controller.go new file mode 100644 index 00000000..e8068c0a --- /dev/null +++ b/internal/hypervisor/worker/controller.go @@ -0,0 +1,185 @@ +package worker + +import ( + "maps" + "sync" + + tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/api" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework" + "github.com/NexusGPU/tensor-fusion/internal/hypervisor/worker/computing" + "github.com/samber/lo" + "k8s.io/klog/v2" +) + +type WorkerController struct { + mode api.IsolationMode + backend framework.Backend + + deviceController framework.DeviceController + quotaController framework.QuotaController + + mu sync.RWMutex + workers map[string]*api.WorkerInfo + workerAllocations map[string]*api.WorkerAllocation +} + +func NewWorkerController( + deviceController framework.DeviceController, mode api.IsolationMode, backend framework.Backend) framework.WorkerController { + quotaController := computing.NewQuotaController(deviceController) + return &WorkerController{ + deviceController: deviceController, + mode: mode, + backend: backend, + quotaController: quotaController, + + workers: make(map[string]*api.WorkerInfo, 32), + workerAllocations: make(map[string]*api.WorkerAllocation, 32), + } +} + +func (w *WorkerController) Start() error { + // Register worker update handler + handler := framework.WorkerChangeHandler{ + OnAdd: func(worker *api.WorkerInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.workers[worker.WorkerUID] = worker + }, + OnRemove: func(worker *api.WorkerInfo) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.workers, worker.WorkerUID) + }, + OnUpdate: func(oldWorker, newWorker *api.WorkerInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.workers[newWorker.WorkerUID] = newWorker + }, + } + + err := w.backend.RegisterWorkerUpdateHandler(handler) + if err != nil { + return err + } + + // Start soft quota limiter + if w.mode == tfv1.IsolationModeSoft { + if err := w.quotaController.StartSoftQuotaLimiter(); err != nil { + klog.Fatalf("Failed to start soft quota limiter: %v", err) + } + klog.Info("Soft quota limiter started") + } + + // Start backend after all handlers are registered + err = w.backend.Start() + if err != nil { + return err + } + klog.Info("Worker backend started") + return nil +} + +func (w *WorkerController) Stop() error { + _ = w.backend.Stop() + _ = w.quotaController.StopSoftQuotaLimiter() + return nil +} + +// AllocateWorker implements framework.WorkerController +func (w *WorkerController) AllocateWorkerDevices(request *api.WorkerInfo) (*api.WorkerAllocation, error) { + // Validate devices exist + w.mu.Lock() + defer w.mu.Unlock() + + deviceInfos := make([]*api.DeviceInfo, 0, len(request.AllocatedDevices)) + + // partitioned mode, call split device + isPartitioned := request.IsolationMode == tfv1.IsolationModePartitioned && request.PartitionTemplateID != "" + + for _, deviceUUID := range request.AllocatedDevices { + if device, exists := w.deviceController.GetDevice(deviceUUID); exists { + if isPartitioned { + deviceInfo, err := w.deviceController.SplitDevice(deviceUUID, request.PartitionTemplateID) + if err != nil { + return nil, err + } + deviceInfos = append(deviceInfos, deviceInfo) + } else { + deviceInfos = append(deviceInfos, device) + } + } + } + + mounts, err := w.deviceController.GetVendorMountLibs() + if err != nil { + klog.Errorf("failed to get vendor mount libs for worker allocation of %s: %v,", request.WorkerUID, err) + return nil, err + } + + envs := make(map[string]string, 8) + devices := make(map[string]*api.DeviceSpec, 8) + for _, deviceInfo := range deviceInfos { + maps.Copy(envs, deviceInfo.DeviceEnv) + for devNode, guestPath := range deviceInfo.DeviceNode { + if _, exists := devices[devNode]; exists { + continue + } + devices[devNode] = &api.DeviceSpec{ + HostPath: devNode, + GuestPath: guestPath, + Permissions: "rwm", + } + } + } + + allocation := &api.WorkerAllocation{ + WorkerInfo: request, + DeviceInfos: deviceInfos, + Envs: envs, + Mounts: mounts, + Devices: lo.Values(devices), + } + + w.workerAllocations[request.WorkerUID] = allocation + for _, deviceUUID := range request.AllocatedDevices { + w.deviceController.AddDeviceAllocation(deviceUUID, allocation) + } + return allocation, nil +} + +func (w *WorkerController) DeallocateWorker(workerUID string) error { + w.mu.Lock() + defer w.mu.Unlock() + allocation, exists := w.workerAllocations[workerUID] + if !exists { + klog.Errorf("worker allocation not found for worker, can not deallocate worker %s", workerUID) + return nil + } + delete(w.workerAllocations, workerUID) + for _, deviceUUID := range allocation.WorkerInfo.AllocatedDevices { + w.deviceController.RemoveDeviceAllocation(deviceUUID, allocation) + } + return nil +} + +func (w *WorkerController) ListWorkers() ([]*api.WorkerInfo, error) { + w.mu.RLock() + defer w.mu.RUnlock() + return lo.Values(w.workers), nil +} + +func (w *WorkerController) GetWorkerAllocation(workerUID string) (*api.WorkerAllocation, bool) { + w.mu.RLock() + defer w.mu.RUnlock() + allocation, exists := w.workerAllocations[workerUID] + return allocation, exists +} + +func (w *WorkerController) GetWorkerMetrics() (map[string]map[string]map[string]*api.WorkerMetrics, error) { + // TODO: implement this + // Get all allocations to know which workers exist + // find process and then get metrics by host processes + // w.deviceController.GetProcessMetrics() + return nil, nil +} diff --git a/internal/hypervisor/worker/state/ctx_migration.go b/internal/hypervisor/worker/state/ctx_migration.go new file mode 100644 index 00000000..4df0094f --- /dev/null +++ b/internal/hypervisor/worker/state/ctx_migration.go @@ -0,0 +1 @@ +package worker diff --git a/internal/hypervisor/worker/state/soft_limiter_shm.go b/internal/hypervisor/worker/state/soft_limiter_shm.go new file mode 100644 index 00000000..c548006b --- /dev/null +++ b/internal/hypervisor/worker/state/soft_limiter_shm.go @@ -0,0 +1,937 @@ +package worker + +import ( + "fmt" + "math" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +// Constants +const ( + MaxProcesses = 2048 + MaxDevices = 16 + MaxUUIDLen = 64 + ShmPathSuffix = "shm" +) + +// RefCountError represents errors in reference count operations +type RefCountError struct { + Type string +} + +func (e *RefCountError) Error() string { + return fmt.Sprintf("ref count error: %s", e.Type) +} + +var ( + ErrRefCountUnderflow = &RefCountError{Type: "underflow"} +) + +// PodIdentifier contains namespace and name +type PodIdentifier struct { + Namespace string + Name string +} + +// NewPodIdentifier creates a new PodIdentifier +func NewPodIdentifier(namespace, name string) *PodIdentifier { + return &PodIdentifier{ + Namespace: namespace, + Name: name, + } +} + +// ToPath returns the path for this pod identifier +func (p *PodIdentifier) ToPath(basePath string) string { + return filepath.Join(basePath, p.Namespace, p.Name) +} + +// FromShmFilePath parses a PodIdentifier from a full shared memory path +// Path format: {base_path}/{namespace}/{name}/shm +func FromShmFilePath(path string) (*PodIdentifier, error) { + path = filepath.Clean(path) + components := strings.Split(path, string(filepath.Separator)) + + // Filter out empty components (from leading/trailing separators) + var filtered []string + for _, comp := range components { + if comp != "" { + filtered = append(filtered, comp) + } + } + components = filtered + + // Need at least: namespace, name, and "shm" (3 components minimum) + if len(components) < 3 { + return nil, fmt.Errorf("invalid path format: %s (need at least namespace/name/shm)", path) + } + + // Extract the last 3 components: {namespace}/{name}/shm + compLen := len(components) + + // Verify the last component is "shm" + if components[compLen-1] != ShmPathSuffix { + return nil, fmt.Errorf("invalid path format: %s (last component must be 'shm')", path) + } + + namespace := components[compLen-3] + name := components[compLen-2] + + // Validate namespace and name are not empty + if namespace == "" || name == "" { + return nil, fmt.Errorf("invalid path format: %s (namespace and name must be non-empty)", path) + } + + return NewPodIdentifier(namespace, name), nil +} + +// String returns the string representation +func (p *PodIdentifier) String() string { + return fmt.Sprintf("%s/%s", p.Namespace, p.Name) +} + +// CleanupEmptyParentDirectories removes empty parent directories after removing a file +func CleanupEmptyParentDirectories(filePath string, stopAtPath *string) error { + parentDir := filepath.Dir(filePath) + + // Skip if we've reached the stop path + if stopAtPath != nil && parentDir == *stopAtPath { + return nil + } + + // Try to remove the immediate parent directory if it's empty + entries, err := os.ReadDir(parentDir) + if err != nil { + return err + } + + if len(entries) == 0 { + if err := os.Remove(parentDir); err != nil { + return err + } + + // Recursively try to remove parent directories if they're also empty + return CleanupEmptyParentDirectories(parentDir, stopAtPath) + } + + return nil +} + +// SharedDeviceInfoV1 is the legacy device state (without ERL) +type SharedDeviceInfoV1 struct { + AvailableCudaCores int32 + UpLimit uint32 + MemLimit uint64 + TotalCudaCores uint32 + PodMemoryUsed uint64 +} + +// SharedDeviceInfoV2 is the V2 device state with ERL support +type SharedDeviceInfoV2 struct { + UpLimit uint32 + MemLimit uint64 + TotalCudaCores uint32 + PodMemoryUsed uint64 + + // ERL (Elastic Rate Limiting) - PID-controlled token bucket + ERLTokenRefillRate uint64 // f64 stored as bits + ERLTokenCapacity uint64 // f64 stored as bits + ERLCurrentTokens uint64 // f64 stored as bits + ERLLastTokenUpdate uint64 // f64 stored as bits +} + +// SharedDeviceInfo is a type alias for backward compatibility +type SharedDeviceInfo = SharedDeviceInfoV2 + +// NewSharedDeviceInfoV1 creates a new V1 device info +func NewSharedDeviceInfoV1(totalCudaCores, upLimit uint32, memLimit uint64) *SharedDeviceInfoV1 { + return &SharedDeviceInfoV1{ + AvailableCudaCores: 0, + UpLimit: upLimit, + MemLimit: memLimit, + TotalCudaCores: totalCudaCores, + PodMemoryUsed: 0, + } +} + +// NewSharedDeviceInfoV2 creates a new V2 device info +func NewSharedDeviceInfoV2(totalCudaCores, upLimit uint32, memLimit uint64) *SharedDeviceInfoV2 { + return &SharedDeviceInfoV2{ + UpLimit: upLimit, + MemLimit: memLimit, + TotalCudaCores: totalCudaCores, + PodMemoryUsed: 0, + ERLTokenRefillRate: math.Float64bits(10.0), // Default 10 tokens/sec + ERLTokenCapacity: math.Float64bits(100.0), + ERLCurrentTokens: math.Float64bits(100.0), + ERLLastTokenUpdate: math.Float64bits(0.0), + } +} + +// DeviceEntryV1 is the legacy device entry +type DeviceEntryV1 struct { + UUID [MaxUUIDLen]byte + DeviceInfo SharedDeviceInfoV1 + IsActiveField uint32 + //nolint:unused // Padding field for memory alignment in shared memory structures + _padding [4]byte +} + +// DeviceEntryV2 is the V2 device entry with ERL +type DeviceEntryV2 struct { + UUID [MaxUUIDLen]byte + DeviceInfo SharedDeviceInfoV2 + IsActiveField uint32 +} + +// DeviceEntry is a type alias for backward compatibility +type DeviceEntry = DeviceEntryV2 + +// NewDeviceEntryV1 creates a new V1 device entry +func NewDeviceEntryV1() *DeviceEntryV1 { + return &DeviceEntryV1{ + DeviceInfo: *NewSharedDeviceInfoV1(0, 0, 0), + } +} + +// NewDeviceEntryV2 creates a new V2 device entry +func NewDeviceEntryV2() *DeviceEntryV2 { + return &DeviceEntryV2{ + DeviceInfo: *NewSharedDeviceInfoV2(0, 0, 0), + } +} + +// SetUUID sets the device UUID +func (d *DeviceEntryV1) SetUUID(uuid string) { + copyLen := len(uuid) + if copyLen > MaxUUIDLen-1 { + copyLen = MaxUUIDLen - 1 + } + + // Clear the UUID array + for i := range d.UUID { + d.UUID[i] = 0 + } + + // Copy the new UUID + copy(d.UUID[:], uuid[:copyLen]) +} + +// GetUUID gets the device UUID as a string +func (d *DeviceEntryV1) GetUUID() string { + nullPos := MaxUUIDLen - 1 + for i, b := range d.UUID { + if b == 0 { + nullPos = i + break + } + } + return string(d.UUID[:nullPos]) +} + +// IsActive checks if this entry is active +func (d *DeviceEntryV1) IsActive() bool { + return atomic.LoadUint32(&d.IsActiveField) != 0 +} + +// SetActive sets the active status +func (d *DeviceEntryV1) SetActive(active bool) { + var val uint32 + if active { + val = 1 + } + atomic.StoreUint32(&d.IsActiveField, val) +} + +// SetUUID sets the device UUID +func (d *DeviceEntryV2) SetUUID(uuid string) { + copyLen := len(uuid) + if copyLen > MaxUUIDLen-1 { + copyLen = MaxUUIDLen - 1 + } + + // Clear the UUID array + for i := range d.UUID { + d.UUID[i] = 0 + } + + // Copy the new UUID + copy(d.UUID[:], uuid[:copyLen]) +} + +// GetUUID gets the device UUID as a string +func (d *DeviceEntryV2) GetUUID() string { + nullPos := MaxUUIDLen - 1 + for i, b := range d.UUID { + if b == 0 { + nullPos = i + break + } + } + return string(d.UUID[:nullPos]) +} + +// IsActive checks if this entry is active +func (d *DeviceEntryV2) IsActive() bool { + return atomic.LoadUint32(&d.IsActiveField) != 0 +} + +// SetActive sets the active status +func (d *DeviceEntryV2) SetActive(active bool) { + var val uint32 + if active { + val = 1 + } + atomic.StoreUint32(&d.IsActiveField, val) +} + +// DeviceConfig contains device configuration information +type DeviceConfig struct { + DeviceIdx uint32 + DeviceUUID string + UpLimit uint32 + MemLimit uint64 + SMCount uint32 + MaxThreadPerSM uint32 + TotalCudaCores uint32 +} + +// SharedDeviceStateV1 is the V1 shared device state +type SharedDeviceStateV1 struct { + Devices [MaxDevices]DeviceEntryV1 + DeviceCountField uint32 + LastHeartbeat uint64 + PIDs *ShmMutex[*PIDSet] +} + +// SharedDeviceStateV2 is the V2 shared device state with ERL +type SharedDeviceStateV2 struct { + Devices [MaxDevices]DeviceEntryV2 + DeviceCountField uint32 + LastHeartbeat uint64 + PIDs *ShmMutex[*PIDSet] +} + +// SharedDeviceState is a versioned enum for compatibility +type SharedDeviceState struct { + V1 *SharedDeviceStateV1 + V2 *SharedDeviceStateV2 +} + +// Version returns the version number +func (s *SharedDeviceState) Version() uint32 { + if s.V1 != nil { + return 1 + } + return 2 +} + +// HasERL checks if this state uses ERL features +func (s *SharedDeviceState) HasERL() bool { + return s.V2 != nil +} + +// NewSharedDeviceStateV1 creates a new V1 state +func NewSharedDeviceStateV1(configs []DeviceConfig) (*SharedDeviceStateV1, error) { + now := uint64(time.Now().Unix()) + + state := &SharedDeviceStateV1{ + DeviceCountField: uint32(len(configs)), + LastHeartbeat: now, + PIDs: NewShmMutex(NewPIDSet()), + } + + for _, config := range configs { + deviceIdx := int(config.DeviceIdx) + if deviceIdx >= MaxDevices { + return nil, fmt.Errorf("device index %d exceeds maximum devices %d", deviceIdx, MaxDevices) + } + + entry := &state.Devices[deviceIdx] + entry.SetUUID(config.DeviceUUID) + entry.DeviceInfo.TotalCudaCores = config.TotalCudaCores + entry.DeviceInfo.AvailableCudaCores = int32(config.TotalCudaCores) + entry.DeviceInfo.UpLimit = config.UpLimit + entry.DeviceInfo.MemLimit = config.MemLimit + entry.SetActive(true) + } + + return state, nil +} + +// NewSharedDeviceStateV2 creates a new V2 state +func NewSharedDeviceStateV2(configs []DeviceConfig) (*SharedDeviceStateV2, error) { + now := uint64(time.Now().Unix()) + + state := &SharedDeviceStateV2{ + DeviceCountField: uint32(len(configs)), + LastHeartbeat: now, + PIDs: NewShmMutex(NewPIDSet()), + } + + for _, config := range configs { + deviceIdx := int(config.DeviceIdx) + if deviceIdx >= MaxDevices { + return nil, fmt.Errorf("device index %d exceeds maximum devices %d", deviceIdx, MaxDevices) + } + + entry := &state.Devices[deviceIdx] + entry.SetUUID(config.DeviceUUID) + entry.DeviceInfo.TotalCudaCores = config.TotalCudaCores + entry.DeviceInfo.UpLimit = config.UpLimit + entry.DeviceInfo.MemLimit = config.MemLimit + + // Initialize ERL fields with defaults + entry.DeviceInfo.ERLTokenCapacity = math.Float64bits(100.0) + entry.DeviceInfo.ERLTokenRefillRate = math.Float64bits(10.0) + entry.DeviceInfo.ERLCurrentTokens = math.Float64bits(100.0) + entry.DeviceInfo.ERLLastTokenUpdate = math.Float64bits(float64(now)) + + entry.SetActive(true) + } + + return state, nil +} + +// NewSharedDeviceState creates a new SharedDeviceState (defaults to V2) +func NewSharedDeviceState(configs []DeviceConfig) (*SharedDeviceState, error) { + v2, err := NewSharedDeviceStateV2(configs) + if err != nil { + return nil, err + } + return &SharedDeviceState{V2: v2}, nil +} + +// HasDevice checks if a device exists at the given index +func (s *SharedDeviceStateV1) HasDevice(index int) bool { + return index < MaxDevices && s.Devices[index].IsActive() +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceStateV1) DeviceCount() int { + return int(atomic.LoadUint32(&s.DeviceCountField)) +} + +// UpdateHeartbeat updates the heartbeat timestamp +func (s *SharedDeviceStateV1) UpdateHeartbeat(timestamp uint64) { + atomic.StoreUint64(&s.LastHeartbeat, timestamp) +} + +// GetLastHeartbeat returns the last heartbeat timestamp +func (s *SharedDeviceStateV1) GetLastHeartbeat() uint64 { + return atomic.LoadUint64(&s.LastHeartbeat) +} + +// IsHealthy checks if the shared memory is healthy based on heartbeat +func (s *SharedDeviceStateV1) IsHealthy(timeout time.Duration) bool { + now := uint64(time.Now().Unix()) + lastHeartbeat := s.GetLastHeartbeat() + + if lastHeartbeat == 0 { + return false + } + + if lastHeartbeat > now { + return false + } + + return now-lastHeartbeat <= uint64(timeout.Seconds()) +} + +// AddPID adds a PID to the set +func (s *SharedDeviceStateV1) AddPID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.InsertIfAbsent(pid) +} + +// RemovePID removes a PID from the set +func (s *SharedDeviceStateV1) RemovePID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.RemoveValue(pid) +} + +// GetAllPIDs returns all PIDs currently stored +func (s *SharedDeviceStateV1) GetAllPIDs() []int { + s.PIDs.Lock() + defer s.PIDs.Unlock() + return s.PIDs.Value.Values() +} + +// CleanupOrphanedLocks cleans up any orphaned locks +func (s *SharedDeviceStateV1) CleanupOrphanedLocks() { + s.PIDs.CleanupOrphanedLock() +} + +// HasDevice checks if a device exists at the given index +func (s *SharedDeviceStateV2) HasDevice(index int) bool { + return index < MaxDevices && s.Devices[index].IsActive() +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceStateV2) DeviceCount() int { + return int(atomic.LoadUint32(&s.DeviceCountField)) +} + +// UpdateHeartbeat updates the heartbeat timestamp +func (s *SharedDeviceStateV2) UpdateHeartbeat(timestamp uint64) { + atomic.StoreUint64(&s.LastHeartbeat, timestamp) +} + +// GetLastHeartbeat returns the last heartbeat timestamp +func (s *SharedDeviceStateV2) GetLastHeartbeat() uint64 { + return atomic.LoadUint64(&s.LastHeartbeat) +} + +// IsHealthy checks if the shared memory is healthy based on heartbeat +func (s *SharedDeviceStateV2) IsHealthy(timeout time.Duration) bool { + now := uint64(time.Now().Unix()) + lastHeartbeat := s.GetLastHeartbeat() + + if lastHeartbeat == 0 { + return false + } + + if lastHeartbeat > now { + return false + } + + return now-lastHeartbeat <= uint64(timeout.Seconds()) +} + +// AddPID adds a PID to the set +func (s *SharedDeviceStateV2) AddPID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.InsertIfAbsent(pid) +} + +// RemovePID removes a PID from the set +func (s *SharedDeviceStateV2) RemovePID(pid int) { + s.PIDs.Lock() + defer s.PIDs.Unlock() + s.PIDs.Value.RemoveValue(pid) +} + +// GetAllPIDs returns all PIDs currently stored +func (s *SharedDeviceStateV2) GetAllPIDs() []int { + s.PIDs.Lock() + defer s.PIDs.Unlock() + return s.PIDs.Value.Values() +} + +// CleanupOrphanedLocks cleans up any orphaned locks +func (s *SharedDeviceStateV2) CleanupOrphanedLocks() { + s.PIDs.CleanupOrphanedLock() +} + +// Helper methods for SharedDeviceState that delegate to the appropriate version + +// HasDevice checks if a device exists +func (s *SharedDeviceState) HasDevice(index int) bool { + if s.V1 != nil { + return s.V1.HasDevice(index) + } + return s.V2.HasDevice(index) +} + +// DeviceCount returns the number of devices +func (s *SharedDeviceState) DeviceCount() int { + if s.V1 != nil { + return s.V1.DeviceCount() + } + return s.V2.DeviceCount() +} + +// UpdateHeartbeat updates the heartbeat +func (s *SharedDeviceState) UpdateHeartbeat(timestamp uint64) { + if s.V1 != nil { + s.V1.UpdateHeartbeat(timestamp) + } else { + s.V2.UpdateHeartbeat(timestamp) + } +} + +// GetLastHeartbeat returns the last heartbeat +func (s *SharedDeviceState) GetLastHeartbeat() uint64 { + if s.V1 != nil { + return s.V1.GetLastHeartbeat() + } + return s.V2.GetLastHeartbeat() +} + +// IsHealthy checks if healthy +func (s *SharedDeviceState) IsHealthy(timeout time.Duration) bool { + if s.V1 != nil { + return s.V1.IsHealthy(timeout) + } + return s.V2.IsHealthy(timeout) +} + +// AddPID adds a PID +func (s *SharedDeviceState) AddPID(pid int) { + if s.V1 != nil { + s.V1.AddPID(pid) + } else { + s.V2.AddPID(pid) + } +} + +// RemovePID removes a PID +func (s *SharedDeviceState) RemovePID(pid int) { + if s.V1 != nil { + s.V1.RemovePID(pid) + } else { + s.V2.RemovePID(pid) + } +} + +// GetAllPIDs returns all PIDs +func (s *SharedDeviceState) GetAllPIDs() []int { + if s.V1 != nil { + return s.V1.GetAllPIDs() + } + return s.V2.GetAllPIDs() +} + +// CleanupOrphanedLocks cleans up orphaned locks +func (s *SharedDeviceState) CleanupOrphanedLocks() { + if s.V1 != nil { + s.V1.CleanupOrphanedLocks() + } else { + s.V2.CleanupOrphanedLocks() + } +} + +// SetPodMemoryUsed sets pod memory used for a device +func (s *SharedDeviceState) SetPodMemoryUsed(index int, memory uint64) bool { + if s.V1 != nil { + if index >= MaxDevices || !s.V1.Devices[index].IsActive() { + return false + } + atomic.StoreUint64(&s.V1.Devices[index].DeviceInfo.PodMemoryUsed, memory) + return true + } + if index >= MaxDevices || !s.V2.Devices[index].IsActive() { + return false + } + atomic.StoreUint64(&s.V2.Devices[index].DeviceInfo.PodMemoryUsed, memory) + return true +} + +// ERL token bucket operations for SharedDeviceInfoV2 + +// GetERLTokenCapacity returns the token capacity +func (d *SharedDeviceInfoV2) GetERLTokenCapacity() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLTokenCapacity)) +} + +// SetERLTokenCapacity sets the token capacity +func (d *SharedDeviceInfoV2) SetERLTokenCapacity(capacity float64) { + atomic.StoreUint64(&d.ERLTokenCapacity, math.Float64bits(capacity)) +} + +// GetERLTokenRefillRate returns the refill rate +func (d *SharedDeviceInfoV2) GetERLTokenRefillRate() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLTokenRefillRate)) +} + +// SetERLTokenRefillRate sets the refill rate +func (d *SharedDeviceInfoV2) SetERLTokenRefillRate(rate float64) { + atomic.StoreUint64(&d.ERLTokenRefillRate, math.Float64bits(rate)) +} + +// GetERLCurrentTokens returns the current tokens +func (d *SharedDeviceInfoV2) GetERLCurrentTokens() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLCurrentTokens)) +} + +// SetERLCurrentTokens sets the current tokens +func (d *SharedDeviceInfoV2) SetERLCurrentTokens(tokens float64) { + atomic.StoreUint64(&d.ERLCurrentTokens, math.Float64bits(tokens)) +} + +// GetERLLastTokenUpdate returns the last token update timestamp +func (d *SharedDeviceInfoV2) GetERLLastTokenUpdate() float64 { + return math.Float64frombits(atomic.LoadUint64(&d.ERLLastTokenUpdate)) +} + +// SetERLLastTokenUpdate sets the last token update timestamp +func (d *SharedDeviceInfoV2) SetERLLastTokenUpdate(timestamp float64) { + atomic.StoreUint64(&d.ERLLastTokenUpdate, math.Float64bits(timestamp)) +} + +// LoadERLTokenState loads the token state atomically +func (d *SharedDeviceInfoV2) LoadERLTokenState() (float64, float64) { + return d.GetERLCurrentTokens(), d.GetERLLastTokenUpdate() +} + +// StoreERLTokenState stores the token state atomically +func (d *SharedDeviceInfoV2) StoreERLTokenState(tokens, timestamp float64) { + d.SetERLCurrentTokens(tokens) + d.SetERLLastTokenUpdate(timestamp) +} + +// LoadERLQuota loads the quota configuration +func (d *SharedDeviceInfoV2) LoadERLQuota() (float64, float64) { + return d.GetERLTokenCapacity(), d.GetERLTokenRefillRate() +} + +// FetchSubERLTokens atomically subtracts tokens and returns the value before subtraction +func (d *SharedDeviceInfoV2) FetchSubERLTokens(cost float64) float64 { + for { + currentBits := atomic.LoadUint64(&d.ERLCurrentTokens) + current := math.Float64frombits(currentBits) + + if current < cost { + return current + } + + newValue := math.Max(0.0, current-cost) + newBits := math.Float64bits(newValue) + + if atomic.CompareAndSwapUint64(&d.ERLCurrentTokens, currentBits, newBits) { + return current + } + } +} + +// FetchAddERLTokens atomically adds tokens (capped at capacity) and returns the value before addition +func (d *SharedDeviceInfoV2) FetchAddERLTokens(amount float64) float64 { + capacity := d.GetERLTokenCapacity() + + for { + currentBits := atomic.LoadUint64(&d.ERLCurrentTokens) + current := math.Float64frombits(currentBits) + + newValue := math.Max(0.0, math.Min(capacity, current+amount)) + newBits := math.Float64bits(newValue) + + if atomic.CompareAndSwapUint64(&d.ERLCurrentTokens, currentBits, newBits) { + return current + } + } +} + +// PIDSet is a set of process IDs with a fixed capacity +type PIDSet struct { + values []int + mu sync.Mutex //nolint:unused // Used via ShmMutex wrapper +} + +// NewPIDSet creates a new PID set +func NewPIDSet() *PIDSet { + return &PIDSet{ + values: make([]int, 0, MaxProcesses), + } +} + +// InsertIfAbsent inserts a value if it's not already present +func (s *PIDSet) InsertIfAbsent(pid int) bool { + for _, v := range s.values { + if v == pid { + return false + } + } + if len(s.values) >= MaxProcesses { + return false + } + s.values = append(s.values, pid) + return true +} + +// RemoveValue removes a value from the set +func (s *PIDSet) RemoveValue(pid int) bool { + for i, v := range s.values { + if v == pid { + s.values = append(s.values[:i], s.values[i+1:]...) + return true + } + } + return false +} + +// Values returns all values in the set +func (s *PIDSet) Values() []int { + result := make([]int, len(s.values)) + copy(result, s.values) + return result +} + +// ShmMutex is a shared memory mutex wrapper +type ShmMutex[T any] struct { + mu sync.Mutex + Value T +} + +// NewShmMutex creates a new shared memory mutex +func NewShmMutex[T any](value T) *ShmMutex[T] { + return &ShmMutex[T]{ + Value: value, + } +} + +// Lock locks the mutex +func (m *ShmMutex[T]) Lock() { + m.mu.Lock() +} + +// Unlock unlocks the mutex +func (m *ShmMutex[T]) Unlock() { + m.mu.Unlock() +} + +// CleanupOrphanedLock cleans up orphaned locks (placeholder for now) +func (m *ShmMutex[T]) CleanupOrphanedLock() { + // In a real implementation, this would check for dead processes + // and release their locks. For now, it's a no-op. +} + +// SharedMemoryHandle manages a shared memory mapping +type SharedMemoryHandle struct { + path string + data []byte + state *SharedDeviceState + file *os.File + fileSize int64 +} + +// CreateSharedMemoryHandle creates a new shared memory handle +func CreateSharedMemoryHandle(podPath string, configs []DeviceConfig) (*SharedMemoryHandle, error) { + shmPath := filepath.Join(podPath, ShmPathSuffix) + + // Create directory if it doesn't exist + if err := os.MkdirAll(podPath, 0755); err != nil { + return nil, fmt.Errorf("failed to create directory: %w", err) + } + + // Calculate size needed for SharedDeviceStateV2 + stateSize := int(unsafe.Sizeof(SharedDeviceStateV2{})) + + // Create or open the file + file, err := os.OpenFile(shmPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return nil, fmt.Errorf("failed to create file: %w", err) + } + + // Truncate to the required size + if err := file.Truncate(int64(stateSize)); err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to truncate file: %w", err) + } + + // Memory map the file + data, err := syscall.Mmap(int(file.Fd()), 0, stateSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to mmap: %w", err) + } + + // Initialize the state + state, err := NewSharedDeviceStateV2(configs) + if err != nil { + _ = syscall.Munmap(data) + _ = file.Close() + return nil, err + } + + // Copy the state to the mapped memory + stateBytes := (*[1 << 30]byte)(unsafe.Pointer(state))[:stateSize:stateSize] + copy(data, stateBytes) + + // Get a pointer to the mapped state + mappedState := (*SharedDeviceStateV2)(unsafe.Pointer(&data[0])) + + // Initialize the PIDs mutex in the mapped memory + // Note: This is a simplified version - in a real implementation, + // you'd need to properly initialize the mutex for shared memory + mappedState.PIDs = NewShmMutex(NewPIDSet()) + + return &SharedMemoryHandle{ + path: shmPath, + data: data, + state: &SharedDeviceState{V2: mappedState}, + file: file, + fileSize: int64(stateSize), + }, nil +} + +// OpenSharedMemoryHandle opens an existing shared memory handle +func OpenSharedMemoryHandle(podPath string) (*SharedMemoryHandle, error) { + shmPath := filepath.Join(podPath, ShmPathSuffix) + + // Open the file + file, err := os.OpenFile(shmPath, os.O_RDWR, 0666) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + + // Get file size + stat, err := file.Stat() + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to stat file: %w", err) + } + + fileSize := stat.Size() + + // Memory map the file + data, err := syscall.Mmap(int(file.Fd()), 0, int(fileSize), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to mmap: %w", err) + } + + // Get a pointer to the mapped state (assume V2 for now) + mappedState := (*SharedDeviceStateV2)(unsafe.Pointer(&data[0])) + + return &SharedMemoryHandle{ + path: shmPath, + data: data, + state: &SharedDeviceState{V2: mappedState}, + file: file, + fileSize: fileSize, + }, nil +} + +// GetState returns the shared device state +func (h *SharedMemoryHandle) GetState() *SharedDeviceState { + return h.state +} + +// Close closes the shared memory handle +func (h *SharedMemoryHandle) Close() error { + if h.data != nil { + _ = syscall.Munmap(h.data) + h.data = nil + } + if h.file != nil { + _ = h.file.Close() + h.file = nil + } + return nil +} + +// Cleanup removes the shared memory file and cleans up empty directories +func (h *SharedMemoryHandle) Cleanup(stopAtPath *string) error { + if err := h.Close(); err != nil { + return err + } + + if err := os.Remove(h.path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove file: %w", err) + } + + if stopAtPath != nil { + return CleanupEmptyParentDirectories(h.path, stopAtPath) + } + return CleanupEmptyParentDirectories(h.path, nil) +} diff --git a/internal/hypervisor/worker/state/soft_limiter_shm_test.go b/internal/hypervisor/worker/state/soft_limiter_shm_test.go new file mode 100644 index 00000000..41d67ff7 --- /dev/null +++ b/internal/hypervisor/worker/state/soft_limiter_shm_test.go @@ -0,0 +1,648 @@ +package worker + +import ( + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testShmBasePath = "/tmp/test_shm" + testDeviceIdx = uint32(0) + testTotalCores = uint32(1024) + testUpLimit = uint32(80) + testMemLimit = uint64(1024 * 1024 * 1024) // 1GB +) + +func createTestConfigs() []DeviceConfig { + return []DeviceConfig{ + { + DeviceIdx: testDeviceIdx, + DeviceUUID: "test-device-uuid", + UpLimit: testUpLimit, + MemLimit: testMemLimit, + TotalCudaCores: testTotalCores, + SMCount: 10, + MaxThreadPerSM: 1024, + }, + } +} + +func TestDeviceEntryBasicOperations(t *testing.T) { + entry := NewDeviceEntryV2() + + // Test UUID operations + entry.SetUUID("test-uuid-123") + assert.Equal(t, "test-uuid-123", entry.GetUUID()) + + // Test active status + assert.False(t, entry.IsActive()) + entry.SetActive(true) + assert.True(t, entry.IsActive()) + entry.SetActive(false) + assert.False(t, entry.IsActive()) + + // Test very long UUID handling + longUUID := strings.Repeat("a", MaxUUIDLen+10) + entry.SetUUID(longUUID) + storedUUID := entry.GetUUID() + assert.Less(t, len(storedUUID), MaxUUIDLen) + assert.Contains(t, storedUUID, "a") +} + +func TestSharedDeviceStateCreationAndBasicOps(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + // Test initial state (V2 by default) + assert.Equal(t, uint32(2), state.Version()) + assert.Equal(t, 1, state.DeviceCount()) + + // Test that heartbeat is initialized to current time (should be non-zero and recent) + heartbeat := state.GetLastHeartbeat() + assert.Greater(t, heartbeat, uint64(0)) + now := uint64(time.Now().Unix()) + assert.Less(t, now-heartbeat, uint64(2)) // Should be within 2 seconds + + // Should be healthy since heartbeat was just set + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test device exists by index + deviceIdx := int(configs[0].DeviceIdx) + assert.True(t, state.HasDevice(deviceIdx)) +} + +func TestSharedDeviceStateHeartbeatFunctionality(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Test initial healthy state (heartbeat is initialized to current time) + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test setting heartbeat to a specific time + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) + + // Test old heartbeat (should be unhealthy) + state.UpdateHeartbeat(now - 60) + assert.False(t, state.IsHealthy(30*time.Second)) +} + +func TestSharedDeviceInfoAtomicOperations(t *testing.T) { + // Test V1 device info (has available_cores) + deviceInfoV1 := NewSharedDeviceInfoV1(testTotalCores, testUpLimit, testMemLimit) + + // Test available cores operations (V1 only) + deviceInfoV1.AvailableCudaCores = 512 + assert.Equal(t, int32(512), deviceInfoV1.AvailableCudaCores) + + deviceInfoV1.AvailableCudaCores = 600 + assert.Equal(t, int32(600), deviceInfoV1.AvailableCudaCores) + + // Test negative values + deviceInfoV1.AvailableCudaCores = -50 + assert.Equal(t, int32(-50), deviceInfoV1.AvailableCudaCores) + + // Test other fields + deviceInfoV1.UpLimit = 90 + assert.Equal(t, uint32(90), deviceInfoV1.UpLimit) + + deviceInfoV1.MemLimit = 2 * 1024 * 1024 * 1024 + assert.Equal(t, uint64(2*1024*1024*1024), deviceInfoV1.MemLimit) + + // Test V2 device info (has ERL fields) + deviceInfoV2 := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + // Test ERL fields - refill rate is now the control parameter + deviceInfoV2.SetERLTokenRefillRate(15.0) + assert.Equal(t, 15.0, deviceInfoV2.GetERLTokenRefillRate()) + + deviceInfoV2.SetERLTokenCapacity(100.0) + assert.Equal(t, 100.0, deviceInfoV2.GetERLTokenCapacity()) + + deviceInfoV2.PodMemoryUsed = 512 * 1024 * 1024 + assert.Equal(t, uint64(512*1024*1024), deviceInfoV2.PodMemoryUsed) +} + +func TestERLTokenBucketPreservesTokensWhenInsufficient(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + + deviceInfo.SetERLCurrentTokens(1.5) + before := deviceInfo.FetchSubERLTokens(2.0) + assert.Equal(t, 1.5, before) + assert.Equal(t, 1.5, deviceInfo.GetERLCurrentTokens()) + + deviceInfo.SetERLCurrentTokens(5.0) + beforeSuccess := deviceInfo.FetchSubERLTokens(2.0) + assert.Equal(t, 5.0, beforeSuccess) + assert.Equal(t, 3.0, deviceInfo.GetERLCurrentTokens()) +} + +func TestSharedMemoryHandleCreateAndOpen(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("handle_create_open", "test") + + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(podPath) + }() + + // Create shared memory + handle1, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + defer func() { + _ = handle1.Close() + }() + + state1 := handle1.GetState() + assert.Equal(t, uint32(2), state1.Version()) + assert.Equal(t, 1, state1.DeviceCount()) + + // Verify shared memory file exists after creation + assert.True(t, fileExists(filepath.Join(podPath, ShmPathSuffix))) + + // Open existing shared memory + handle2, err := OpenSharedMemoryHandle(podPath) + require.NoError(t, err) + defer func() { + _ = handle2.Close() + }() + + state2 := handle2.GetState() + assert.Equal(t, uint32(2), state2.Version()) + assert.Equal(t, 1, state2.DeviceCount()) + + // Verify they access the same memory + deviceIdx := int(configs[0].DeviceIdx) + state1.SetPodMemoryUsed(deviceIdx, 42) + memory := state2.GetPodMemoryUsed(deviceIdx) + assert.Equal(t, uint64(42), memory) +} + +func TestSharedMemoryHandleErrorHandling(t *testing.T) { + _, err := OpenSharedMemoryHandle("non_existent_memory") + assert.Error(t, err) +} + +func TestConcurrentDeviceAccess(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("concurrent_access", "test") + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(podPath) + }() + + handle, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + defer func() { + _ = handle.Close() + }() + + deviceIdx := int(configs[0].DeviceIdx) + var wg sync.WaitGroup + numGoroutines := 5 + iterations := 20 + + // Spawn multiple goroutines doing concurrent access + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + state := handle.GetState() + + for j := 0; j < iterations; j++ { + value := uint64(id*iterations + j) + state.SetPodMemoryUsed(deviceIdx, value) + + time.Sleep(time.Millisecond) + + readValue := state.GetPodMemoryUsed(deviceIdx) + // Value should be valid (set by some goroutine) + assert.GreaterOrEqual(t, readValue, uint64(0)) + assert.Less(t, readValue, uint64(100)) + } + }(i) + } + + wg.Wait() +} + +func TestDeviceIterationMethods(t *testing.T) { + // Create multiple device configurations + configs := []DeviceConfig{ + { + DeviceIdx: 0, + DeviceUUID: "device-0", + UpLimit: 80, + MemLimit: 1024 * 1024 * 1024, + TotalCudaCores: 1024, + SMCount: 10, + MaxThreadPerSM: 1024, + }, + { + DeviceIdx: 2, + DeviceUUID: "device-2", + UpLimit: 70, + MemLimit: 2 * 1024 * 1024 * 1024, + TotalCudaCores: 2048, + SMCount: 20, + MaxThreadPerSM: 1024, + }, + } + + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + // Test iterating over active devices + activeCount := 0 + for i := 0; i < MaxDevices; i++ { + if state.HasDevice(i) { + activeCount++ + } + } + assert.Equal(t, 2, activeCount) + + // Check that indices match the device_idx from configs + assert.True(t, state.HasDevice(0)) + assert.True(t, state.HasDevice(2)) + + // Test deactivating a device and checking + if state.V2 != nil { + state.V2.Devices[2].SetActive(false) + assert.False(t, state.HasDevice(2)) + assert.True(t, state.HasDevice(0)) + } +} + +func TestPIDSetDeduplicatesOnAdd(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Add the same pid multiple times + state.AddPID(1234) + state.AddPID(1234) + state.AddPID(1234) + + pids := state.GetAllPIDs() + assert.Equal(t, 1, len(pids), "should contain only one PID after duplicate adds") + if len(pids) > 0 { + assert.Equal(t, 1234, pids[0]) + } +} + +func TestPIDRemoveByValueWorks(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + state.AddPID(111) + state.AddPID(222) + state.AddPID(333) + + state.RemovePID(222) + + pids := state.GetAllPIDs() + assert.Equal(t, 2, len(pids), "should remove the specified PID") + assert.Contains(t, pids, 111) + assert.Contains(t, pids, 333) + assert.NotContains(t, pids, 222) +} + +func TestPIDSetCapacityAndDuplicateBehavior(t *testing.T) { + state, err := NewSharedDeviceState([]DeviceConfig{}) + require.NoError(t, err) + + // Fill to capacity with unique PIDs + for pid := 0; pid < MaxProcesses; pid++ { + state.AddPID(pid) + } + + pids := state.GetAllPIDs() + assert.Equal(t, MaxProcesses, len(pids), "should reach max capacity with unique PIDs") + + // Adding an existing PID should not change the count + state.AddPID(0) + pidsAfterDup := state.GetAllPIDs() + assert.Equal(t, MaxProcesses, len(pidsAfterDup), "should remain at capacity when inserting duplicate") +} + +func TestCleanupEmptyParentDirectories(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create a file in the pod directory + testFile := filepath.Join(podDir, ShmPathSuffix) + err = os.WriteFile(testFile, []byte("test data"), 0644) + require.NoError(t, err) + + // Verify structure exists + assert.True(t, fileExists(testFile)) + assert.True(t, fileExists(podDir)) + assert.True(t, fileExists(namespaceDir)) + + // Remove the file + err = os.Remove(testFile) + require.NoError(t, err) + + // Test cleanup without stop_at_path (should remove all empty dirs) + err = CleanupEmptyParentDirectories(testFile, nil) + assert.NoError(t, err) + + // Pod directory should be removed + assert.False(t, fileExists(podDir)) + // Namespace directory should be removed + assert.False(t, fileExists(namespaceDir)) +} + +func TestCleanupEmptyParentDirectoriesWithStopAtPath(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create a file in the pod directory + testFile := filepath.Join(podDir, ShmPathSuffix) + err = os.WriteFile(testFile, []byte("test data"), 0644) + require.NoError(t, err) + + // Remove the file + err = os.Remove(testFile) + require.NoError(t, err) + + // Test cleanup with stop_at_path set to base_path + stopAtPath := tempDir + err = CleanupEmptyParentDirectories(testFile, &stopAtPath) + assert.NoError(t, err) + + // Pod directory should be removed + assert.False(t, fileExists(podDir)) + // Namespace directory should be removed + assert.False(t, fileExists(namespaceDir)) + // Base directory should remain (it's the stop_at_path) + assert.True(t, fileExists(tempDir)) +} + +func TestCleanupEmptyParentDirectoriesStopsAtNonEmptyDir(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "test_cleanup_*") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tempDir) + }() + + // Create nested directory structure: base/namespace/podname/ + namespaceDir := filepath.Join(tempDir, "test-namespace") + podDir := filepath.Join(namespaceDir, "test-pod") + err = os.MkdirAll(podDir, 0755) + require.NoError(t, err) + + // Create two files in the pod directory + testFile1 := filepath.Join(podDir, ShmPathSuffix) + testFile2 := filepath.Join(podDir, "other_file") + err = os.WriteFile(testFile1, []byte("test data"), 0644) + require.NoError(t, err) + err = os.WriteFile(testFile2, []byte("other data"), 0644) + require.NoError(t, err) + + // Remove only one file + err = os.Remove(testFile1) + require.NoError(t, err) + + // Test cleanup - should not remove pod directory since it's not empty + stopAtPath := tempDir + err = CleanupEmptyParentDirectories(testFile1, &stopAtPath) + assert.NoError(t, err) + + // Pod directory should still exist (not empty) + assert.True(t, fileExists(podDir)) + assert.True(t, fileExists(namespaceDir)) + assert.True(t, fileExists(testFile2)) +} + +func TestPodIdentifierFromShmFilePath(t *testing.T) { + tests := []struct { + name string + path string + expectError bool + expectedNS string + expectedName string + }{ + { + name: "valid path", + path: "/base/namespace/podname/shm", + expectError: false, + expectedNS: "namespace", + expectedName: "podname", + }, + { + name: "invalid path - too short", + path: "/base/shm", + expectError: true, + }, + { + name: "invalid path - only two components", + path: "/namespace/shm", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pid, err := FromShmFilePath(tt.path) + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, pid) + } else { + assert.NoError(t, err) + assert.NotNil(t, pid) + assert.Equal(t, tt.expectedNS, pid.Namespace) + assert.Equal(t, tt.expectedName, pid.Name) + } + }) + } +} + +func TestPodIdentifierToPath(t *testing.T) { + pid := NewPodIdentifier("test-namespace", "test-pod") + path := pid.ToPath("/base") + expected := filepath.Join("/base", "test-namespace", "test-pod") + assert.Equal(t, expected, path) +} + +func TestSharedDeviceStateSetPodMemoryUsed(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceState(configs) + require.NoError(t, err) + + deviceIdx := int(configs[0].DeviceIdx) + + // Test setting memory + success := state.SetPodMemoryUsed(deviceIdx, 1024*1024*1024) + assert.True(t, success) + + // Test setting memory for non-existent device + success = state.SetPodMemoryUsed(999, 1024) + assert.False(t, success) +} + +func TestERLTokenOperations(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + + // Test initial values + assert.Equal(t, 10.0, deviceInfo.GetERLTokenRefillRate()) + assert.Equal(t, 100.0, deviceInfo.GetERLTokenCapacity()) + assert.Equal(t, 100.0, deviceInfo.GetERLCurrentTokens()) + + // Test setting values + deviceInfo.SetERLTokenRefillRate(50.0) + deviceInfo.SetERLTokenCapacity(200.0) + deviceInfo.SetERLCurrentTokens(150.0) + + assert.Equal(t, 50.0, deviceInfo.GetERLTokenRefillRate()) + assert.Equal(t, 200.0, deviceInfo.GetERLTokenCapacity()) + assert.Equal(t, 150.0, deviceInfo.GetERLCurrentTokens()) + + // Test LoadERLTokenState + tokens, timestamp := deviceInfo.LoadERLTokenState() + assert.Equal(t, 150.0, tokens) + assert.Equal(t, 0.0, timestamp) // Initial timestamp is 0.0 + + // Test StoreERLTokenState + deviceInfo.StoreERLTokenState(175.0, 12345.0) + tokens, timestamp = deviceInfo.LoadERLTokenState() + assert.Equal(t, 175.0, tokens) + assert.Equal(t, 12345.0, timestamp) + + // Test LoadERLQuota + capacity, rate := deviceInfo.LoadERLQuota() + assert.Equal(t, 200.0, capacity) + assert.Equal(t, 50.0, rate) +} + +func TestFetchAddERLTokens(t *testing.T) { + deviceInfo := NewSharedDeviceInfoV2(testTotalCores, testUpLimit, testMemLimit) + deviceInfo.SetERLTokenCapacity(100.0) + deviceInfo.SetERLCurrentTokens(50.0) + + // Add tokens + before := deviceInfo.FetchAddERLTokens(30.0) + assert.Equal(t, 50.0, before) + assert.Equal(t, 80.0, deviceInfo.GetERLCurrentTokens()) + + // Add tokens that would exceed capacity + before = deviceInfo.FetchAddERLTokens(50.0) + assert.Equal(t, 80.0, before) + assert.Equal(t, 100.0, deviceInfo.GetERLCurrentTokens()) // Capped at capacity +} + +func TestSharedDeviceStateV1Operations(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceStateV1(configs) + require.NoError(t, err) + + assert.Equal(t, 1, state.DeviceCount()) + assert.True(t, state.HasDevice(0)) + assert.False(t, state.HasDevice(1)) + + // Test heartbeat + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) +} + +func TestSharedDeviceStateV2Operations(t *testing.T) { + configs := createTestConfigs() + state, err := NewSharedDeviceStateV2(configs) + require.NoError(t, err) + + assert.Equal(t, 1, state.DeviceCount()) + assert.True(t, state.HasDevice(0)) + assert.False(t, state.HasDevice(1)) + + // Test heartbeat + now := uint64(time.Now().Unix()) + state.UpdateHeartbeat(now) + assert.Equal(t, now, state.GetLastHeartbeat()) + assert.True(t, state.IsHealthy(30*time.Second)) +} + +func TestDeviceEntryV1Operations(t *testing.T) { + entry := NewDeviceEntryV1() + + entry.SetUUID("v1-uuid-test") + assert.Equal(t, "v1-uuid-test", entry.GetUUID()) + + assert.False(t, entry.IsActive()) + entry.SetActive(true) + assert.True(t, entry.IsActive()) +} + +func TestSharedMemoryHandleCleanup(t *testing.T) { + configs := createTestConfigs() + identifier := NewPodIdentifier("cleanup_test", "test") + podPath := identifier.ToPath(testShmBasePath) + defer func() { + _ = os.RemoveAll(testShmBasePath) + }() + + handle, err := CreateSharedMemoryHandle(podPath, configs) + require.NoError(t, err) + + shmPath := filepath.Join(podPath, ShmPathSuffix) + assert.True(t, fileExists(shmPath)) + + // Cleanup + stopAtPath := testShmBasePath + err = handle.Cleanup(&stopAtPath) + assert.NoError(t, err) + + // File should be removed + assert.False(t, fileExists(shmPath)) +} + +// Helper function to check if file exists +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +// Helper function to get pod memory used (needed for tests) +func (s *SharedDeviceState) GetPodMemoryUsed(index int) uint64 { + if s.V1 != nil { + if index >= MaxDevices || !s.V1.Devices[index].IsActive() { + return 0 + } + return atomic.LoadUint64(&s.V1.Devices[index].DeviceInfo.PodMemoryUsed) + } + if index >= MaxDevices || !s.V2.Devices[index].IsActive() { + return 0 + } + return atomic.LoadUint64(&s.V2.Devices[index].DeviceInfo.PodMemoryUsed) +} diff --git a/internal/hypervisor/worker/vram/vram_trap.go b/internal/hypervisor/worker/vram/vram_trap.go new file mode 100644 index 00000000..15728f5d --- /dev/null +++ b/internal/hypervisor/worker/vram/vram_trap.go @@ -0,0 +1,3 @@ +package worker + +// diff --git a/internal/indexallocator/indexallocator.go b/internal/indexallocator/indexallocator.go index d839589e..055c23c5 100644 --- a/internal/indexallocator/indexallocator.go +++ b/internal/indexallocator/indexallocator.go @@ -2,38 +2,52 @@ package indexallocator import ( "context" + "encoding/json" "fmt" + "math" + "sync" "sync/atomic" + "time" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/utils" v1 "k8s.io/api/core/v1" "k8s.io/client-go/util/retry" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/manager" ) -const ( - IndexRangeStart = 1 - IndexRangeEnd = 512 -) - -// IndexAllocator manages allocation of 1-512 temporary indices for Pod-to-DevicePlugin communication -// Uses a simple atomic counter that increments from 1 to 512, then wraps around to 1 -// No bitmap tracking needed - index reuse is acceptable after 512 cycles +// IndexAllocator manages allocation of 1-128 temporary indices for Pod-to-DevicePlugin communication +// Uses a simple atomic counter that increments from 1 to 128, then wraps around to 1 +// No bitmap tracking needed - index reuse is acceptable after 128 cycles +// The availability check will be at PostBind stage, detected by pod index annotation on Node level type IndexAllocator struct { IsLeader bool + Client client.Client // Atomic counter for index allocation (1-512, wraps around) - currentIndex int64 + currentIndex int64 + ctx context.Context + storeMutex sync.RWMutex + initializedCh chan struct{} - Client client.Client + // in use index from 0x01 -> 0xf8, indicates the pod using this index + // When pod completed CDI and started or pending image pulling, should be removed from the queue + nodeIndexQueue map[string]map[int]types.NamespacedName - ctx context.Context + podIndexMap map[types.NamespacedName]indexIdentifier + + asyncCheckingMap map[types.NamespacedName]struct{} +} + +type indexIdentifier struct { + nodeName string + index int } func NewIndexAllocator(ctx context.Context, client client.Client) (*IndexAllocator, error) { @@ -42,10 +56,15 @@ func NewIndexAllocator(ctx context.Context, client client.Client) (*IndexAllocat } allocator := &IndexAllocator{ - Client: client, - IsLeader: false, - currentIndex: 0, // Will start from 1 on first assignment - ctx: ctx, + Client: client, + IsLeader: false, + currentIndex: 0, // Will start from 1 on first assignment + ctx: ctx, + initializedCh: make(chan struct{}), + + nodeIndexQueue: make(map[string]map[int]types.NamespacedName, 128), + + podIndexMap: make(map[types.NamespacedName]indexIdentifier, 128), } return allocator, nil @@ -56,34 +75,15 @@ func (s *IndexAllocator) SetupWithManager(ctx context.Context, mgr manager.Manag _ = mgr.Add(manager.RunnableFunc(func(ctx context.Context) error { <-mgr.Elected() s.IsLeader = true - leaderInfo := &v1.ConfigMap{ - ObjectMeta: metav1.ObjectMeta{ - Name: constants.LeaderInfoConfigMapName, - Namespace: utils.CurrentNamespace(), - }, - } - err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { - _, err := controllerutil.CreateOrUpdate(ctx, s.Client, leaderInfo, func() error { - leaderInfo.Data = map[string]string{ - constants.LeaderInfoConfigMapLeaderIPKey: utils.CurrentIP(), - } - return nil - }) - return err - }) - if err != nil { - log.FromContext(ctx).Error(err, "Failed to update leader IP info in ConfigMap") - } - readyCh <- struct{}{} return nil })) return readyCh } -// AssignIndex assigns a temporary index (1-512) for Pod-to-DevicePlugin communication +// AssignIndex assigns a temporary index (1-128) for Pod-to-DevicePlugin communication // Uses atomic increment to ensure thread-safe assignment -// Index wraps around from 512 to 1 (simple modulo operation) +// Index wraps around from 128 to 1 (simple modulo operation) func (s *IndexAllocator) AssignIndex(podName string) (int, error) { if !s.IsLeader { log.FromContext(s.ctx).Error(nil, "only leader can assign index", "podName", podName) @@ -91,7 +91,212 @@ func (s *IndexAllocator) AssignIndex(podName string) (int, error) { } // Atomic increment and wrap around next := atomic.AddInt64(&s.currentIndex, 1) - index := int((next-1)%IndexRangeEnd) + IndexRangeStart + index := int((next-1)%(constants.IndexModLength*constants.IndexKeyLength)) + 1 log.FromContext(s.ctx).Info("assigned index successfully", "podName", podName, "index", index) return index, nil } + +// ReconcileLockState maintains memory state for node level index assign and release queue +func (s *IndexAllocator) ReconcileLockState(pod *v1.Pod) { + if pod.Labels[constants.LabelComponent] != constants.ComponentWorker { + return + } + // Check if it's TF indexed Pod by container resource limits + // If isIndex But PodIndex not set, check phase, if pending, should assign index, next check + if pod.Spec.NodeName == "" { + return + } + + index, err := utils.ParsePodIndexResourceClaim(pod) + if err != nil { + log.FromContext(s.ctx).Error(err, "not TF indexed Pod, skip reconcile lock state", "pod", pod.Name) + return + } + _, indexAllocated := pod.Annotations[constants.PodIndexAnnotation] + + // Only pending pods can occupy the node level index + if utils.IsPodPending(pod) { + s.storeMutex.Lock() + indexQueue := s.nodeIndexQueue[pod.Spec.NodeName] + if indexQueue == nil { + indexQueue = make(map[int]types.NamespacedName) + s.nodeIndexQueue[pod.Spec.NodeName] = indexQueue + } + + // If just started and missing in memory, should complement the index queue and pod index map + if indexAllocated { + // occupy the index if missing (when scheduler restarted) + if _, exists := indexQueue[index]; !exists { + podMeta := types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + } + indexQueue[index] = podMeta + s.podIndexMap[podMeta] = indexIdentifier{ + nodeName: pod.Spec.NodeName, + index: index, + } + } + s.storeMutex.Unlock() + return + } + + if podMeta, exists := indexQueue[index]; exists { + // If already occupied by other Pod, check if it's the same Pod + if podMeta.Namespace != pod.Namespace || podMeta.Name != pod.Name { + log.FromContext(s.ctx).Error(fmt.Errorf("pod index conflict"), "can not reconcile index lock, more than one pending pods occupy the same index", "pod", pod.Name, "index", index) + s.storeMutex.Unlock() + return + } + } else { + // new Pod occupy the index, add to index queue + indexQueue[index] = types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + } + s.podIndexMap[types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }] = indexIdentifier{ + nodeName: pod.Spec.NodeName, + index: index, + } + s.storeMutex.Unlock() + // Brand new pending pod, ensure the async checking loop for assigning index annotation + s.AsyncCheckNodeIndexAvailableAndAssign(pod, index) + } + } else if utils.IsPodRunning(pod) { + s.RemoveNodeIndexQueueForPod(types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }) + } +} + +func (s *IndexAllocator) RemoveNodeIndexQueueForPod(namespacedName types.NamespacedName) { + s.storeMutex.Lock() + defer s.storeMutex.Unlock() + + indexIdentifier, exists := s.podIndexMap[namespacedName] + if !exists { + return + } + if indexQueue, exists := s.nodeIndexQueue[indexIdentifier.nodeName]; exists { + if val, exists := indexQueue[indexIdentifier.index]; exists { + if val.Namespace == namespacedName.Namespace && val.Name == namespacedName.Name { + delete(indexQueue, indexIdentifier.index) + log.FromContext(s.ctx).Info("Removed pod from node index queue after pod running/stopped/deleted", "pod", namespacedName, "index", indexIdentifier.index) + } + delete(s.podIndexMap, namespacedName) + } + } +} + +func (s *IndexAllocator) CheckNodeIndexAndTryOccupy(pod *v1.Pod, index int) bool { + <-s.initializedCh + nodeName := pod.Spec.NodeName + if nodeName == "" { + // should not happen, unscheduled pod + return false + } + s.storeMutex.RLock() + indexQueue := s.nodeIndexQueue[nodeName] + if len(indexQueue) == 0 { + s.storeMutex.RUnlock() + return false + } + _, exists := indexQueue[index] + s.storeMutex.RUnlock() + // Occupy index for node + if !exists { + s.storeMutex.Lock() + indexQueue[index] = types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + } + s.storeMutex.Unlock() + return true + } + return false +} + +func (s *IndexAllocator) SetReady() { + close(s.initializedCh) +} + +func (s *IndexAllocator) AsyncCheckNodeIndexAvailableAndAssign(pod *v1.Pod, index int) { + s.storeMutex.Lock() + defer s.storeMutex.Unlock() + podMeta := types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + } + if _, exists := s.asyncCheckingMap[podMeta]; exists { + // already started checking loop, skip + return + } + s.asyncCheckingMap[podMeta] = struct{}{} + + go func() { + defer func() { + s.storeMutex.Lock() + delete(s.asyncCheckingMap, types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }) + s.storeMutex.Unlock() + }() + + // Infinity backoff retry until index is available, and also reconcile started + _ = retry.OnError(wait.Backoff{ + Duration: 3 * time.Second, + Factor: 1.4, + Jitter: 0.1, + Steps: math.MaxInt32, + Cap: 60 * time.Minute, + }, func(err error) bool { + return true + }, func() error { + pod := &v1.Pod{} + if err := s.Client.Get(s.ctx, client.ObjectKeyFromObject(pod), pod); err != nil { + if errors.IsNotFound(err) { + // pod is deleted, stop retrying + return nil + } + return err + } + if utils.IsPodStopped(pod) { + return nil + } + // Skip if index is already assigned or no annotation + if pod.Annotations == nil || pod.Annotations[constants.PodIndexAnnotation] != "" { + if utils.IsPodRunning(pod) { + log.FromContext(s.ctx).Info("[WARNING] pod is running without index allocation hypervisor may not working", + "pod", pod.Name, "node", pod.Spec.NodeName) + return nil + } + // else do nothing, may caused by duplicated reconciling + } + + if !s.CheckNodeIndexAndTryOccupy(pod, index) { + return fmt.Errorf("index is not available") + } + // Index available, patch annotation to transit Pod from Pending to DeviceAllocating in hypervisor + patchOps := map[string]any{ + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.PodIndexAnnotation), + "value": index, + } + patchBytes, err := json.Marshal(patchOps) + if err != nil { + return err + } + err = s.Client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patchBytes)) + if err != nil { + log.FromContext(s.ctx).Error(err, "failed to patch pod index annotation", "pod", pod.Name, "index", index) + return err + } + return nil + }) + }() +} diff --git a/internal/metrics/connect.go b/internal/metrics/connect.go index 1e931422..3b64ec85 100644 --- a/internal/metrics/connect.go +++ b/internal/metrics/connect.go @@ -153,7 +153,7 @@ func (t *TimeSeriesDB) SetTableTTL(ttl string) error { func (t *TimeSeriesDB) FindRecentNodeMetrics() ([]NodeResourceMetrics, error) { var monitors []NodeResourceMetrics - err := t.DB.Find(&monitors, map[string]interface{}{ + err := t.DB.Find(&monitors, map[string]any{ "ts": gorm.Expr("now() - interval 1 hour"), }).Error return monitors, err diff --git a/internal/metrics/encoder.go b/internal/metrics/encoder.go index a78fa50c..892e36bc 100644 --- a/internal/metrics/encoder.go +++ b/internal/metrics/encoder.go @@ -37,6 +37,9 @@ type MultiProtocolEncoder struct { } func NewEncoder(encoderType string) Encoder { + if encoderType == "" { + encoderType = config.MetricsFormatInflux + } encoderEnum, exists := stringToEncoderType[encoderType] if !exists { // Default to influx for unknown types diff --git a/internal/metrics/encoders/otel.go b/internal/metrics/encoders/otel.go index e372ef3c..cd596a20 100644 --- a/internal/metrics/encoders/otel.go +++ b/internal/metrics/encoders/otel.go @@ -51,11 +51,11 @@ type OtelStrategy struct { // otelMetric represents a single OTLP metric point with all its associated data type otelMetric struct { - name string // Metric name - attributes []attribute.KeyValue // OpenTelemetry attributes (tags) - value interface{} // Primary metric value - timestamp time.Time // Metric timestamp - fields map[string]interface{} // All field values + name string // Metric name + attributes []attribute.KeyValue // OpenTelemetry attributes (tags) + value any // Primary metric value + timestamp time.Time // Metric timestamp + fields map[string]any // All field values } // NewOtelStrategy creates a new optimized OTEL strategy with pre-allocated slices @@ -72,7 +72,7 @@ func (s *OtelStrategy) StartLine(measurement string) { s.currentMetric = &otelMetric{ name: measurement, attributes: make([]attribute.KeyValue, 0, defaultAttributeCapacity), - fields: make(map[string]interface{}, defaultFieldCapacity), + fields: make(map[string]any, defaultFieldCapacity), } } @@ -205,7 +205,7 @@ func (s *OtelStrategy) writeAttribute(attr attribute.KeyValue) { } // writeTimestampsAndValue writes timestamp fields and the metric value -func (s *OtelStrategy) writeTimestampsAndValue(timestamp time.Time, value interface{}) { +func (s *OtelStrategy) writeTimestampsAndValue(timestamp time.Time, value any) { timestampNanos := strconv.FormatInt(timestamp.UnixNano(), 10) s.buffer.WriteString(timestampStart) s.buffer.WriteString(timestampNanos) @@ -218,7 +218,7 @@ func (s *OtelStrategy) writeTimestampsAndValue(timestamp time.Time, value interf // writeFieldMetricJSON writes a field metric in OTLP JSON format. // Field metrics have names suffixed with the field key (e.g., "cpu_usage_percent"). -func (s *OtelStrategy) writeFieldMetricJSON(metric *otelMetric, fieldKey string, fieldValue interface{}) { +func (s *OtelStrategy) writeFieldMetricJSON(metric *otelMetric, fieldKey string, fieldValue any) { // Write metric name with field suffix s.buffer.WriteString(metricStart) s.buffer.WriteString(metric.name) @@ -237,7 +237,7 @@ func (s *OtelStrategy) writeFieldMetricJSON(metric *otelMetric, fieldKey string, // writeValueJSON writes a value in the appropriate OTLP format. // Integer types use "asInt" field, floating point types use "asDouble" field. -func (s *OtelStrategy) writeValueJSON(value interface{}) { +func (s *OtelStrategy) writeValueJSON(value any) { switch v := value.(type) { // Integer types - all use "asInt" with string values in OTLP case int: diff --git a/internal/portallocator/portallocator.go b/internal/portallocator/portallocator.go index 1a050eee..4899af4e 100644 --- a/internal/portallocator/portallocator.go +++ b/internal/portallocator/portallocator.go @@ -15,10 +15,8 @@ import ( "k8s.io/client-go/util/retry" "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/manager" ) @@ -115,25 +113,6 @@ func (s *PortAllocator) SetupWithManager(ctx context.Context, mgr manager.Manage _ = mgr.Add(manager.RunnableFunc(func(ctx context.Context) error { <-mgr.Elected() s.IsLeader = true - leaderInfo := &v1.ConfigMap{ - ObjectMeta: metav1.ObjectMeta{ - Name: constants.LeaderInfoConfigMapName, - Namespace: utils.CurrentNamespace(), - }, - } - err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { - _, err := controllerutil.CreateOrUpdate(ctx, s.Client, leaderInfo, func() error { - leaderInfo.Data = map[string]string{ - constants.LeaderInfoConfigMapLeaderIPKey: utils.CurrentIP(), - } - return nil - }) - return err - }) - if err != nil { - log.FromContext(ctx).Error(err, "Failed to update leader IP info in ConfigMap") - } - s.storeMutexNode.Lock() s.storeMutexCluster.Lock() defer s.storeMutexNode.Unlock() diff --git a/internal/scheduler/expander/handler.go b/internal/scheduler/expander/handler.go index 26da438a..3d3e4a6a 100644 --- a/internal/scheduler/expander/handler.go +++ b/internal/scheduler/expander/handler.go @@ -125,7 +125,7 @@ func (e *NodeExpander) GetNodeScalerInfo() any { defer e.mu.RUnlock() inFlightNodeClaimSnapshot := make(map[string]any) - e.inFlightNodeClaims.Range(func(key, value interface{}) bool { + e.inFlightNodeClaims.Range(func(key, value any) bool { inFlightNodeClaimSnapshot[key.(string)] = value return true }) @@ -155,15 +155,15 @@ func (e *NodeExpander) ProcessExpansion(ctx context.Context, pod *corev1.Pod) er gpuNodesPassedOtherFilters, err := e.simulateSchedulingWithoutGPU(ctx, pod) if err != nil { e.eventRecorder.Eventf(pod, corev1.EventTypeNormal, "NodeExpansionCheck", - "can not schedule on any nodes even without GPU constraints, manual check required. error: %w", err) - e.logger.Info("Pod schedulable but no GPU nodes available, manual check required", + "can not schedule on any nodes even without GPU constraints, karpenter should take over expansion. error: %w", err) + e.logger.Info("Pod schedulable but no GPU nodes available, karpenter should take over expansion", "namespace", pod.Namespace, "pod", pod.Name, "error", err) return nil } if len(gpuNodesPassedOtherFilters) == 0 { e.eventRecorder.Eventf(pod, corev1.EventTypeNormal, "NodeExpansionCheck", - "can not schedule on any nodes, manual check required, 0 fit nodes") - e.logger.Info("Pod schedulable but no GPU nodes available, manual check required", + "can not schedule on any nodes even without GPU constraints, karpenter should take over expansion, 0 fit nodes") + e.logger.Info("Pod schedulable but no GPU nodes available, karpenter should take over expansion", "namespace", pod.Namespace, "pod", pod.Name) return nil } @@ -417,7 +417,7 @@ func (e *NodeExpander) checkGPUFitWithInflightNodes(pod *corev1.Pod, potentialGp // Get allocation request e.mu.RLock() defer e.mu.RUnlock() - allocRequest, _, err := e.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(e.ctx, pod) if err != nil { return nil, false, true, false } @@ -468,7 +468,7 @@ func (e *NodeExpander) checkGPUFitWithInflightNodes(pod *corev1.Pod, potentialGp } func (e *NodeExpander) checkGPUFitForNewNode(pod *corev1.Pod, gpus []*tfv1.GPU) bool { - allocRequest, _, err := e.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(e.ctx, pod) if err != nil { return false } diff --git a/internal/scheduler/expander/handler_test.go b/internal/scheduler/expander/handler_test.go index 42edd70a..6f5bcc53 100644 --- a/internal/scheduler/expander/handler_test.go +++ b/internal/scheduler/expander/handler_test.go @@ -56,7 +56,7 @@ func (suite *NodeExpanderTestSuite) SetupSuite() { Expect(suite.k8sClient.Create(ctx, ns)).To(Succeed()) // Setup proper allocator for testing - suite.allocator = gpuallocator.NewGpuAllocator(ctx, suite.k8sClient, time.Second) + suite.allocator = gpuallocator.NewGpuAllocator(ctx, nil, suite.k8sClient, time.Second) err := suite.allocator.InitGPUAndQuotaStore() if err != nil { // For test environments, we can ignore some initialization errors diff --git a/internal/scheduler/gpuresources/gpuresources.go b/internal/scheduler/gpuresources/gpuresources.go index c3759fad..09309198 100644 --- a/internal/scheduler/gpuresources/gpuresources.go +++ b/internal/scheduler/gpuresources/gpuresources.go @@ -8,11 +8,13 @@ import ( "strconv" "strings" "sync" + "time" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" + "github.com/NexusGPU/tensor-fusion/internal/indexallocator" "github.com/NexusGPU/tensor-fusion/internal/metrics" "github.com/NexusGPU/tensor-fusion/internal/quota" "github.com/NexusGPU/tensor-fusion/internal/utils" @@ -23,6 +25,8 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/retry" "k8s.io/klog/v2" fwk "k8s.io/kube-scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework" @@ -42,12 +46,13 @@ var _ framework.PostBindPlugin = &GPUFit{} var _ framework.EnqueueExtensions = &GPUFit{} type GPUFit struct { - logger *klog.Logger - fh framework.Handle - client client.Client - allocator *gpuallocator.GpuAllocator - ctx context.Context - cfg *config.GPUFitConfig + logger *klog.Logger + fh framework.Handle + client client.Client + allocator *gpuallocator.GpuAllocator + indexAllocator *indexallocator.IndexAllocator + ctx context.Context + cfg *config.GPUFitConfig } type GPUSchedulingStateData struct { @@ -80,7 +85,7 @@ func (p *GPUSchedulingStateData) Clone() fwk.StateData { type PluginFactoryFunc func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) -func NewWithDeps(allocator *gpuallocator.GpuAllocator, client client.Client) PluginFactoryFunc { +func NewWithDeps(allocator *gpuallocator.GpuAllocator, indexAllocator *indexallocator.IndexAllocator, client client.Client) PluginFactoryFunc { return func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { target := &config.GPUFitConfig{} if unknown, ok := obj.(*runtime.Unknown); ok { @@ -91,12 +96,13 @@ func NewWithDeps(allocator *gpuallocator.GpuAllocator, client client.Client) Plu lh := klog.FromContext(ctx).WithValues("plugin", Name) lh.Info("Creating new GPUFit plugin") c := &GPUFit{ - logger: &lh, - fh: handle, - cfg: target, - allocator: allocator, - ctx: ctx, - client: client, + logger: &lh, + fh: handle, + cfg: target, + allocator: allocator, + indexAllocator: indexAllocator, + ctx: ctx, + client: client, } lh.Info("Created new GPUFit plugin", "plugin", c) @@ -128,7 +134,7 @@ func (s *GPUFit) PreFilter(ctx context.Context, state fwk.CycleState, pod *v1.Po // Handle tensor-fusion mode scheduling s.logger.Info("checking GPU node resources for pod", "pod", pod.Name) - allocRequest, reason, err := s.allocator.ComposeAllocationRequest(pod) + allocRequest, reason, err := utils.ComposeAllocationRequest(s.ctx, pod) if err != nil { return nil, fwk.NewStatus(fwk.Error, reason) } @@ -162,6 +168,29 @@ func (s *GPUFit) PreFilter(ctx context.Context, state fwk.CycleState, pod *v1.Po } } + // For partitioned mode, match partition template if not already specified + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID == "" { + matchedGPU, partitionMatch, err := s.allocator.GetMatchedPartition(allocRequest, filteredGPUs) + if err != nil { + metrics.SetSchedulerMetrics(allocRequest.PoolName, false) + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "PartitionTemplateMatchFailed", + "match partition template", "Failed to match partition template: "+err.Error()) + s.logger.Error(err, "failed to match partition template", "pod", pod.Name) + return nil, fwk.NewStatus(fwk.Unschedulable, fmt.Sprintf("no suitable partition template: %v", err)) + } + + // Set partition template ID in alloc request + allocRequest.PartitionTemplateID = partitionMatch.TemplateID + s.logger.Info("Matched partition template in PreFilter", + "pod", pod.Name, + "gpu", matchedGPU.Name, + "template", allocRequest.PartitionTemplateID, + "score", partitionMatch.Score) + + // Update state with the updated alloc request + state.Write(CycleStateAllocateRequest, allocRequest) + } + validNodesValidGPUs := lo.GroupBy(filteredGPUs, func(gpu *tfv1.GPU) string { return gpu.Status.NodeSelector[constants.KubernetesHostNameLabel] }) @@ -424,22 +453,15 @@ func (s *GPUFit) Reserve(ctx context.Context, state fwk.CycleState, pod *v1.Pod, } // reserve GPU resources inside memory and asynchronously update GPU custom resource + allocReq := allocRequest.(*tfv1.AllocRequest) _, err = s.allocator.Bind( schedulingResult.FinalGPUs, - allocRequest.(*tfv1.AllocRequest), + allocReq, ) if err != nil { return fwk.NewStatus(fwk.Error, err.Error()) } - // Index is already assigned in webhook stage, scheduler cannot modify Pod - // Just verify that index annotation exists for logging - if pod.Annotations != nil { - if indexStr, exists := pod.Annotations[constants.PodIndexAnnotation]; exists && indexStr != "" { - s.logger.V(5).Info("Pod index already assigned in webhook", "pod", pod.Name, "index", indexStr) - } - } - return fwk.NewStatus(fwk.Success, "") } @@ -477,19 +499,86 @@ func (s *GPUFit) PostBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod gpuIDs := strings.Join(gpuSchedulingResult.(*GPUSchedulingStateData).FinalGPUs, ",") s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) - // Patch GPU device IDs annotation - patch := []byte(`[{ - "op": "add", - "path": "/metadata/annotations/` + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation) + `", - "value": "` + gpuIDs + `"}]`) - err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)) + index, err := utils.ParsePodIndexResourceClaim(pod) if err != nil { - s.logger.Error(err, "failed to patch gpu device ids", "pod", pod.Name) - s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed", - "Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs) + s.logger.Error(err, "failed to parse pod index annotation", "pod", pod.Name) + return + } + + indexAvailable := s.indexAllocator.CheckNodeIndexAndTryOccupy(pod, index) + + // Build patch operations + patchOps := []map[string]any{ + { + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation), + "value": gpuIDs, + }, + } + if indexAvailable { + patchOps = append(patchOps, map[string]any{ + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.PodIndexAnnotation), + "value": index, + }) } else { - s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", - "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + s.logger.Info("Index is not available on node, spawn a goroutine to patch it asynchronously", "pod", pod.Name, "node", nodeName, "index", index) + // spawn a goroutine to patch + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "PodIndexAllocationPending", "Pod index allocation pending", + fmt.Sprintf("Index %d will be patched into pod after released by other pod on the same node: %s", index, nodeName)) + s.indexAllocator.AsyncCheckNodeIndexAvailableAndAssign(pod, index) + } + + // Add partition template ID annotation if in partitioned mode + allocRequestRaw, err := state.Read(CycleStateAllocateRequest) + if err == nil { + allocRequest := allocRequestRaw.(*tfv1.AllocRequest) + if allocRequest.Isolation == tfv1.IsolationModePartitioned && allocRequest.PartitionTemplateID != "" { + patchOps = append(patchOps, map[string]any{ + "op": "add", + "path": "/metadata/annotations/" + utils.EscapeJSONPointer(constants.PartitionTemplateIDAnnotation), + "value": allocRequest.PartitionTemplateID, + }) + s.logger.Info("Adding partition template ID annotation", "pod", pod.Name, "templateID", allocRequest.PartitionTemplateID) + } + } + + // Convert patch operations to JSON + patchBytes, err := json.Marshal(patchOps) + if err != nil { + s.logger.Error(err, "failed to marshal patch operations", "pod", pod.Name) + return + } + + // Patch pod annotations with retry + err = retry.OnError(wait.Backoff{ + Duration: 1 * time.Second, + Factor: 2, + Jitter: 0.1, + Steps: 3, + }, func(err error) bool { + return true + }, func() error { + err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patchBytes)) + if err != nil { + s.logger.Error(err, "failed to patch pod annotations", "pod", pod.Name) + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed", + "Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs) + } else { + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", + "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + } + return nil + }) + if err != nil { + if indexAvailable { + s.indexAllocator.RemoveNodeIndexQueueForPod(types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }) + } + s.logger.Error(err, "failed to patch pod annotations in post binding stage", "pod", pod.Name) + return } } @@ -509,8 +598,8 @@ func (s *GPUFit) EventsToRegister(_ context.Context) ([]fwk.ClusterEventWithHint }, nil } -// convertToGPU converts an interface{} to *tfv1.GPU, handling both typed and unstructured objects -func convertToGPU(obj interface{}) (*tfv1.GPU, error) { +// convertToGPU converts an any to *tfv1.GPU, handling both typed and unstructured objects +func convertToGPU(obj any) (*tfv1.GPU, error) { if obj == nil { return nil, nil } @@ -531,7 +620,7 @@ func convertToGPU(obj interface{}) (*tfv1.GPU, error) { return nil, fmt.Errorf("cannot convert %T to *tfv1.GPU", obj) } -func (s *GPUFit) queueingHint(logger klog.Logger, pod *v1.Pod, oldObj, newObj interface{}) (fwk.QueueingHint, error) { +func (s *GPUFit) queueingHint(logger klog.Logger, pod *v1.Pod, oldObj, newObj any) (fwk.QueueingHint, error) { // Only process TensorFusion worker pods if !utils.IsTensorFusionWorker(pod) { return fwk.QueueSkip, nil @@ -573,7 +662,7 @@ func (s *GPUFit) queueingHint(logger klog.Logger, pod *v1.Pod, oldObj, newObj in } // Compose allocation request for the pod passed in by scheduler framework - allocRequest, _, err := s.allocator.ComposeAllocationRequest(pod) + allocRequest, _, err := utils.ComposeAllocationRequest(s.ctx, pod) if err != nil { logger.V(5).Info("Failed to compose allocation request for pod, skip", "pod", klog.KObj(pod), "error", err) diff --git a/internal/scheduler/gpuresources/gpuresources_test.go b/internal/scheduler/gpuresources/gpuresources_test.go index 5707a640..33b50d7c 100644 --- a/internal/scheduler/gpuresources/gpuresources_test.go +++ b/internal/scheduler/gpuresources/gpuresources_test.go @@ -34,6 +34,7 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" + "github.com/NexusGPU/tensor-fusion/internal/indexallocator" "github.com/NexusGPU/tensor-fusion/internal/utils" internalcache "k8s.io/kubernetes/pkg/scheduler/backend/cache" internalqueue "k8s.io/kubernetes/pkg/scheduler/backend/queue" @@ -41,12 +42,13 @@ import ( type GPUResourcesSuite struct { suite.Suite - client client.Client - fwk framework.Framework - allocator *gpuallocator.GpuAllocator - plugin *GPUFit - ctx context.Context - cancel context.CancelFunc + client client.Client + fwk framework.Framework + allocator *gpuallocator.GpuAllocator + indexAllocator *indexallocator.IndexAllocator + plugin *GPUFit + ctx context.Context + cancel context.CancelFunc } func (s *GPUResourcesSuite) SetupTest() { @@ -169,7 +171,7 @@ func (s *GPUResourcesSuite) SetupTest() { Status: tfv1.GPUStatus{ Phase: tfv1.TensorFusionGPUPhaseRunning, NodeSelector: map[string]string{constants.KubernetesHostNameLabel: "node-c"}, - UsedBy: tfv1.UsedByNvidiaDevicePlugin, + UsedBy: "nvidia-device-plugin", Capacity: &tfv1.Resource{ Tflops: resource.MustParse("2000"), Vram: resource.MustParse("40Gi"), @@ -257,13 +259,13 @@ func (s *GPUResourcesSuite) SetupTest() { s.NoError(err) s.fwk = fwk - s.allocator = gpuallocator.NewGpuAllocator(s.ctx, s.client, time.Second) + s.allocator = gpuallocator.NewGpuAllocator(s.ctx, nil, s.client, time.Second) err = s.allocator.InitGPUAndQuotaStore() s.NoError(err) s.allocator.ReconcileAllocationState() s.allocator.SetAllocatorReady() - pluginFactory := NewWithDeps(s.allocator, s.client) + pluginFactory := NewWithDeps(s.allocator, s.indexAllocator, s.client) pluginConfig := &runtime.Unknown{ Raw: []byte(`{ "maxWorkerPerNode": 3, @@ -597,7 +599,7 @@ func (s *GPUResourcesSuite) makePod(name string, annotations map[string]string) func (s *GPUResourcesSuite) TestNewWithDeps() { log.FromContext(s.ctx).Info("Running TestNewWithDeps") - pluginFactory := NewWithDeps(s.allocator, s.client) + pluginFactory := NewWithDeps(s.allocator, s.indexAllocator, s.client) s.NotNil(pluginFactory) // Test with valid config diff --git a/internal/utils/compose.go b/internal/utils/compose.go index 5ca775a2..16855da4 100644 --- a/internal/utils/compose.go +++ b/internal/utils/compose.go @@ -135,6 +135,10 @@ func AddOrOverrideTFClientMissingAnnotationsBeforePatch(pod *v1.Pod, tfInfo Tens // add inject container annotation for client Pod, in case user doesn't specify it pod.Annotations[constants.InjectContainerAnnotation] = strings.Join(tfInfo.ContainerNames, ",") pod.Annotations[constants.IsolationModeAnnotation] = string(tfInfo.Profile.Isolation) + // add partition template ID if in partitioned mode + if tfInfo.Profile.Isolation == tfv1.IsolationModePartitioned && tfInfo.Profile.PartitionTemplateID != "" { + pod.Annotations[constants.PartitionTemplateIDAnnotation] = tfInfo.Profile.PartitionTemplateID + } } func AppendTFWorkerLabelsAndAnnotationsAfterTemplate( @@ -196,6 +200,10 @@ func AppendTFWorkerLabelsAndAnnotationsAfterTemplate( }), ",") } annotations[constants.IsolationModeAnnotation] = string(workload.Spec.Isolation) + // add partition template ID if in partitioned mode + if workload.Spec.Isolation == tfv1.IsolationModePartitioned && workload.Spec.PartitionTemplateID != "" { + annotations[constants.PartitionTemplateIDAnnotation] = workload.Spec.PartitionTemplateID + } return labels, annotations } @@ -449,7 +457,7 @@ func configureFeatures4InjectLib(isLocalGPU bool, disabledFeatures string) []v1. return envList } -func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, pool *tfv1.GPUPool) { +func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, pool *tfv1.GPUPool, compatibleWithNvidiaContainerToolkit bool) { // Hypervisor needs to read /proc to map pod with processID spec.HostPID = true spec.TerminationGracePeriodSeconds = constants.GracefulPeriodSeconds @@ -534,7 +542,7 @@ func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, poo }, }) - composeHypervisorInitContainer(spec, pool) + composeHypervisorInitContainer(spec, pool, compatibleWithNvidiaContainerToolkit) composeHypervisorContainer(spec, pool, enableVector) if enableVector { @@ -542,11 +550,11 @@ func AddTFHypervisorConfAfterTemplate(ctx context.Context, spec *v1.PodSpec, poo } } -func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool) { +func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool, compatibleWithNvidiaContainerToolkit bool) { spec.InitContainers = append(spec.InitContainers, v1.Container{ Name: "init-shm", Image: pool.Spec.ComponentConfig.Hypervisor.Image, - Command: []string{"hypervisor", "mount-shm"}, + Command: []string{constants.ComponentHypervisor, constants.MountShmSubcommand}, SecurityContext: &v1.SecurityContext{ Privileged: ptr.To(true), }, @@ -559,6 +567,49 @@ func composeHypervisorInitContainer(spec *v1.PodSpec, pool *tfv1.GPUPool) { }, }, }) + + // Add initContainer to wait for NVIDIA Container Toolkit toolkit-ready validation + if compatibleWithNvidiaContainerToolkit { + initContainerImage := pool.Spec.ComponentConfig.Hypervisor.Image + if initContainerImage == "" { + // Use the same image as the main container if not specified + if len(spec.Containers) > 0 { + initContainerImage = spec.Containers[0].Image + } + } + + initContainer := v1.Container{ + Name: "toolkit-validation", + Image: initContainerImage, + Command: []string{"sh", "-c"}, + Args: []string{ + "until [ -f /run/nvidia/validations/toolkit-ready ]; do echo waiting for nvidia container stack to be setup; sleep 5; done", + }, + SecurityContext: &v1.SecurityContext{ + Privileged: ptr.To(true), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "run-nvidia-validations", + MountPath: "/run/nvidia/validations", + MountPropagation: ptr.To(v1.MountPropagationHostToContainer), + }, + }, + } + + spec.InitContainers = append(spec.InitContainers, initContainer) + + // Add volume for NVIDIA validations + spec.Volumes = append(spec.Volumes, v1.Volume{ + Name: "run-nvidia-validations", + VolumeSource: v1.VolumeSource{ + HostPath: &v1.HostPathVolumeSource{ + Path: "/run/nvidia/validations", + Type: ptr.To(v1.HostPathDirectoryOrCreate), + }, + }, + }) + } } func composeHypervisorContainer(spec *v1.PodSpec, pool *tfv1.GPUPool, enableVector bool) { diff --git a/internal/utils/config.go b/internal/utils/config.go index 23256dc2..ed8bd192 100644 --- a/internal/utils/config.go +++ b/internal/utils/config.go @@ -127,6 +127,67 @@ func GetEnvOrDefault(key, defaultValue string) string { return defaultValue } +// PodWorkerInfo contains extracted worker information from pod annotations +type PodWorkerInfo struct { + DeviceUUIDs []string + IsolationMode string + MemoryLimitBytes uint64 + ComputeLimitUnits uint32 + TemplateID string +} + +// ExtractPodWorkerInfo extracts worker information from pod annotations +// This is a common utility function used by both GpuAllocator and PodCacheManager +func ExtractPodWorkerInfo(pod *corev1.Pod) PodWorkerInfo { + info := PodWorkerInfo{} + + // Extract GPU device IDs + if gpuIDsStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { + ids := strings.Split(gpuIDsStr, ",") + info.DeviceUUIDs = make([]string, 0, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id != "" { + info.DeviceUUIDs = append(info.DeviceUUIDs, id) + } + } + } + + // Extract isolation mode + if isolationMode, exists := pod.Annotations[constants.IsolationModeAnnotation]; exists { + info.IsolationMode = isolationMode + } else { + info.IsolationMode = string(tfv1.IsolationModeSoft) // default + } + + // Extract memory limit (VRAM) + if vramLimit, exists := pod.Annotations[constants.VRAMLimitAnnotation]; exists { + if qty, err := resource.ParseQuantity(vramLimit); err == nil { + info.MemoryLimitBytes = uint64(qty.Value()) + } + } + + // Extract compute limit (compute percent) + if computeLimit, exists := pod.Annotations[constants.ComputeLimitAnnotation]; exists { + if qty, err := resource.ParseQuantity(computeLimit); err == nil { + // Convert to percentage units (e.g., "50" -> 50, "100" -> 100) + percent := qty.AsApproximateFloat64() + info.ComputeLimitUnits = uint32(percent) + } + } + + // Extract template ID (for partitioned mode) + // First check PartitionTemplateIDAnnotation (set by scheduler) + if templateID, exists := pod.Annotations[constants.PartitionTemplateIDAnnotation]; exists { + info.TemplateID = templateID + } else if templateID, exists := pod.Annotations[constants.WorkloadProfileAnnotation]; exists { + // Fallback to WorkloadProfileAnnotation + info.TemplateID = templateID + } + + return info +} + func GetGPUResource(pod *corev1.Pod, isRequest bool) (tfv1.Resource, error) { tflopsKey := constants.TFLOPSRequestAnnotation vramKey := constants.VRAMRequestAnnotation @@ -222,3 +283,16 @@ func GetLeaderIP(client client.Client) string { } return leaderInfo.Data[constants.LeaderInfoConfigMapLeaderIPKey] } + +// only for local development, won't set KUBECONFIG env var in none local environments +func NormalizeKubeConfigEnv() { + cfgPath := os.Getenv("KUBECONFIG") + if cfgPath != "" && strings.HasPrefix(cfgPath, "~") { + home, err := os.UserHomeDir() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + _ = os.Setenv("KUBECONFIG", strings.Replace(cfgPath, "~", home, 1)) + } +} diff --git a/internal/utils/reconcile.go b/internal/utils/reconcile.go index ce2138a6..c9c3d319 100644 --- a/internal/utils/reconcile.go +++ b/internal/utils/reconcile.go @@ -166,6 +166,14 @@ func IsPodStopped(pod *corev1.Pod) bool { return pod.Status.Phase == corev1.PodFailed || pod.Status.Phase == corev1.PodSucceeded } +func IsPodRunning(pod *corev1.Pod) bool { + return pod.Status.Phase == corev1.PodRunning +} + +func IsPodPending(pod *corev1.Pod) bool { + return pod.Status.Phase == corev1.PodPending && pod.DeletionTimestamp.IsZero() +} + func ExtractPoolNameFromNodeLabel(node *tfv1.GPUNode) string { var poolName string for labelKey := range node.Labels { @@ -245,9 +253,12 @@ func IsDesignatedNodePod(pod *corev1.Pod) bool { func GetInitialGPUNodeSelector() []string { selector := os.Getenv("INITIAL_GPU_NODE_LABEL_SELECTOR") if selector == "" { - selector = constants.InitialGPUNodeSelector + return nil } selectors := strings.Split(selector, "=") + if len(selectors) != 2 { + return nil + } return selectors } @@ -265,3 +276,21 @@ func containsGPUResources(res corev1.ResourceList) bool { } return false } + +// AppendEnvVarsIfNotExists appends environment variables to the slice only if they don't already exist (by name). +// It returns the updated slice with new env vars appended. +func AppendEnvVarsIfNotExists(envVars []corev1.EnvVar, newEnvVars ...corev1.EnvVar) []corev1.EnvVar { + existingNames := make(map[string]bool) + for _, env := range envVars { + existingNames[env.Name] = true + } + + for _, newEnv := range newEnvVars { + if !existingNames[newEnv.Name] { + envVars = append(envVars, newEnv) + existingNames[newEnv.Name] = true + } + } + + return envVars +} diff --git a/internal/utils/resource.go b/internal/utils/resource.go index b78f579e..c9b2ffc3 100644 --- a/internal/utils/resource.go +++ b/internal/utils/resource.go @@ -1,6 +1,7 @@ package utils import ( + context "context" "fmt" "math" "slices" @@ -10,10 +11,14 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/samber/lo" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log" ) +const MaxGPUCounterPerAllocation = 128 + func GPUResourcesFromAnnotations(annotations map[string]string) (*tfv1.Resources, error) { result := tfv1.Resources{} resInfo := []struct { @@ -73,3 +78,97 @@ func ParseIndicesAnnotation(gpuIndicesStr string) ([]int32, bool) { }) return gpuIndices, false } + +func ComposeAllocationRequest(ctx context.Context, pod *corev1.Pod) (*tfv1.AllocRequest, string, error) { + // allow Pods with no requests/limits to use TensorFusion, Pod webhook will ensure at least one request/limit is set + gpuRequestResource, err := GetGPUResource(pod, true) + if err != nil { + log.FromContext(ctx).Error(err, "Invalid gpu request annotation", "pod", pod.Name, "namespace", pod.Namespace) + } + gpuLimitResource, err := GetGPUResource(pod, false) + if err != nil { + log.FromContext(ctx).Error(err, "Invalid gpu limit annotation", "pod", pod.Name, "namespace", pod.Namespace) + } + + count := 1 + if gpuCountStr, exists := pod.Annotations[constants.GpuCountAnnotation]; exists { + count, err = strconv.Atoi(gpuCountStr) + if err != nil { + return &tfv1.AllocRequest{}, "invalid gpu count annotation", err + } + } + if count > MaxGPUCounterPerAllocation { + return &tfv1.AllocRequest{}, "gpu count annotation is too large", nil + } + + qosLevel := tfv1.QoSLevel(pod.Annotations[constants.QoSLevelAnnotation]) + if qosLevel == "" { + qosLevel = tfv1.QoSMedium + } + + gpuVendor := pod.Annotations[constants.GpuVendorAnnotation] + + gpuIndices, hasError := ParseIndicesAnnotation(pod.Annotations[constants.GpuIndicesAnnotation]) + if hasError { + return &tfv1.AllocRequest{}, "invalid gpu-indices annotation", + fmt.Errorf("can not parse gpu indices annotation") + } + + // Read isolation mode + isolationMode := tfv1.IsolationModeType(pod.Annotations[constants.IsolationModeAnnotation]) + if isolationMode == "" { + isolationMode = tfv1.IsolationModeSoft + } + + allocRequest := tfv1.AllocRequest{ + PoolName: pod.Annotations[constants.GpuPoolKey], + Request: gpuRequestResource, + Limit: gpuLimitResource, + + Count: uint(count), + GPUModel: pod.Annotations[constants.GPUModelAnnotation], + GPUIndices: gpuIndices, + GPUVendor: gpuVendor, + Isolation: isolationMode, + WorkloadNameNamespace: tfv1.NameNamespace{ + Name: pod.Labels[constants.WorkloadKey], + Namespace: pod.Namespace, + }, + PodMeta: pod.ObjectMeta, + QoS: qosLevel, + } + + // Read partition template ID annotation if in partitioned mode + if allocRequest.Isolation == tfv1.IsolationModePartitioned { + if partitionTemplateID, ok := pod.Annotations[constants.PartitionTemplateIDAnnotation]; ok && partitionTemplateID != "" { + allocRequest.PartitionTemplateID = partitionTemplateID + } + } + + // for already allocated workers, set the GPU device IDs for further scaling and retrieval + if gpuIdStr, exists := pod.Annotations[constants.GPUDeviceIDsAnnotation]; exists { + gpuIds := strings.SplitSeq(gpuIdStr, ",") + allocRequest.GPUNames = slices.Collect(gpuIds) + } + + return &allocRequest, "", nil +} + +func ParsePodIndexResourceClaim(pod *corev1.Pod) (int, error) { + for _, container := range pod.Spec.Containers { + for indexKey, indexValue := range container.Resources.Limits { + if strings.HasPrefix(string(indexKey), constants.PodIndexAnnotation+constants.PodIndexDelimiter) { + indexStr := strings.Split(string(indexKey), constants.PodIndexDelimiter)[1] + indexInt, err := strconv.ParseInt(indexStr, 16, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse tensor fusion index of Pod resource limits: %v", err) + } + if indexInt < 0 || indexInt >= constants.IndexKeyLength { + return 0, fmt.Errorf("tensor fusion index of Pod resource limits out of range: %d", indexInt) + } + return int(indexValue.Value()) + int(indexInt)*constants.IndexModLength, nil + } + } + } + return 0, fmt.Errorf("tensor fusion index of Pod resource limits is missing in any container") +} diff --git a/internal/version/version.go b/internal/version/version.go index 25cc9213..5080cb18 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -6,6 +6,7 @@ import ( "time" ) +// set by GO_LDFLAGS in release.yaml var ( BuildVersion string ) diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index fe18e7fe..a82b3428 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -310,7 +310,7 @@ func (m *TensorFusionPodMutator) createOrUpdateWorkload( } func (m *TensorFusionPodMutator) patchTFClient( - _ctx context.Context, + ctx context.Context, pod *corev1.Pod, pool *tfv1.GPUPool, isLocalGPU bool, @@ -329,10 +329,12 @@ func (m *TensorFusionPodMutator) patchTFClient( // Assign index once per pod (before processing containers) // Index must be assigned in webhook stage since scheduler cannot modify Pod - // This is a special index resource (1-512), not a real device resource + // This is a special index resource (1-32), not a real device resource // Index is assigned in ascending order (1, 2, 3, ...) via distributed lock (leader election) - // index := m.assignDeviceAllocationIndex(ctx, pod) - // log.FromContext(ctx).Info("assigned device allocation index successfully", "index", index, "pod", pod.Name) + index := m.assignDeviceAllocationIndex(ctx, pod) + + // clean annotation if exists, must be assigned by scheduler to ensure lock of certain index on one node + delete(pod.Annotations, constants.PodIndexAnnotation) for _, containerIndex := range containerIndices { container := &pod.Spec.Containers[containerIndex] @@ -362,17 +364,14 @@ func (m *TensorFusionPodMutator) patchTFClient( // Inject tensor-fusion.ai/index resource for Device Plugin communication // This is a special index resource (not a real device), used for Pod-to-DevicePlugin communication - if container.Resources.Requests == nil { - container.Resources.Requests = make(corev1.ResourceList) - } if container.Resources.Limits == nil { container.Resources.Limits = make(corev1.ResourceList) } - // Limit is set to actual index value (1-512) for Device Plugin to match Pod + // Limit is set to actual index value (1-128) for Device Plugin to match Pod // ResourceFit of dummy device already ignored in TF scheduler - // indexQuantity := resource.MustParse(strconv.Itoa(index)) - // TODO: workaround to avoid kubelet resource check error - container.Resources.Limits[constants.PodIndexAnnotation] = resource.MustParse("1") + indexQuantity := resource.MustParse(strconv.Itoa((index % constants.IndexModLength) + 1)) + indexKey := fmt.Sprintf("%s%s%x", constants.PodIndexAnnotation, constants.PodIndexDelimiter, index/constants.IndexModLength) + container.Resources.Limits[corev1.ResourceName(indexKey)] = indexQuantity if !isLocalGPU { addConnectionForRemoteFixedReplicaVirtualGPU(pod, container, clientConfig) @@ -448,14 +447,6 @@ func (m *TensorFusionPodMutator) assignDeviceAllocationIndex(ctx context.Context // No allocator available, use 0 as fallback index = 0 } - - // Set annotation for matching in Device Plugin - if pod.Annotations == nil { - pod.Annotations = make(map[string]string) - } - if index > 0 { - pod.Annotations[constants.PodIndexAnnotation] = strconv.Itoa(index) - } return index } diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index 0066b442..c4adf622 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -106,6 +106,13 @@ func ParseTensorFusionInfo( workloadProfile.Spec.Isolation = tfv1.IsolationModeSoft } + // Read partition template ID annotation if in partitioned mode + if workloadProfile.Spec.Isolation == tfv1.IsolationModePartitioned { + if partitionTemplateID, ok := pod.Annotations[constants.PartitionTemplateIDAnnotation]; ok && partitionTemplateID != "" { + workloadProfile.Spec.PartitionTemplateID = partitionTemplateID + } + } + workerPodTemplate, ok := pod.Annotations[constants.WorkerPodTemplateAnnotation] if ok && workerPodTemplate != "" { if workloadProfile.Spec.IsLocalGPU { diff --git a/provider/Makefile b/provider/Makefile new file mode 100644 index 00000000..c1ad8680 --- /dev/null +++ b/provider/Makefile @@ -0,0 +1,89 @@ +# Makefile for building accelerator libraries +# Supports both stub and vendor-specific implementations (NVIDIA, Ascend, etc.) + +CC ?= gcc +CFLAGS ?= -Wall -Wextra -std=c11 -fPIC -O2 +LDFLAGS ?= -shared + +# Directories +PROVIDER_DIR := $(shell pwd) +STUB_DIR := $(PROVIDER_DIR)/stub +ASCEND_DIR := $(PROVIDER_DIR)/ascend +BUILD_DIR := $(PROVIDER_DIR)/build +TEST_DIR := $(PROVIDER_DIR)/test + +# Output libraries +STUB_LIB := $(BUILD_DIR)/libaccelerator_stub.so +ASCEND_LIB := $(BUILD_DIR)/libaccelerator_ascend.so + +# Source files +STUB_SRC := $(STUB_DIR)/accelerator.c +ASCEND_SRC := $(ASCEND_DIR)/accelerator.c + +# Object files +STUB_OBJ := $(BUILD_DIR)/accelerator_stub.o +ASCEND_OBJ := $(BUILD_DIR)/accelerator_ascend.o + +# Test executables +TEST_BIN := $(BUILD_DIR)/test_accelerator + +.PHONY: all clean stub ascend test install + +all: stub + +# Build stub implementation +stub: $(STUB_LIB) + +$(STUB_LIB): $(STUB_OBJ) | $(BUILD_DIR) + $(CC) $(LDFLAGS) -o $@ $< + +$(STUB_OBJ): $(STUB_SRC) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) -c -o $@ $< + +# Build Ascend implementation (requires Ascend CANN SDK) +ascend: $(ASCEND_LIB) + +$(ASCEND_LIB): $(ASCEND_OBJ) | $(BUILD_DIR) + $(CC) $(LDFLAGS) -o $@ $< $(ASCEND_LDFLAGS) + +$(ASCEND_OBJ): $(ASCEND_SRC) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) $(ASCEND_CFLAGS) -c -o $@ $< + +# Build test executable +test: $(TEST_BIN) + +$(TEST_BIN): $(TEST_DIR)/test_accelerator.c $(STUB_LIB) | $(BUILD_DIR) + $(CC) $(CFLAGS) -I$(PROVIDER_DIR) -o $@ $(TEST_DIR)/test_accelerator.c -L$(BUILD_DIR) -laccelerator_stub -Wl,-rpath,$(BUILD_DIR) + +# Run tests +test-run: test + LD_LIBRARY_PATH=$(BUILD_DIR):$$LD_LIBRARY_PATH $(TEST_BIN) + +# Create build directory +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +# Clean build artifacts +clean: + rm -rf $(BUILD_DIR) + +# Install libraries to system path (optional) +install: $(STUB_LIB) + install -d /usr/local/lib/tensor-fusion + install -m 755 $(STUB_LIB) /usr/local/lib/tensor-fusion/ + install -d /usr/local/include/tensor-fusion + install -m 644 $(PROVIDER_DIR)/accelerator.h /usr/local/include/tensor-fusion/ + install -m 644 $(PROVIDER_DIR)/limiter.h /usr/local/include/tensor-fusion/ + +# Help target +help: + @echo "Available targets:" + @echo " all - Build stub implementation (default)" + @echo " stub - Build stub accelerator library" + @echo " ascend - Build Ascend accelerator library (requires CANN SDK)" + @echo " test - Build test executable" + @echo " test-run - Build and run tests" + @echo " clean - Remove build artifacts" + @echo " install - Install libraries to system path" + @echo " help - Show this help message" + diff --git a/provider/README.md b/provider/README.md new file mode 100644 index 00000000..d6a7ffb5 --- /dev/null +++ b/provider/README.md @@ -0,0 +1,129 @@ +# Accelerator Provider Interface + +This directory contains the abstract ABI (Application Binary Interface) for vGPU vendor accelerator libraries. + +## Overview + +The accelerator interface abstracts vGPU vendor-specific implementations into a unified API, supporting four isolation modes: + +- **Shared Mode**: Oversubscription, high elasticity, no resource control (equivalent to NVIDIA timeslicing) +- **Soft Mode**: Oversubscription, high elasticity, time-sharing resource control via hooks and limiter +- **Hard Mode**: No oversubscription, medium elasticity, space-sharing via one-time resource limits +- **Partitioned Mode**: No oversubscription, low elasticity, hardware/driver-level partitioning (e.g., MIG) + +## Structure + +``` +provider/ +├── accelerator.h # Main interface definition +├── limiter.h # Limiter.so API (not vendor-implemented) +├── Makefile # Build scripts +├── stub/ +│ └── accelerator.c # Stub implementation for testing +├── ascend/ +│ └── accelerator.c # Huawei Ascend implementation +└── test/ + └── test_accelerator.c # Test suite +``` + +## Building + +### Build Stub Implementation + +```bash +cd provider +make stub +``` + +### Build Ascend Implementation + +```bash +cd provider +make ascend +``` + +### Run Tests + +```bash +cd provider +make test-run +``` + +## Interface Categories + +### 1. DeviceInfo APIs + +- `getDeviceInfo()`: Get device information (capabilities, basic info, NUMA, etc.) +- `getPartitionTemplates()`: Get hardware partition templates (e.g., MIG) +- `getDeviceTopology()`: Get device topology (NVLink, IB NIC, etc.) + +### 2. Virtualization APIs + +#### Partitioned Isolation +- `assignPartition()`: Assign hardware partition (returns partitionOverhead) +- `removePartition()`: Remove partition + +#### Hard Isolation +- `setMemHardLimit()`: Set hard memory limit (one-time) +- `setComputeUnitHardLimit()`: Set hard compute limit (one-time) + +#### Snapshot/Migration +- `snapshot()`: Snapshot device state for processes +- `resume()`: Resume device state for processes + +### 3. Metrics APIs + +- `getProcessComputeUtilization()`: Get compute utilization per process +- `getProcessMemoryUtilization()`: Get memory utilization per process +- `getDeviceMetrics()`: Get basic device metrics (power, PCIe, SM active, TC usage) +- `getExtendedDeviceMetrics()`: Get extended metrics (NVLink bandwidth, etc.) + +## Vendor Implementations + +### Stub Implementation + +The stub implementation (`stub/accelerator.c`) provides a reference implementation for testing and development. + +### Ascend Implementation + +The Ascend implementation (`ascend/accelerator.c`) provides support for Huawei Ascend accelerators: + +- Supports Soft and Hard isolation modes +- Does not support hardware partitioning (MIG-like features) +- Uses HCCS (Huawei Cache Coherent System) for device interconnects +- Typical device: Ascend 910 with 32GB memory, 2 AI cores, 320 TFLOPS (FP16) + +## Usage in Hypervisor + +The hypervisor uses the accelerator library via CGO bindings: + +```go +import "github.com/NexusGPU/tensor-fusion/internal/hypervisor/device" + +mgr, err := device.NewManager("path/to/libaccelerator.so", 30*time.Second) +``` + +See `internal/hypervisor/device/` for the Go bindings and device manager implementation. + +## Testing + +All tests pass successfully: + +```bash +$ make test-run +======================================== +Accelerator Library Test Suite +======================================== +Total tests: 47 +Passed: 47 +Failed: 0 +All tests passed! ✓ +``` + +## Notes + +- All struct parameters are carefully designed with key attributes +- Memory management: Use provided cleanup functions to free allocated memory +- Thread safety: Vendor implementations should be thread-safe +- Error handling: All APIs return Result enum for error handling + diff --git a/provider/accelerator.h b/provider/accelerator.h new file mode 100644 index 00000000..dbe25b54 --- /dev/null +++ b/provider/accelerator.h @@ -0,0 +1,433 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACCELERATOR_H +#define ACCELERATOR_H + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================ +// Common Types +// ============================================================================ + +typedef enum { + RESULT_SUCCESS = 0, + RESULT_ERROR_INVALID_PARAM = 1, + RESULT_ERROR_NOT_FOUND = 2, + RESULT_ERROR_NOT_SUPPORTED = 3, + RESULT_ERROR_RESOURCE_EXHAUSTED = 4, + RESULT_ERROR_OPERATION_FAILED = 5, + RESULT_ERROR_INTERNAL = 6 +} Result; + +typedef enum { + ISOLATION_MODE_SHARED = 0, // Timeslicing, no resource control + ISOLATION_MODE_SOFT = 1, // Hook-based, token-based limiting + ISOLATION_MODE_HARD = 2, // One-time resource limits + ISOLATION_MODE_PARTITIONED = 3 // Hardware/driver-level partitioning (MIG) +} IsolationMode; + +// ============================================================================ +// DeviceInfo Types +// ============================================================================ + +// Device capabilities +typedef struct { + bool supportsPartitioning; // e.g., MIG support + bool supportsSoftIsolation; // Hook-based isolation support + bool supportsHardIsolation; // One-time limit support + bool supportsSnapshot; // Process snapshot/resume support + bool supportsMetrics; // Metrics collection support + uint32_t maxPartitions; // Maximum number of partitions + uint32_t maxWorkersPerDevice; // Maximum workers per device +} DeviceCapabilities; + +// Basic device information +typedef struct { + char uuid[64]; // Device UUID + char vendor[32]; // Vendor name (e.g., "NVIDIA", "AMD") + char model[128]; // Model name (e.g., "A100", "H100") + char driverVersion[64]; // Driver version + char firmwareVersion[64]; // Firmware version + int32_t index; // Device index + int32_t numaNode; // NUMA node ID (-1 if not assigned) + uint64_t totalMemoryBytes; // Total memory in bytes + uint64_t totalComputeUnits; // Total compute units (e.g., SMs for NVIDIA) + double maxTflops; // Maximum TFLOPS + uint32_t pcieGen; // PCIe generation + uint32_t pcieWidth; // PCIe width (lanes) +} DeviceBasicInfo; + +// Device properties +typedef struct { + uint32_t clockGraphics; // Graphics clock (MHz) + uint32_t clockSM; // SM clock (MHz) - for NVIDIA + uint32_t clockMem; // Memory clock (MHz) + uint32_t clockAI; // AI core clock (MHz) - for Ascend + uint32_t powerLimit; // Power limit (W) + uint32_t temperatureThreshold; // Temperature threshold (C) + bool eccEnabled; // ECC enabled + bool persistenceModeEnabled; // Persistence mode + char computeCapability[16]; // Compute capability (e.g., "8.0", "9.0" for NVIDIA, "Ascend310" for Ascend) + char chipType[32]; // Chip type (e.g., "NVIDIA", "Ascend", "AMD") +} DeviceProperties; + +// Related device information (for topology) +typedef struct { + char deviceUUID[64]; // Related device UUID + char connectionType[32]; // Connection type (e.g., "NVLink", "PCIe", "IB") + uint32_t bandwidthMBps; // Bandwidth in MB/s + uint32_t latencyNs; // Latency in nanoseconds +} RelatedDevice; + +// Extended device information +typedef struct { + DeviceBasicInfo basic; + DeviceProperties props; + RelatedDevice* relatedDevices; // Array of related devices + size_t relatedDeviceCount; // Number of related devices + DeviceCapabilities capabilities; +} ExtendedDeviceInfo; + +// Partition template for hardware partitioning (e.g., MIG) +typedef struct { + char templateId[64]; // Template identifier + char name[128]; // Human-readable name + uint64_t memoryBytes; // Memory allocated to partition + uint64_t computeUnits; // Compute units allocated + double tflops; // TFLOPS for this partition + uint32_t sliceCount; // Number of slices (for MIG) + bool isDefault; // Is this a default template + char description[256]; // Description +} PartitionTemplate; + +// Device topology information +typedef struct { + char deviceUUID[64]; // Device UUID + int32_t numaNode; // NUMA node + RelatedDevice* connections; // Array of connections + size_t connectionCount; // Number of connections +} DeviceTopology; + +// Extended topology (includes NVLink, IB NIC, etc.) +typedef struct { + DeviceTopology* devices; // Array of device topologies + size_t deviceCount; // Number of devices + uint32_t nvlinkBandwidthMBps; // NVLink total bandwidth + uint32_t ibNicCount; // InfiniBand NIC count + char topologyType[32]; // Topology type (e.g., "NVLink", "PCIe") +} ExtendedDeviceTopology; + +// ============================================================================ +// Virtualization Types +// ============================================================================ + +// Partition assignment request +typedef struct { + char templateId[64]; // Template ID to use + char deviceUUID[64]; // Target device UUID + char partitionUUID[64]; // Output: assigned partition UUID +} PartitionAssignment; + +// Worker information for isolation +typedef struct { + char workerId[64]; // Worker identifier + char deviceUUID[64]; // Device UUID + pid_t processId; // Process ID + uint64_t memoryLimitBytes; // Memory limit (for hard isolation) + uint32_t computeUnitLimit; // Compute unit limit (for hard isolation) + IsolationMode isolationMode; // Isolation mode +} WorkerInfo; + +// Process array for snapshot/resume +typedef struct { + pid_t* processIds; // Array of process IDs + size_t processCount; // Number of processes + char deviceUUID[64]; // Device UUID +} ProcessArray; + +// ============================================================================ +// Metrics Types +// ============================================================================ + +// Extra metric key-value pair +typedef struct { + char key[64]; // Metric key name + double value; // Metric value +} ExtraMetric; + +// Compute utilization +typedef struct { + char processId[32]; // Process ID as string + char deviceUUID[64]; // Device UUID + double utilizationPercent; // Utilization percentage (0-100) + uint64_t activeSMs; // Active SMs/Compute Units + uint64_t totalSMs; // Total SMs/Compute Units + double tflopsUsed; // TFLOPS currently used +} ComputeUtilization; + +// Memory utilization +typedef struct { + char processId[32]; // Process ID as string + char deviceUUID[64]; // Device UUID + uint64_t usedBytes; // Memory used in bytes + uint64_t reservedBytes; // Memory reserved in bytes + double utilizationPercent; // Utilization percentage (0-100) +} MemoryUtilization; + +// Basic device metrics +typedef struct { + char deviceUUID[64]; // Device UUID + double powerUsageWatts; // Current power usage (W) + double temperatureCelsius; // Temperature (C) + uint64_t pcieRxBytes; // PCIe RX bytes + uint64_t pcieTxBytes; // PCIe TX bytes + uint32_t smActivePercent; // SM active percentage + uint32_t tensorCoreUsagePercent; // Tensor Core usage percentage + uint64_t memoryUsedBytes; // Memory used + uint64_t memoryTotalBytes; // Memory total + ExtraMetric* extraMetrics; // Array of extra metrics (key-value pairs) + size_t extraMetricsCount; // Number of extra metrics +} DeviceMetrics; + +// Extended device metrics (NVLink, etc.) +typedef struct { + char deviceUUID[64]; // Device UUID + uint32_t* nvlinkBandwidthMBps; // NVLink bandwidth per link (MB/s) + size_t nvlinkCount; // Number of NVLink connections + uint64_t* ibNicBandwidthMBps; // IB NIC bandwidth per NIC (MB/s) + size_t ibNicCount; // Number of IB NICs + uint32_t* pcieBandwidthMBps; // PCIe bandwidth per link (MB/s) + size_t pcieLinkCount; // Number of PCIe links +} ExtendedDeviceMetrics; + +// ============================================================================ +// DeviceInfo APIs +// ============================================================================ + +/** + * Get the number of available devices. + * + * @param deviceCount Output parameter for number of devices + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetDeviceCount(size_t* deviceCount); + +/** + * Get all available devices information. + * + * @param devices Output buffer for device information (allocated by caller) + * @param maxCount Maximum number of devices that can fit in the buffer + * @param deviceCount Output parameter for number of devices actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetAllDevices(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount); + +/** + * Get device topology including NVLink, IB NIC, and other interconnects. + * + * @param deviceIndexArray Array of device indices to query + * @param deviceCount Number of devices in array + * @param topology Output parameter for extended topology (allocated by caller) + * @param maxConnectionsPerDevice Maximum number of connections per device in topology buffer + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetDeviceTopology(int32_t* deviceIndexArray, size_t deviceCount, ExtendedDeviceTopology* topology, size_t maxConnectionsPerDevice); + +// ============================================================================ +// Virtualization APIs - Partitioned Isolation +// ============================================================================ + +/** + * Assign a partition to a device using a template (e.g., create MIG instance). + * + * @param assignment Partition assignment request (templateId, deviceUUID) + * Output: partitionUUID and partitionOverheadBytes + * @return true on success, false otherwise + */ +bool AssignPartition(PartitionAssignment* assignment); + +/** + * Remove a partition from a device. + * + * @param templateId Template ID used to create the partition + * @param deviceUUID Device UUID + * @return true on success, false otherwise + */ +bool RemovePartition(const char* templateId, const char* deviceUUID); + +// ============================================================================ +// Virtualization APIs - Hard Isolation +// ============================================================================ + +/** + * Set hard memory limit for a worker (one-time, called at worker start by limiter.so). + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param memoryLimitBytes Memory limit in bytes + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result SetMemHardLimit(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes); + +/** + * Set hard compute unit limit for a worker (one-time, called at worker start). + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param computeUnitLimit Compute unit limit (e.g., percentage 0-100) + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result SetComputeUnitHardLimit(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit); + +// ============================================================================ +// Virtualization APIs - Device Snapshot/Migration +// ============================================================================ + +/** + * Snapshot device state for processes (lock processes, checkpoint state). + * Called from hypervisor for migration. + * + * @param processes Array of processes to snapshot + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Snapshot(ProcessArray* processes); + +/** + * Resume device state for processes (unlock processes, restore state). + * Called from hypervisor after migration. + * + * @param processes Array of processes to resume + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Resume(ProcessArray* processes); + +// ============================================================================ +// Metrics APIs +// ============================================================================ + +/** + * Get compute utilization for all processes on all devices. + * + * @param utilizations Output buffer for compute utilizations (allocated by caller) + * @param maxCount Maximum number of utilizations that can fit in the buffer + * @param utilizationCount Output parameter for number of utilizations actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetProcessComputeUtilization( + ComputeUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +); + +/** + * Get memory utilization for all processes on all devices. + * + * @param utilizations Output buffer for memory utilizations (allocated by caller) + * @param maxCount Maximum number of utilizations that can fit in the buffer + * @param utilizationCount Output parameter for number of utilizations actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetProcessMemoryUtilization( + MemoryUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +); + +/** + * Get basic device metrics (power, PCIe, SM active, TC usage, etc.). + * + * @param deviceUUIDArray Array of device UUIDs + * @param deviceCount Number of devices + * @param metrics Output buffer for device metrics (allocated by caller, size >= deviceCount) + * @param maxExtraMetricsPerDevice Maximum number of extra metrics per device + * @return RESULT_SUCCESS on success, error code otherwise + * + * Note: Caller must allocate extraMetrics arrays for each device metric. + * Each metrics[i].extraMetrics should point to an array of size maxExtraMetricsPerDevice. + * The function will fill in the metrics and set extraMetricsCount for each device. + */ +Result GetDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + DeviceMetrics* metrics, + size_t maxExtraMetricsPerDevice +); + +/** + * Get extended device metrics (NVLink bandwidth, etc.). + * + * @param deviceUUIDArray Array of device UUIDs + * @param deviceCount Number of devices + * @param metrics Output buffer for extended device metrics (allocated by caller, size >= deviceCount) + * @param maxNvlinkPerDevice Maximum number of NVLink connections per device + * @param maxIbNicPerDevice Maximum number of IB NICs per device + * @param maxPciePerDevice Maximum number of PCIe links per device + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetExtendedDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + ExtendedDeviceMetrics* metrics, + size_t maxNvlinkPerDevice, + size_t maxIbNicPerDevice, + size_t maxPciePerDevice +); + + +typedef struct { + char* hostPath; // Host path + char* guestPath; // Guest path +} Mount; +/** + * Get vendor mount libs. + * + * @param mounts Output buffer for vendor mount libs (allocated by caller) + * @param maxCount Maximum number of mounts that can fit in the buffer + * @param mountCount Output parameter for number of mounts actually returned + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result GetVendorMountLibs(Mount* mounts, size_t maxCount, size_t* mountCount); + +// ============================================================================ +// Utility APIs +// ============================================================================ + +/** + * Log a message (for debugging and diagnostics). + * + * @param level Log level (e.g., "DEBUG", "INFO", "WARN", "ERROR") + * @param message Log message + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result Log(const char* level, const char* message); + +#ifdef __cplusplus +} +#endif + +// Include limiter.h after defining Result enum +#include "limiter.h" + +#endif // ACCELERATOR_H + diff --git a/provider/limiter.h b/provider/limiter.h new file mode 100644 index 00000000..681a0ec2 --- /dev/null +++ b/provider/limiter.h @@ -0,0 +1,140 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIMITER_H +#define LIMITER_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================ +// Limiter Types +// ============================================================================ + +// Memory operation record +typedef struct { + char deviceUUID[64]; // Device UUID + int64_t bytesDiff; // Bytes difference (positive = allocation, negative = deallocation) + bool shouldBlock; // Output: whether this operation should be blocked + uint64_t availableBytes; // Output: available bytes after this operation +} MemoryOpRecord; + +// Compute operation record +typedef struct { + char deviceUUID[64]; // Device UUID + uint64_t computeTokens; // Compute tokens consumed (e.g., SM-cycles) + bool shouldBlock; // Output: whether this operation should be blocked + uint64_t availableTokens; // Output: available tokens after this operation +} ComputeOpRecord; + +// Worker freeze state +typedef struct { + char workerId[64]; // Worker identifier + bool isFrozen; // Current freeze state + uint64_t freezeTimeMs; // Time frozen in milliseconds +} WorkerFreezeState; + +// ============================================================================ +// Limiter APIs (Implemented by limiter.so, NOT by vendor accelerator.so) +// ============================================================================ + +/** + * Check and record memory operations for soft isolation. + * This API is called from hooks in CUDA runtime (via dlsym replacement). + * + * @param processId Process identifier + * @param deviceUUID Device UUID + * @param bytesDiff Bytes difference (positive = allocation, negative = deallocation) + * @param record Output parameter for operation record + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result CheckAndRecordMemoryOps(const char* processId, const char* deviceUUID, int64_t bytesDiff, MemoryOpRecord* record); + +/** + * Check and record compute operations for soft isolation. + * This API is called from hooks in CUDA runtime (via dlsym replacement). + * + * @param processId Process identifier + * @param deviceUUID Device UUID + * @param computeTokens Compute tokens consumed (e.g., SM-cycles) + * @param record Output parameter for operation record + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result CheckAndRecordComputeOps(const char* processId, const char* deviceUUID, uint64_t computeTokens, ComputeOpRecord* record); + +/** + * Freeze a worker process (pause execution when resource limit reached). + * This API is called automatically when resources are exhausted. + * + * @param workerId Worker identifier + * @param state Output parameter for freeze state + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result FreezeWorker(const char* workerId, WorkerFreezeState* state); + +/** + * Resume a worker process (resume execution when resources become available). + * This API is called automatically when resources become available. + * + * @param workerId Worker identifier + * @param state Output parameter for freeze state + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result ResumeWorker(const char* workerId, WorkerFreezeState* state); + +/** + * Auto-freeze hook: called when resource limit is reached. + * This triggers automatic freezing of the worker. + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param resourceType Resource type ("memory" or "compute") + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AutoFreeze(const char* workerId, const char* deviceUUID, const char* resourceType); + +/** + * Auto-resume hook: called when resources become available. + * This triggers automatic resuming of the worker. + * + * @param workerId Worker identifier + * @param deviceUUID Device UUID + * @param resourceType Resource type ("memory" or "compute") + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AutoResume(const char* workerId, const char* deviceUUID, const char* resourceType); + +/** + * Add a worker process to the limiter tracking. + * This API is called when a process starts using a device. + * + * @param deviceUUID Device UUID + * @param processId Process identifier (as string) + * @return RESULT_SUCCESS on success, error code otherwise + */ +Result AddWorkerProcess(const char* deviceUUID, const char* processId); + +#ifdef __cplusplus +} +#endif + +#endif // LIMITER_H + diff --git a/provider/stub/accelerator.c b/provider/stub/accelerator.c new file mode 100644 index 00000000..af5e76a3 --- /dev/null +++ b/provider/stub/accelerator.c @@ -0,0 +1,598 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Feature test macros for POSIX functions (required on Linux) +#define _POSIX_C_SOURCE 200809L +#define _DEFAULT_SOURCE + +#include "../accelerator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// Global Variables for Limiter Thread +// ============================================================================ + +static const char* g_processId = "stub-process-0"; +static _Atomic uint64_t g_lastComputeCallTimeMs = 0; // Last call time in milliseconds +static pthread_t g_limiterThread; +static volatile int g_threadRunning = 0; + +// ============================================================================ +// Limiter Thread Function +// ============================================================================ + +static void* limiterThreadFunc(void* arg __attribute__((unused))) { + // Get first device UUID for testing + ExtendedDeviceInfo devices[256]; // Stack-allocated buffer + size_t deviceCount = 0; + char deviceUUID[64] = {0}; + + if (GetAllDevices(devices, 256, &deviceCount) != RESULT_SUCCESS || deviceCount == 0) { + return NULL; + } + snprintf(deviceUUID, sizeof(deviceUUID), "%s", devices[0].basic.uuid); + + // Add worker process to limiter tracking + AddWorkerProcess(deviceUUID, g_processId); + + // Call CheckAndRecordMemoryOps once + MemoryOpRecord memRecord; + CheckAndRecordMemoryOps(g_processId, deviceUUID, 0, &memRecord); + + // Call CheckAndRecordComputeOps every second + while (g_threadRunning) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + uint64_t currentTimeMs = (uint64_t)ts.tv_sec * 1000 + (uint64_t)ts.tv_nsec / 1000000; + + ComputeOpRecord computeRecord; + CheckAndRecordComputeOps(g_processId, deviceUUID, 1000, &computeRecord); + + // Update global variable + g_lastComputeCallTimeMs = currentTimeMs; + + // Sleep for 1 second + sleep(1); + } + + return NULL; +} + +// ============================================================================ +// Constructor - Initialize Limiter Thread +// ============================================================================ + +__attribute__((constructor)) +static void initLimiterThread(void) { + g_threadRunning = 1; + if (pthread_create(&g_limiterThread, NULL, limiterThreadFunc, NULL) != 0) { + fprintf(stderr, "Failed to create limiter thread\n"); + return; + } + pthread_detach(g_limiterThread); +} + +// ============================================================================ +// Destructor - Cleanup Limiter Thread +// ============================================================================ + +__attribute__((destructor)) +static void cleanupLimiterThread(void) { + g_threadRunning = 0; + // Thread will exit on next iteration +} + +// ============================================================================ +// Stub Implementation - Limiter APIs +// ============================================================================ + +Result AddWorkerProcess(const char* deviceUUID, const char* processId) { + (void)deviceUUID; // Unused in stub + (void)processId; // Unused in stub + return RESULT_SUCCESS; +} + +Result CheckAndRecordMemoryOps(const char* processId, const char* deviceUUID, int64_t bytesDiff, MemoryOpRecord* record) { + (void)processId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)bytesDiff; // Unused in stub + + if (!record) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always allow, set available bytes to a large value + record->shouldBlock = false; + record->availableBytes = 16ULL * 1024 * 1024 * 1024; // 16GB + return RESULT_SUCCESS; +} + +Result CheckAndRecordComputeOps(const char* processId, const char* deviceUUID, uint64_t computeTokens, ComputeOpRecord* record) { + (void)processId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)computeTokens; // Unused in stub + + if (!record) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always allow, set available tokens to a large value + record->shouldBlock = false; + record->availableTokens = 1000000; // Large token pool + return RESULT_SUCCESS; +} + +Result FreezeWorker(const char* workerId, WorkerFreezeState* state) { + (void)workerId; // Unused in stub + if (!state) { + return RESULT_ERROR_INVALID_PARAM; + } + state->isFrozen = false; + state->freezeTimeMs = 0; + return RESULT_SUCCESS; +} + +Result ResumeWorker(const char* workerId, WorkerFreezeState* state) { + (void)workerId; // Unused in stub + if (!state) { + return RESULT_ERROR_INVALID_PARAM; + } + state->isFrozen = false; + state->freezeTimeMs = 0; + return RESULT_SUCCESS; +} + +Result AutoFreeze(const char* workerId, const char* deviceUUID, const char* resourceType) { + (void)workerId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)resourceType; // Unused in stub + return RESULT_SUCCESS; +} + +Result AutoResume(const char* workerId, const char* deviceUUID, const char* resourceType) { + (void)workerId; // Unused in stub + (void)deviceUUID; // Unused in stub + (void)resourceType; // Unused in stub + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - DeviceInfo APIs +// ============================================================================ + +Result GetDeviceCount(size_t* deviceCount) { + if (!deviceCount) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 4 devices + *deviceCount = 4; + return RESULT_SUCCESS; +} + +// Helper function to initialize a single device info +static void initDeviceInfo(ExtendedDeviceInfo* info, int32_t deviceIndex) { + // Initialize basic info + snprintf(info->basic.uuid, sizeof(info->basic.uuid), "stub-device-%d", deviceIndex); + snprintf(info->basic.vendor, sizeof(info->basic.vendor), "STUB"); + snprintf(info->basic.model, sizeof(info->basic.model), "Stub-GPU-Model"); + snprintf(info->basic.driverVersion, sizeof(info->basic.driverVersion), "1.0.0-stub"); + snprintf(info->basic.firmwareVersion, sizeof(info->basic.firmwareVersion), "1.0.0-stub"); + info->basic.index = deviceIndex; + info->basic.numaNode = deviceIndex % 2; // Stub: alternate NUMA nodes + info->basic.totalMemoryBytes = 16ULL * 1024 * 1024 * 1024; // 16GB + info->basic.totalComputeUnits = 108; // Stub: 108 SMs + info->basic.maxTflops = 312.0; // Stub: 312 TFLOPS + info->basic.pcieGen = 4; + info->basic.pcieWidth = 16; + + // Initialize properties + info->props.clockGraphics = 1410; // MHz + info->props.clockSM = 1410; // MHz + info->props.clockMem = 1215; // MHz + info->props.powerLimit = 400; // W + info->props.temperatureThreshold = 83; // C + info->props.eccEnabled = true; + info->props.persistenceModeEnabled = false; + snprintf(info->props.computeCapability, sizeof(info->props.computeCapability), "8.0"); + info->props.clockAI = 0; // Not applicable for stub + snprintf(info->props.chipType, sizeof(info->props.chipType), "STUB"); + + // Initialize capabilities + info->capabilities.supportsPartitioning = true; + info->capabilities.supportsSoftIsolation = true; + info->capabilities.supportsHardIsolation = true; + info->capabilities.supportsSnapshot = true; + info->capabilities.supportsMetrics = true; + info->capabilities.maxPartitions = 7; + info->capabilities.maxWorkersPerDevice = 16; + + // Initialize related devices (stub: no related devices) + info->relatedDevices = NULL; + info->relatedDeviceCount = 0; +} + +Result GetAllDevices(ExtendedDeviceInfo* devices, size_t maxCount, size_t* deviceCount) { + if (!devices || !deviceCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 4 devices (but not more than maxCount) + size_t actualCount = 4; + if (actualCount > maxCount) { + actualCount = maxCount; + } + *deviceCount = actualCount; + + // Initialize each device + for (size_t i = 0; i < actualCount; i++) { + initDeviceInfo(&devices[i], (int32_t)i); + } + + return RESULT_SUCCESS; +} + +Result GetPartitionTemplates(int32_t deviceIndex __attribute__((unused)), PartitionTemplate* templates, size_t maxCount, size_t* templateCount) { + if (!templates || !templateCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: return 3 example templates (but not more than maxCount) + size_t actualCount = 3; + if (actualCount > maxCount) { + actualCount = maxCount; + } + *templateCount = actualCount; + + // Template 1: 1/7 slice + if (actualCount > 0) { + PartitionTemplate* t1 = &templates[0]; + snprintf(t1->templateId, sizeof(t1->templateId), "mig-1g.7gb"); + snprintf(t1->name, sizeof(t1->name), "1/7 GPU Slice"); + t1->memoryBytes = 7ULL * 1024 * 1024 * 1024; // 7GB + t1->computeUnits = 14; // 1/7 of 108 SMs + t1->tflops = 312.0 / 7.0; // ~44.6 TFLOPS + t1->sliceCount = 1; + t1->isDefault = false; + snprintf(t1->description, sizeof(t1->description), "1/7 GPU slice with 7GB memory"); + } + + // Template 2: 2/7 slice + if (actualCount > 1) { + PartitionTemplate* t2 = &templates[1]; + snprintf(t2->templateId, sizeof(t2->templateId), "mig-2g.14gb"); + snprintf(t2->name, sizeof(t2->name), "2/7 GPU Slice"); + t2->memoryBytes = 14ULL * 1024 * 1024 * 1024; // 14GB + t2->computeUnits = 28; // 2/7 of 108 SMs + t2->tflops = 312.0 * 2.0 / 7.0; // ~89.1 TFLOPS + t2->sliceCount = 2; + t2->isDefault = true; + snprintf(t2->description, sizeof(t2->description), "2/7 GPU slice with 14GB memory"); + } + + // Template 3: 3/7 slice + if (actualCount > 2) { + PartitionTemplate* t3 = &templates[2]; + snprintf(t3->templateId, sizeof(t3->templateId), "mig-3g.21gb"); + snprintf(t3->name, sizeof(t3->name), "3/7 GPU Slice"); + t3->memoryBytes = 21ULL * 1024 * 1024 * 1024; // 21GB (stub, exceeds total) + t3->computeUnits = 42; // 3/7 of 108 SMs + t3->tflops = 312.0 * 3.0 / 7.0; // ~133.7 TFLOPS + t3->sliceCount = 3; + t3->isDefault = false; + snprintf(t3->description, sizeof(t3->description), "3/7 GPU slice with 21GB memory"); + } + + return RESULT_SUCCESS; +} + +Result GetDeviceTopology(int32_t* deviceIndexArray, size_t deviceCount, ExtendedDeviceTopology* topology, size_t maxConnectionsPerDevice) { + if (!deviceIndexArray || deviceCount == 0 || !topology || maxConnectionsPerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Note: topology->devices must be pre-allocated by caller with size >= deviceCount + // topology->devices[i].connections must be pre-allocated by caller with size >= maxConnectionsPerDevice + if (!topology->devices) { + return RESULT_ERROR_INVALID_PARAM; + } + topology->deviceCount = deviceCount; + + // Initialize each device topology + for (size_t i = 0; i < deviceCount; i++) { + DeviceTopology* dt = &topology->devices[i]; + snprintf(dt->deviceUUID, sizeof(dt->deviceUUID), "stub-device-%d", deviceIndexArray[i]); + dt->numaNode = deviceIndexArray[i] % 2; + + // Stub: create connections to other devices + size_t connectionCount = (deviceCount > 1) ? (deviceCount - 1) : 0; + if (connectionCount > maxConnectionsPerDevice) { + connectionCount = maxConnectionsPerDevice; + } + + if (connectionCount > 0 && dt->connections) { + dt->connectionCount = connectionCount; + + size_t connIdx = 0; + for (size_t j = 0; j < deviceCount && connIdx < connectionCount; j++) { + if (j != i) { + RelatedDevice* rd = &dt->connections[connIdx]; + snprintf(rd->deviceUUID, sizeof(rd->deviceUUID), "stub-device-%d", deviceIndexArray[j]); + snprintf(rd->connectionType, sizeof(rd->connectionType), "NVLink"); + rd->bandwidthMBps = 600000; // 600 GB/s (stub) + rd->latencyNs = 100; // 100ns (stub) + connIdx++; + } + } + } else { + dt->connections = NULL; + dt->connectionCount = 0; + } + } + + // Set extended topology info + topology->nvlinkBandwidthMBps = 600000 * deviceCount; // Total bandwidth + topology->ibNicCount = 0; // Stub: no IB NICs + snprintf(topology->topologyType, sizeof(topology->topologyType), "NVLink"); + + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Partitioned Isolation +// ============================================================================ + +bool AssignPartition(PartitionAssignment* assignment) { + if (!assignment || assignment->templateId[0] == '\0' || assignment->deviceUUID[0] == '\0') { + return false; + } + + // Stub: generate a partition UUID + // Limit string lengths to ensure output fits in 64-byte buffer: + // "partition-" (9) + templateId (26) + "-" (1) + deviceUUID (26) + null (1) = 63 bytes + snprintf(assignment->partitionUUID, sizeof(assignment->partitionUUID), + "partition-%.26s-%.26s", assignment->templateId, assignment->deviceUUID); + + return true; +} + +bool RemovePartition(const char* templateId, const char* deviceUUID) { + if (!templateId || !deviceUUID) { + return false; + } + + // Stub: always succeed + return true; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Hard Isolation +// ============================================================================ + +Result SetMemHardLimit(const char* workerId, const char* deviceUUID, uint64_t memoryLimitBytes) { + if (!workerId || !deviceUUID || memoryLimitBytes == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed + return RESULT_SUCCESS; +} + +Result SetComputeUnitHardLimit(const char* workerId, const char* deviceUUID, uint32_t computeUnitLimit) { + if (!workerId || !deviceUUID || computeUnitLimit == 0 || computeUnitLimit > 100) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Virtualization APIs - Device Snapshot/Migration +// ============================================================================ + +Result Snapshot(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: verify processes exist (basic check) + for (size_t i = 0; i < processes->processCount; i++) { + if (kill(processes->processIds[i], 0) != 0) { + // Process doesn't exist or no permission + return RESULT_ERROR_NOT_FOUND; + } + } + + // Stub: always succeed (no actual snapshot implementation) + return RESULT_SUCCESS; +} + +Result Resume(ProcessArray* processes) { + if (!processes || !processes->processIds || processes->processCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Stub: always succeed (no actual resume implementation) + return RESULT_SUCCESS; +} + +// ============================================================================ +// Stub Implementation - Metrics APIs +// ============================================================================ + +Result GetProcessComputeUtilization( + ComputeUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // For now, stub implementation returns empty + // The actual implementation should query limiter for all tracked processes + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetProcessMemoryUtilization( + MemoryUtilization* utilizations, + size_t maxCount, + size_t* utilizationCount +) { + if (!utilizations || !utilizationCount || maxCount == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // TODO: Get actual device and process list from limiter + // For now, stub implementation returns empty + // The actual implementation should query limiter for all tracked processes + *utilizationCount = 0; + return RESULT_SUCCESS; +} + +Result GetDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + DeviceMetrics* metrics, + size_t maxExtraMetricsPerDevice +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics || maxExtraMetricsPerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Fill stub data + for (size_t i = 0; i < deviceCount; i++) { + DeviceMetrics* dm = &metrics[i]; + snprintf(dm->deviceUUID, sizeof(dm->deviceUUID), "%s", deviceUUIDArray[i]); + dm->powerUsageWatts = 200.0 + (i * 10.0); // Stub: 200-300W + dm->temperatureCelsius = 45.0 + (i * 5.0); // Stub: 45-50C + dm->pcieRxBytes = 1024ULL * 1024 * 1024 * (i + 1); // Stub: 1-4GB + dm->pcieTxBytes = 512ULL * 1024 * 1024 * (i + 1); // Stub: 0.5-2GB + dm->smActivePercent = 50 + (i * 10); // Stub: 50-90% + dm->tensorCoreUsagePercent = 30 + (i * 5); // Stub: 30-50% + dm->memoryUsedBytes = 8ULL * 1024 * 1024 * 1024; // Stub: 8GB + dm->memoryTotalBytes = 16ULL * 1024 * 1024 * 1024; // Stub: 16GB + + // Fill extra metrics + if (dm->extraMetrics != NULL && maxExtraMetricsPerDevice > 0) { + size_t extraCount = 0; + + // Add some example extra metrics + if (extraCount < maxExtraMetricsPerDevice) { + snprintf(dm->extraMetrics[extraCount].key, sizeof(dm->extraMetrics[extraCount].key), "gpuUtilization"); + dm->extraMetrics[extraCount].value = 75.0 + (i * 5.0); // Stub: 75-95% + extraCount++; + } + + if (extraCount < maxExtraMetricsPerDevice) { + snprintf(dm->extraMetrics[extraCount].key, sizeof(dm->extraMetrics[extraCount].key), "memoryBandwidthMBps"); + dm->extraMetrics[extraCount].value = 800.0 + (i * 50.0); // Stub: 800-1200 MB/s + extraCount++; + } + + if (extraCount < maxExtraMetricsPerDevice) { + snprintf(dm->extraMetrics[extraCount].key, sizeof(dm->extraMetrics[extraCount].key), "encoderUtilization"); + dm->extraMetrics[extraCount].value = 10.0 + (i * 2.0); // Stub: 10-20% + extraCount++; + } + + if (extraCount < maxExtraMetricsPerDevice) { + snprintf(dm->extraMetrics[extraCount].key, sizeof(dm->extraMetrics[extraCount].key), "decoderUtilization"); + dm->extraMetrics[extraCount].value = 15.0 + (i * 3.0); // Stub: 15-30% + extraCount++; + } + + dm->extraMetricsCount = extraCount; + } else { + dm->extraMetricsCount = 0; + } + } + + return RESULT_SUCCESS; +} + +Result GetExtendedDeviceMetrics( + const char** deviceUUIDArray, + size_t deviceCount, + ExtendedDeviceMetrics* metrics, + size_t maxNvlinkPerDevice, + size_t maxIbNicPerDevice, + size_t maxPciePerDevice +) { + if (!deviceUUIDArray || deviceCount == 0 || !metrics || + maxNvlinkPerDevice == 0 || maxIbNicPerDevice == 0 || maxPciePerDevice == 0) { + return RESULT_ERROR_INVALID_PARAM; + } + + // Fill stub data + // Note: metrics[i].nvlinkBandwidthMBps, ibNicBandwidthMBps, pcieBandwidthMBps + // must be pre-allocated by caller with appropriate sizes + for (size_t i = 0; i < deviceCount; i++) { + ExtendedDeviceMetrics* edm = &metrics[i]; + snprintf(edm->deviceUUID, sizeof(edm->deviceUUID), "%s", deviceUUIDArray[i]); + + // Stub: 6 NVLink connections per device (but not more than max) + edm->nvlinkCount = 6; + if (edm->nvlinkCount > maxNvlinkPerDevice) { + edm->nvlinkCount = maxNvlinkPerDevice; + } + if (edm->nvlinkBandwidthMBps) { + for (size_t j = 0; j < edm->nvlinkCount; j++) { + edm->nvlinkBandwidthMBps[j] = 500000 + (j * 10000); // Stub: 500-550 GB/s + } + } + + // Stub: 2 IB NICs per device (but not more than max) + edm->ibNicCount = 2; + if (edm->ibNicCount > maxIbNicPerDevice) { + edm->ibNicCount = maxIbNicPerDevice; + } + if (edm->ibNicBandwidthMBps) { + for (size_t j = 0; j < edm->ibNicCount; j++) { + edm->ibNicBandwidthMBps[j] = 200000; // Stub: 200 GB/s per NIC + } + } + + // Stub: 1 PCIe link (but not more than max) + edm->pcieLinkCount = 1; + if (edm->pcieLinkCount > maxPciePerDevice) { + edm->pcieLinkCount = maxPciePerDevice; + } + if (edm->pcieBandwidthMBps && edm->pcieLinkCount > 0) { + edm->pcieBandwidthMBps[0] = 32000; // Stub: 32 GB/s (PCIe 4.0 x16) + } + } + + return RESULT_SUCCESS; +} + +Result GetVendorMountLibs(Mount* mounts, size_t maxCount, size_t* mountCount) { + if (!mounts || maxCount == 0 || !mountCount) { + return RESULT_ERROR_INVALID_PARAM; + } + *mountCount = 0; + return RESULT_SUCCESS; +} diff --git a/provider/test/test_accelerator.c b/provider/test/test_accelerator.c new file mode 100644 index 00000000..6b04e3bc --- /dev/null +++ b/provider/test/test_accelerator.c @@ -0,0 +1,293 @@ +/* + * Copyright 2024. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "../accelerator.h" + +// Test result tracking +static int tests_run = 0; +static int tests_passed = 0; +static int tests_failed = 0; + +#define TEST_ASSERT(condition, message) \ + do { \ + tests_run++; \ + if (condition) { \ + tests_passed++; \ + printf(" ✓ %s\n", message); \ + } else { \ + tests_failed++; \ + printf(" ✗ %s\n", message); \ + } \ + } while (0) + +// Test getDeviceInfo +void test_getDeviceInfo() { + printf("\n=== Testing getDeviceInfo ===\n"); + + ExtendedDeviceInfo info; + Result result = getDeviceInfo(0, &info); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceInfo returns success"); + TEST_ASSERT(strlen(info.basic.uuid) > 0, "Device UUID is not empty"); + TEST_ASSERT(strlen(info.basic.vendor) > 0, "Vendor is not empty"); + TEST_ASSERT(strlen(info.basic.model) > 0, "Model is not empty"); + TEST_ASSERT(info.basic.totalMemoryBytes > 0, "Total memory > 0"); + TEST_ASSERT(info.basic.totalComputeUnits > 0, "Total compute units > 0"); + TEST_ASSERT(info.basic.maxTflops > 0, "Max TFLOPS > 0"); + TEST_ASSERT(info.capabilities.maxPartitions > 0, "Max partitions > 0"); + + // Test invalid device index + result = getDeviceInfo(-1, &info); + TEST_ASSERT(result != RESULT_SUCCESS, "Invalid device index returns error"); + + // Cleanup + freeExtendedDeviceInfo(&info); +} + +// Test getPartitionTemplates +void test_getPartitionTemplates() { + printf("\n=== Testing getPartitionTemplates ===\n"); + + PartitionTemplate* templates = NULL; + size_t templateCount = 0; + Result result = getPartitionTemplates(0, &templates, &templateCount); + + TEST_ASSERT(result == RESULT_SUCCESS, "getPartitionTemplates returns success"); + TEST_ASSERT(templates != NULL, "Templates array is not NULL"); + TEST_ASSERT(templateCount > 0, "Template count > 0"); + + if (templates && templateCount > 0) { + TEST_ASSERT(strlen(templates[0].templateId) > 0, "First template has ID"); + TEST_ASSERT(strlen(templates[0].name) > 0, "First template has name"); + TEST_ASSERT(templates[0].memoryBytes > 0, "First template has memory"); + TEST_ASSERT(templates[0].computeUnits > 0, "First template has compute units"); + } + + // Cleanup + freePartitionTemplates(templates, templateCount); +} + +// Test getDeviceTopology +void test_getDeviceTopology() { + printf("\n=== Testing getDeviceTopology ===\n"); + + int32_t deviceIndices[] = {0, 1}; + size_t deviceCount = 2; + ExtendedDeviceTopology topology; + + Result result = getDeviceTopology(deviceIndices, deviceCount, &topology); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceTopology returns success"); + TEST_ASSERT(topology.devices != NULL, "Devices array is not NULL"); + TEST_ASSERT(topology.deviceCount == deviceCount, "Device count matches"); + + if (topology.devices && topology.deviceCount > 0) { + TEST_ASSERT(strlen(topology.devices[0].deviceUUID) > 0, "First device has UUID"); + } + + // Cleanup + freeExtendedDeviceTopology(&topology); +} + +// Test assignPartition +void test_assignPartition() { + printf("\n=== Testing assignPartition ===\n"); + + PartitionAssignment assignment; + snprintf(assignment.templateId, sizeof(assignment.templateId), "mig-1g.7gb"); + snprintf(assignment.deviceUUID, sizeof(assignment.deviceUUID), "stub-device-0"); + + bool result = assignPartition(&assignment); + + TEST_ASSERT(result == true, "assignPartition returns true"); + TEST_ASSERT(strlen(assignment.partitionUUID) > 0, "Partition UUID is assigned"); + TEST_ASSERT(assignment.partitionOverheadBytes > 0, "Partition overhead > 0"); + + // Test invalid input + PartitionAssignment invalid; + invalid.templateId[0] = '\0'; + invalid.deviceUUID[0] = '\0'; + result = assignPartition(&invalid); + TEST_ASSERT(result == false, "Invalid assignment returns false"); +} + +// Test removePartition +void test_removePartition() { + printf("\n=== Testing removePartition ===\n"); + + bool result = removePartition("mig-1g.7gb", "stub-device-0"); + TEST_ASSERT(result == true, "removePartition returns true"); + + result = removePartition(NULL, "stub-device-0"); + TEST_ASSERT(result == false, "NULL templateId returns false"); +} + +// Test setMemHardLimit +void test_setMemHardLimit() { + printf("\n=== Testing setMemHardLimit ===\n"); + + Result result = setMemHardLimit("worker-1", "stub-device-0", 4ULL * 1024 * 1024 * 1024); + TEST_ASSERT(result == RESULT_SUCCESS, "setMemHardLimit returns success"); + + result = setMemHardLimit(NULL, "stub-device-0", 4ULL * 1024 * 1024 * 1024); + TEST_ASSERT(result == RESULT_ERROR_INVALID_PARAM, "NULL workerId returns error"); +} + +// Test setComputeUnitHardLimit +void test_setComputeUnitHardLimit() { + printf("\n=== Testing setComputeUnitHardLimit ===\n"); + + Result result = setComputeUnitHardLimit("worker-1", "stub-device-0", 50); + TEST_ASSERT(result == RESULT_SUCCESS, "setComputeUnitHardLimit returns success"); + + result = setComputeUnitHardLimit("worker-1", "stub-device-0", 150); + TEST_ASSERT(result == RESULT_ERROR_INVALID_PARAM, "Invalid limit > 100 returns error"); +} + +// Test getProcessComputeUtilization +void test_getProcessComputeUtilization() { + printf("\n=== Testing getProcessComputeUtilization ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + const char* processIds[] = {"12345"}; + ComputeUtilization* utilizations = NULL; + size_t utilizationCount = 0; + + Result result = getProcessComputeUtilization( + deviceUUIDs, 1, + processIds, 1, + &utilizations, &utilizationCount + ); + + TEST_ASSERT(result == RESULT_SUCCESS, "getProcessComputeUtilization returns success"); + TEST_ASSERT(utilizations != NULL, "Utilizations array is not NULL"); + TEST_ASSERT(utilizationCount > 0, "Utilization count > 0"); + + if (utilizations && utilizationCount > 0) { + TEST_ASSERT(utilizations[0].utilizationPercent >= 0 && + utilizations[0].utilizationPercent <= 100, + "Utilization percent in valid range"); + } + + freeComputeUtilizations(utilizations, utilizationCount); +} + +// Test getProcessMemoryUtilization +void test_getProcessMemoryUtilization() { + printf("\n=== Testing getProcessMemoryUtilization ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + const char* processIds[] = {"12345"}; + MemoryUtilization* utilizations = NULL; + size_t utilizationCount = 0; + + Result result = getProcessMemoryUtilization( + deviceUUIDs, 1, + processIds, 1, + &utilizations, &utilizationCount + ); + + TEST_ASSERT(result == RESULT_SUCCESS, "getProcessMemoryUtilization returns success"); + TEST_ASSERT(utilizations != NULL, "Utilizations array is not NULL"); + TEST_ASSERT(utilizationCount > 0, "Utilization count > 0"); + + if (utilizations && utilizationCount > 0) { + TEST_ASSERT(utilizations[0].usedBytes > 0, "Used bytes > 0"); + } + + freeMemoryUtilizations(utilizations, utilizationCount); +} + +// Test getDeviceMetrics +void test_getDeviceMetrics() { + printf("\n=== Testing getDeviceMetrics ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + DeviceMetrics* metrics = NULL; + + Result result = getDeviceMetrics(deviceUUIDs, 1, &metrics); + + TEST_ASSERT(result == RESULT_SUCCESS, "getDeviceMetrics returns success"); + TEST_ASSERT(metrics != NULL, "Metrics array is not NULL"); + + if (metrics) { + TEST_ASSERT(strlen(metrics[0].deviceUUID) > 0, "Device UUID is not empty"); + TEST_ASSERT(metrics[0].powerUsageWatts >= 0, "Power usage >= 0"); + TEST_ASSERT(metrics[0].temperatureCelsius >= 0, "Temperature >= 0"); + } + + freeDeviceMetrics(metrics, 1); +} + +// Test getExtendedDeviceMetrics +void test_getExtendedDeviceMetrics() { + printf("\n=== Testing getExtendedDeviceMetrics ===\n"); + + const char* deviceUUIDs[] = {"stub-device-0"}; + ExtendedDeviceMetrics* metrics = NULL; + + Result result = getExtendedDeviceMetrics(deviceUUIDs, 1, &metrics); + + TEST_ASSERT(result == RESULT_SUCCESS, "getExtendedDeviceMetrics returns success"); + TEST_ASSERT(metrics != NULL, "Metrics array is not NULL"); + + if (metrics) { + TEST_ASSERT(strlen(metrics[0].deviceUUID) > 0, "Device UUID is not empty"); + TEST_ASSERT(metrics[0].nvlinkCount > 0, "NVLink count > 0"); + } + + freeExtendedDeviceMetrics(metrics, 1); +} + +// Main test runner +int main() { + printf("========================================\n"); + printf("Accelerator Library Test Suite\n"); + printf("========================================\n"); + + test_getDeviceInfo(); + test_getPartitionTemplates(); + test_getDeviceTopology(); + test_assignPartition(); + test_removePartition(); + test_setMemHardLimit(); + test_setComputeUnitHardLimit(); + test_getProcessComputeUtilization(); + test_getProcessMemoryUtilization(); + test_getDeviceMetrics(); + test_getExtendedDeviceMetrics(); + + printf("\n========================================\n"); + printf("Test Summary\n"); + printf("========================================\n"); + printf("Total tests: %d\n", tests_run); + printf("Passed: %d\n", tests_passed); + printf("Failed: %d\n", tests_failed); + printf("========================================\n"); + + if (tests_failed == 0) { + printf("All tests passed! ✓\n"); + return 0; + } else { + printf("Some tests failed! ✗\n"); + return 1; + } +} + diff --git a/test/sched/preemption_test.go b/test/sched/preemption_test.go index 4e33b6a4..62cfbaa5 100644 --- a/test/sched/preemption_test.go +++ b/test/sched/preemption_test.go @@ -69,7 +69,7 @@ func (pts *PreemptionTestSuite) SetupSuite() { gpuResourceFitOpt := app.WithPlugin( gpuResourceFitPlugin.Name, - gpuResourceFitPlugin.NewWithDeps(fixture.allocator, fixture.client), + gpuResourceFitPlugin.NewWithDeps(fixture.allocator, fixture.indexAllocator, fixture.client), ) gpuTopoOpt := app.WithPlugin( gpuTopoPlugin.Name, diff --git a/test/sched/scheduler_bench_test.go b/test/sched/scheduler_bench_test.go index 4b80fb71..555a6a26 100644 --- a/test/sched/scheduler_bench_test.go +++ b/test/sched/scheduler_bench_test.go @@ -102,7 +102,7 @@ func BenchmarkScheduler(b *testing.B) { gpuResourceFitOpt := app.WithPlugin( gpuResourceFitPlugin.Name, - gpuResourceFitPlugin.NewWithDeps(fixture.allocator, fixture.client), + gpuResourceFitPlugin.NewWithDeps(fixture.allocator, fixture.indexAllocator, fixture.client), ) gpuTopoOpt := app.WithPlugin( gpuTopoPlugin.Name, diff --git a/test/sched/setup.go b/test/sched/setup.go index 5dc80e32..1794a1e4 100644 --- a/test/sched/setup.go +++ b/test/sched/setup.go @@ -11,6 +11,7 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/gpuallocator" + "github.com/NexusGPU/tensor-fusion/internal/indexallocator" gpuResourceFitPlugin "github.com/NexusGPU/tensor-fusion/internal/scheduler/gpuresources" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" @@ -49,14 +50,15 @@ type BenchmarkConfig struct { // BenchmarkFixture holds pre-initialized benchmark data type BenchmarkFixture struct { - ctx context.Context - cancel context.CancelFunc - plugin *gpuResourceFitPlugin.GPUFit - nodes []*v1.Node - pods []*v1.Pod - allocator *gpuallocator.GpuAllocator - client client.Client - fwk framework.Framework + ctx context.Context + cancel context.CancelFunc + plugin *gpuResourceFitPlugin.GPUFit + nodes []*v1.Node + pods []*v1.Pod + allocator *gpuallocator.GpuAllocator + indexAllocator *indexallocator.IndexAllocator + client client.Client + fwk framework.Framework } // NewBenchmarkFixture creates and initializes a benchmark fixture @@ -94,30 +96,33 @@ func NewBenchmarkFixture( // Setup allocator allocator := setupAllocator(b, ctx, client) - + indexAllocator, err := indexallocator.NewIndexAllocator(ctx, client) + require.NoError(b, err) // Setup framework and plugin if !realAPIServer { - fwk, plugin := setupFrameworkAndPlugin(b, ctx, client, allocator, k8sNativeObjects) + fwk, plugin := setupFrameworkAndPlugin(b, ctx, client, allocator, indexAllocator, k8sNativeObjects) return &BenchmarkFixture{ - ctx: ctx, - cancel: cancel, - plugin: plugin, - nodes: nodes, - pods: pods, - allocator: allocator, - client: client, - fwk: fwk, + ctx: ctx, + cancel: cancel, + plugin: plugin, + nodes: nodes, + pods: pods, + allocator: allocator, + indexAllocator: indexAllocator, + client: client, + fwk: fwk, } } else { return &BenchmarkFixture{ - ctx: ctx, - cancel: cancel, - plugin: nil, - nodes: nodes, - pods: pods, - allocator: allocator, - client: client, - fwk: nil, + ctx: ctx, + cancel: cancel, + plugin: nil, + nodes: nodes, + pods: pods, + allocator: allocator, + indexAllocator: indexAllocator, + client: client, + fwk: nil, } } } @@ -352,7 +357,7 @@ func batchCreateResources( func setupFrameworkAndPlugin( b *testing.B, ctx context.Context, client client.Client, - allocator *gpuallocator.GpuAllocator, k8sObjs []runtime.Object, + allocator *gpuallocator.GpuAllocator, indexAllocator *indexallocator.IndexAllocator, k8sObjs []runtime.Object, ) (framework.Framework, *gpuResourceFitPlugin.GPUFit) { // Register plugins including our GPU plugin registeredPlugins := []tf.RegisterPluginFunc{ @@ -374,7 +379,7 @@ func setupFrameworkAndPlugin( require.NoError(b, err) // Create plugin directly - plugin := createPlugin(b, ctx, fwk, allocator, client) + plugin := createPlugin(b, ctx, fwk, allocator, indexAllocator, client) return fwk, plugin } @@ -382,7 +387,7 @@ func setupFrameworkAndPlugin( func setupAllocator( b *testing.B, ctx context.Context, client client.Client, ) *gpuallocator.GpuAllocator { - allocator := gpuallocator.NewGpuAllocator(ctx, client, time.Second) + allocator := gpuallocator.NewGpuAllocator(ctx, nil, client, time.Second) require.NoError(b, allocator.InitGPUAndQuotaStore()) allocator.ReconcileAllocationState() allocator.SetAllocatorReady() @@ -391,9 +396,9 @@ func setupAllocator( func createPlugin( b *testing.B, ctx context.Context, fwk framework.Framework, - allocator *gpuallocator.GpuAllocator, client client.Client, + allocator *gpuallocator.GpuAllocator, indexAllocator *indexallocator.IndexAllocator, client client.Client, ) *gpuResourceFitPlugin.GPUFit { - pluginFactory := gpuResourceFitPlugin.NewWithDeps(allocator, client) + pluginFactory := gpuResourceFitPlugin.NewWithDeps(allocator, indexAllocator, client) pluginConfig := &runtime.Unknown{ Raw: []byte(`{"maxWorkerPerNode": 256, "vramWeight": 0.7, "tflopsWeight": 0.3}`), }