diff --git a/charts/hami/templates/scheduler/configmap.yaml b/charts/hami/templates/scheduler/configmap.yaml index 758891463..8e98dd2ff 100644 --- a/charts/hami/templates/scheduler/configmap.yaml +++ b/charts/hami/templates/scheduler/configmap.yaml @@ -82,6 +82,14 @@ data: }, {{- end }} {{- end }} + {{- if .Values.devices.vastai.enabled }} + {{- range .Values.devices.vastai.customresources }} + { + "name": "{{ . }}", + "ignoredByScheduler": true + }, + {{- end }} + {{- end }} { "name": "{{ .Values.resourceName }}", "ignoredByScheduler": true diff --git a/charts/hami/templates/scheduler/configmapnew.yaml b/charts/hami/templates/scheduler/configmapnew.yaml index e2a91d8b5..0e169aa20 100644 --- a/charts/hami/templates/scheduler/configmapnew.yaml +++ b/charts/hami/templates/scheduler/configmapnew.yaml @@ -85,6 +85,12 @@ data: ignoredByScheduler: true {{- end }} {{- end }} + {{- if .Values.devices.vastai.enabled }} + {{- range .Values.devices.vastai.customresources }} + - name: {{ . }} + ignoredByScheduler: true + {{- end }} + {{- end }} {{- range .Values.devices.awsneuron.customresources }} - name: {{ . }} ignoredByScheduler: true diff --git a/charts/hami/templates/scheduler/device-configmap.yaml b/charts/hami/templates/scheduler/device-configmap.yaml index 92a5a020d..faea2bfa7 100644 --- a/charts/hami/templates/scheduler/device-configmap.yaml +++ b/charts/hami/templates/scheduler/device-configmap.yaml @@ -283,6 +283,8 @@ data: resourceCoreName: "aws.amazon.com/neuroncore" amd: resourceCountName: "amd.com/gpu" + vastai: + resourceCountName: {{ .Values.vastaiResourceName }} vnpus: - chipName: 910A commonWord: Ascend910A diff --git a/charts/hami/values.yaml b/charts/hami/values.yaml index 55af5e626..575450e48 100644 --- a/charts/hami/values.yaml +++ b/charts/hami/values.yaml @@ -53,6 +53,9 @@ kunlunResourceName: "kunlunxin.com/xpu" kunlunResourceVCountName: "kunlunxin.com/vxpu" kunlunResourceVMemoryName: "kunlunxin.com/vxpu-memory" +#Vastai Parameters +vastaiResourceName: "vastaitech.com/va" + schedulerName: "hami-scheduler" podSecurityPolicy: @@ -440,6 +443,11 @@ devices: enabled: true customresources: - mthreads.com/vgpu + vastai: + enabled: true + customresources: + - vastaitech.com/va + - vastaitech.com/va-die nvidia: gpuCorePolicy: default libCudaLogLevel: 1 diff --git a/pkg/device/vastai/device.go b/pkg/device/vastai/device.go new file mode 100644 index 000000000..82dca309b --- /dev/null +++ b/pkg/device/vastai/device.go @@ -0,0 +1,287 @@ +/* +Copyright 2026 The HAMi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vastai + +import ( + "errors" + "fmt" + "strings" + + "github.com/Project-HAMi/HAMi/pkg/device" + "github.com/Project-HAMi/HAMi/pkg/device/common" + "github.com/Project-HAMi/HAMi/pkg/util" + "github.com/Project-HAMi/HAMi/pkg/util/nodelock" + + corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" +) + +type VastaiDevices struct { +} + +const ( + HandshakeAnnos = "hami.io/node-handshake-va" + RegisterAnnos = "hami.io/node-va-register" + VastaiDevice = "Vastai" + VastaiCommonWord = "Vastai" + VastaiInUse = "vastaitech.com/use-va" + VastaiNoUse = "vastaitech.com/nouse-va" + VastaiUseUUID = "vastaitech.com/use-gpuuuid" + VastaiNoUseUUID = "vastaitech.com/nouse-gpuuuid" +) + +var ( + VastaiResourceCount string +) + +type VastaiConfig struct { + ResourceCountName string `yaml:"resourceCountName"` +} + +func InitVastaiDevice(config VastaiConfig) *VastaiDevices { + VastaiResourceCount = config.ResourceCountName + commonWord := VastaiCommonWord + _, ok := device.InRequestDevices[commonWord] + if !ok { + device.InRequestDevices[commonWord] = fmt.Sprintf("hami.io/%s-devices-to-allocate", commonWord) + device.SupportDevices[commonWord] = fmt.Sprintf("hami.io/%s-devices-allocated", commonWord) + util.HandshakeAnnos[commonWord] = HandshakeAnnos + } + return &VastaiDevices{} +} + +func (dev *VastaiDevices) CommonWord() string { + return VastaiCommonWord +} + +func (dev *VastaiDevices) GetNodeDevices(n corev1.Node) ([]*device.DeviceInfo, error) { + devEncoded, ok := n.Annotations[RegisterAnnos] + if !ok { + return []*device.DeviceInfo{}, errors.New("annos not found " + RegisterAnnos) + } + nodedevices, err := device.UnMarshalNodeDevices(devEncoded) + if err != nil { + klog.ErrorS(err, "failed to decode node devices", "node", n.Name, "device annotation", devEncoded) + return []*device.DeviceInfo{}, err + } + klog.V(5).InfoS("nodes device information", "node", n.Name, "nodedevices", devEncoded) + for idx := range nodedevices { + nodedevices[idx].DeviceVendor = VastaiCommonWord + } + if len(nodedevices) == 0 { + klog.InfoS("no gpu device found", "node", n.Name, "device annotation", devEncoded) + return []*device.DeviceInfo{}, errors.New("no gpu found on node") + } + return nodedevices, nil +} + +func (dev *VastaiDevices) MutateAdmission(ctr *corev1.Container, p *corev1.Pod) (bool, error) { + _, ok := ctr.Resources.Limits[corev1.ResourceName(VastaiResourceCount)] + return ok, nil +} + +func (dev *VastaiDevices) LockNode(n *corev1.Node, p *corev1.Pod) error { + found := false + for _, val := range p.Spec.Containers { + if (dev.GenerateResourceRequests(&val).Nums) > 0 { + found = true + break + } + } + if !found { + return nil + } + return nodelock.LockNode(n.Name, nodelock.NodeLockKey, p) +} + +func (dev *VastaiDevices) ReleaseNodeLock(n *corev1.Node, p *corev1.Pod) error { + found := false + for _, val := range p.Spec.Containers { + if (dev.GenerateResourceRequests(&val).Nums) > 0 { + found = true + break + } + } + if !found { + return nil + } + return nodelock.ReleaseNodeLock(n.Name, nodelock.NodeLockKey, p, false) +} + +func (dev *VastaiDevices) NodeCleanUp(nn string) error { + return util.MarkAnnotationsToDelete(HandshakeAnnos, nn) +} + +func (dev *VastaiDevices) checkType(annos map[string]string, d device.DeviceUsage, n device.ContainerDeviceRequest) (bool, bool, bool) { + if strings.Compare(n.Type, VastaiDevice) == 0 { + return true, true, false + } + return false, false, false +} + +func (dev *VastaiDevices) CheckHealth(devType string, n *corev1.Node) (bool, bool) { + return device.CheckHealth(devType, n) +} + +func (dev *VastaiDevices) GenerateResourceRequests(ctr *corev1.Container) device.ContainerDeviceRequest { + klog.V(5).Info("Start to count vastai devices for container ", ctr.Name) + vastaiResourceCount := corev1.ResourceName(VastaiResourceCount) + v, ok := ctr.Resources.Limits[vastaiResourceCount] + if !ok { + v, ok = ctr.Resources.Requests[vastaiResourceCount] + } + if ok { + if n, ok := v.AsInt64(); ok { + klog.Info("Found vastai devices") + memnum := 0 + corenum := int32(0) + mempnum := 100 + + return device.ContainerDeviceRequest{ + Nums: int32(n), + Type: VastaiDevice, + Memreq: int32(memnum), + MemPercentagereq: int32(mempnum), + Coresreq: corenum, + } + } + } + return device.ContainerDeviceRequest{} +} + +func (dev *VastaiDevices) PatchAnnotations(pod *corev1.Pod, annoinput *map[string]string, pd device.PodDevices) map[string]string { + devlist, ok := pd[VastaiDevice] + if ok && len(devlist) > 0 { + deviceStr := device.EncodePodSingleDevice(devlist) + (*annoinput)[device.InRequestDevices[VastaiDevice]] = deviceStr + (*annoinput)[device.SupportDevices[VastaiDevice]] = deviceStr + klog.V(5).Infof("pod add notation key [%s], values is [%s]", device.InRequestDevices[VastaiDevice], deviceStr) + klog.V(5).Infof("pod add notation key [%s], values is [%s]", device.SupportDevices[VastaiDevice], deviceStr) + } + return *annoinput +} + +func (dev *VastaiDevices) ScoreNode(node *corev1.Node, podDevices device.PodSingleDevice, previous []*device.DeviceUsage, policy string) float32 { + return 0 +} + +func (dev *VastaiDevices) AddResourceUsage(pod *corev1.Pod, n *device.DeviceUsage, ctr *device.ContainerDevice) error { + n.Used++ + n.Usedcores += ctr.Usedcores + n.Usedmem += ctr.Usedmem + return nil +} + +func (va *VastaiDevices) Fit(devices []*device.DeviceUsage, request device.ContainerDeviceRequest, pod *corev1.Pod, nodeInfo *device.NodeInfo, allocated *device.PodDevices) (bool, map[string]device.ContainerDevices, string) { + k := request + originReq := k.Nums + prevnuma := -1 + klog.InfoS("Allocating device for container request", "pod", klog.KObj(pod), "card request", k) + tmpDevs := make(map[string]device.ContainerDevices) + reason := make(map[string]int) + for i := range len(devices) { + dev := devices[i] + klog.V(4).InfoS("scoring pod", "pod", klog.KObj(pod), "device", dev.ID, "Memreq", k.Memreq, "MemPercentagereq", k.MemPercentagereq, "Coresreq", k.Coresreq, "Nums", k.Nums, "device index", i) + + _, found, numa := va.checkType(pod.GetAnnotations(), *dev, k) + if !found { + reason[common.CardTypeMismatch]++ + klog.V(5).InfoS(common.CardTypeMismatch, "pod", klog.KObj(pod), "device", dev.ID, dev.Type, k.Type) + continue + } + if numa && prevnuma != dev.Numa { + if k.Nums != originReq { + reason[common.NumaNotFit] += len(tmpDevs) + klog.V(5).InfoS(common.NumaNotFit, "pod", klog.KObj(pod), "device", dev.ID, "k.nums", k.Nums, "numa", numa, "prevnuma", prevnuma, "device numa", dev.Numa) + } + k.Nums = originReq + prevnuma = dev.Numa + tmpDevs = make(map[string]device.ContainerDevices) + } + if !device.CheckUUID(pod.GetAnnotations(), dev.ID, VastaiUseUUID, VastaiNoUseUUID, VastaiCommonWord) { + reason[common.CardUUIDMismatch]++ + klog.V(5).InfoS(common.CardUUIDMismatch, "pod", klog.KObj(pod), "device", dev.ID, "current device info is:", *dev) + continue + } + + memreq := int32(0) + if dev.Count <= dev.Used { + reason[common.CardTimeSlicingExhausted]++ + klog.V(5).InfoS(common.CardTimeSlicingExhausted, "pod", klog.KObj(pod), "device", dev.ID, "count", dev.Count, "used", dev.Used) + continue + } + if k.Coresreq > 100 { + klog.ErrorS(nil, "core limit can't exceed 100", "pod", klog.KObj(pod), "device", dev.ID) + k.Coresreq = 100 + } + if k.Memreq > 0 { + memreq = k.Memreq + } + if k.MemPercentagereq != 101 && k.Memreq == 0 { + memreq = dev.Totalmem * k.MemPercentagereq / 100 + } + if dev.Totalmem-dev.Usedmem < memreq { + reason[common.CardInsufficientMemory]++ + klog.V(5).InfoS(common.CardInsufficientMemory, "pod", klog.KObj(pod), "device", dev.ID, "device index", i, "device total memory", dev.Totalmem, "device used memory", dev.Usedmem, "request memory", memreq) + continue + } + if dev.Totalcore-dev.Usedcores < k.Coresreq { + reason[common.CardInsufficientCore]++ + klog.V(5).InfoS(common.CardInsufficientCore, "pod", klog.KObj(pod), "device", dev.ID, "device index", i, "device total core", dev.Totalcore, "device used core", dev.Usedcores, "request cores", k.Coresreq) + continue + } + // Coresreq=100 indicates it want this card exclusively + if dev.Totalcore == 100 && k.Coresreq == 100 && dev.Used > 0 { + reason[common.ExclusiveDeviceAllocateConflict]++ + klog.V(5).InfoS(common.ExclusiveDeviceAllocateConflict, "pod", klog.KObj(pod), "device", dev.ID, "device index", i, "used", dev.Used) + continue + } + // You can't allocate core=0 job to an already full GPU + if dev.Totalcore != 0 && dev.Usedcores == dev.Totalcore && k.Coresreq == 0 { + reason[common.CardComputeUnitsExhausted]++ + klog.V(5).InfoS(common.CardComputeUnitsExhausted, "pod", klog.KObj(pod), "device", dev.ID, "device index", i) + continue + } + if k.Nums > 0 { + klog.V(5).InfoS("find fit device", "pod", klog.KObj(pod), "device", dev.ID) + k.Nums-- + tmpDevs[k.Type] = append(tmpDevs[k.Type], device.ContainerDevice{ + Idx: int(dev.Index), + UUID: dev.ID, + Type: k.Type, + Usedmem: memreq, + Usedcores: k.Coresreq, + }) + } + if k.Nums == 0 { + klog.V(4).InfoS("device allocate success", "pod", klog.KObj(pod), "allocate device", tmpDevs) + return true, tmpDevs, "" + } + + } + if len(tmpDevs) > 0 { + reason[common.AllocatedCardsInsufficientRequest] = len(tmpDevs) + klog.V(5).InfoS(common.AllocatedCardsInsufficientRequest, "pod", klog.KObj(pod), "request", originReq, "allocated", len(tmpDevs)) + } + return false, tmpDevs, common.GenReason(reason, len(devices)) +} + +func (dev *VastaiDevices) GetResourceNames() device.ResourceNames { + return device.ResourceNames{ + ResourceCountName: VastaiResourceCount, + } +} diff --git a/pkg/device/vastai/device_test.go b/pkg/device/vastai/device_test.go new file mode 100644 index 000000000..41ca67176 --- /dev/null +++ b/pkg/device/vastai/device_test.go @@ -0,0 +1,840 @@ +/* +Copyright 2026 The HAMi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vastai + +import ( + "context" + "errors" + "fmt" + "testing" + + "gotest.tools/v3/assert" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/klog/v2" + + "github.com/Project-HAMi/HAMi/pkg/device" + "github.com/Project-HAMi/HAMi/pkg/util" + "github.com/Project-HAMi/HAMi/pkg/util/client" + "github.com/Project-HAMi/HAMi/pkg/util/nodelock" +) + +func Test_MutateAdmission(t *testing.T) { + config := VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + } + InitVastaiDevice(config) + tests := []struct { + name string + args struct { + ctr *corev1.Container + p *corev1.Pod + } + want bool + err error + }{ + { + name: "set to resource limits", + args: struct { + ctr *corev1.Container + p *corev1.Pod + }{ + ctr: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + p: &corev1.Pod{}, + }, + want: true, + err: nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result, err := dev.MutateAdmission(test.args.ctr, test.args.p) + if err != test.err { + klog.InfoS("set to resource limits failed") + } + assert.Equal(t, result, test.want) + }) + } +} + +func Test_GetNodeDevices(t *testing.T) { + dev := VastaiDevices{} + tests := []struct { + name string + args corev1.Node + want []*device.DeviceInfo + err error + }{ + { + name: "no annotation", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("annos not found " + RegisterAnnos), + }, + { + name: "exist vastai device", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": "[{\"id\":\"7-0-batch-0\",\"count\":1,\"type\":\"Vastai\",\"health\":true,\"devicepairscore\":{}}]", + }, + }, + }, + want: []*device.DeviceInfo{ + { + ID: "7-0-batch-0", + Count: int32(1), + Devmem: int32(0), + Devcore: int32(0), + Type: dev.CommonWord(), + Numa: 0, + Health: true, + Index: uint(0), + Mode: "", + DeviceVendor: VastaiCommonWord, + }, + }, + err: nil, + }, + { + name: "no vasta device", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": ":", + }, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("no gpu found on node"), + }, + { + name: "node annotations not decode successfully", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": "", + }, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("node annotations not decode successfully"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, err := dev.GetNodeDevices(test.args) + if err != nil { + klog.Errorf("got %v, want %v", err, test.err) + } + assert.DeepEqual(t, result, test.want) + }) + } +} + +func Test_CheckHealth(t *testing.T) { + tests := []struct { + name string + args struct { + devType string + n *corev1.Node + } + want1 bool + want2 bool + }{ + { + name: "Requesting state expired", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Requesting_2025-01-07 00:00:00", + }, + }, + }, + }, + want1: false, + want2: false, + }, + { + name: "Deleted state", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Deleted", + }, + }, + }, + }, + want1: true, + want2: false, + }, + { + name: "Unknown state", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Unknown", + }, + }, + }, + }, + want1: true, + want2: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result1, result2 := dev.CheckHealth(test.args.devType, test.args.n) + assert.Equal(t, result1, test.want1) + assert.Equal(t, result2, test.want2) + }) + } +} + +func Test_checkType(t *testing.T) { + tests := []struct { + name string + args struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + } + want1 bool + want2 bool + want3 bool + }{ + { + name: "the same type", + args: struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + }{ + annos: map[string]string{}, + d: device.DeviceUsage{ + Type: "Vastai", + }, + n: device.ContainerDeviceRequest{ + Type: "Vastai", + }, + }, + want1: true, + want2: true, + want3: false, + }, + { + name: "the different type", + args: struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + }{ + annos: map[string]string{}, + d: device.DeviceUsage{ + Type: "Vastai", + }, + n: device.ContainerDeviceRequest{ + Type: "test", + }, + }, + want1: false, + want2: false, + want3: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result1, result2, result3 := dev.checkType(test.args.annos, test.args.d, test.args.n) + assert.Equal(t, result1, test.want1) + assert.Equal(t, result2, test.want2) + assert.Equal(t, result3, test.want3) + }) + } +} + +func Test_PatchAnnotations(t *testing.T) { + tests := []struct { + name string + args struct { + annoinput *map[string]string + pd device.PodDevices + } + want map[string]string + }{ + { + name: "exist device", + args: struct { + annoinput *map[string]string + pd device.PodDevices + }{ + annoinput: &map[string]string{}, + pd: device.PodDevices{ + VastaiDevice: device.PodSingleDevice{ + []device.ContainerDevice{ + { + Idx: 1, + UUID: "test1", + Type: VastaiDevice, + Usedmem: int32(2048), + Usedcores: int32(1), + }, + }, + }, + }, + }, + want: map[string]string{ + device.InRequestDevices[VastaiDevice]: "test1,Vastai,2048,1:;", + device.SupportDevices[VastaiDevice]: "test1,Vastai,2048,1:;", + }, + }, + { + name: "no device", + args: struct { + annoinput *map[string]string + pd device.PodDevices + }{ + annoinput: &map[string]string{}, + pd: device.PodDevices{}, + }, + want: map[string]string{}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result := dev.PatchAnnotations(&corev1.Pod{}, test.args.annoinput, test.args.pd) + assert.DeepEqual(t, result, test.want) + }) + } +} + +func Test_GenerateResourceRequests(t *testing.T) { + tests := []struct { + name string + args *corev1.Container + want device.ContainerDeviceRequest + }{ + { + name: "don't set to limits and request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{}, + Requests: corev1.ResourceList{}, + }, + }, + want: device.ContainerDeviceRequest{}, + }, + { + name: "set to limits and request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + { + name: "only set to limits", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + { + name: "only set to request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result := dev.GenerateResourceRequests(test.args) + assert.DeepEqual(t, result, test.want) + }) + } +} + +func TestDevices_LockNode(t *testing.T) { + tests := []struct { + name string + node *corev1.Node + pod *corev1.Pod + hasLock bool + expectError bool + }{ + { + name: "Test with no containers", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{}}, + hasLock: false, + expectError: false, + }, + { + name: "Test with non-zero resource requests", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }}}}}}, + hasLock: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize fake clientset and pre-load test data + client.KubeClient = fake.NewSimpleClientset() + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testNode", + Annotations: map[string]string{"test-annotation-key": "test-annotation-value", device.InRequestDevices["DCU"]: "some-value"}, + }, + } + + // Add the node to the fake clientset + _, err := client.KubeClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test node: %v", err) + } + + dev := InitVastaiDevice(VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + }) + err = dev.LockNode(node, tt.pod) + if tt.expectError { + assert.Equal(t, err != nil, true) + } else { + assert.NilError(t, err) + } + node, err = client.KubeClient.CoreV1().Nodes().Get(context.Background(), node.Name, metav1.GetOptions{}) + assert.NilError(t, err) + fmt.Println(node.Annotations) + _, ok := node.Annotations[nodelock.NodeLockKey] + assert.Equal(t, ok, tt.hasLock) + }) + } +} + +func TestDevices_ReleaseNodeLock(t *testing.T) { + tests := []struct { + name string + node *corev1.Node + pod *corev1.Pod + hasLock bool + expectError bool + }{ + { + name: "Test with no containers", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{}}, + hasLock: true, + expectError: false, + }, + { + name: "Test with non-zero resource requests", + node: &corev1.Node{}, + pod: &corev1.Pod{ObjectMeta: metav1.ObjectMeta{ + Name: "nozerorr", + Namespace: "default", + }, Spec: corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }}}}}}, + hasLock: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize fake clientset and pre-load test data + client.KubeClient = fake.NewSimpleClientset() + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testNode", + Annotations: map[string]string{"test-annotation-key": "test-annotation-value", device.InRequestDevices[VastaiDevice]: "some-value", nodelock.NodeLockKey: "lock-values,default,nozerorr"}, + }, + } + + // Add the node to the fake clientset + _, err := client.KubeClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test node: %v", err) + } + + dev := InitVastaiDevice(VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + }) + err = dev.ReleaseNodeLock(node, tt.pod) + if tt.expectError { + assert.Equal(t, err != nil, true) + } else { + assert.NilError(t, err) + } + node, err = client.KubeClient.CoreV1().Nodes().Get(context.Background(), node.Name, metav1.GetOptions{}) + assert.NilError(t, err) + fmt.Println(node.Annotations) + _, ok := node.Annotations[nodelock.NodeLockKey] + assert.Equal(t, ok, tt.hasLock) + }) + } +} + +func TestDevices_Fit(t *testing.T) { + config := VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + } + dev := InitVastaiDevice(config) + + tests := []struct { + name string + devices []*device.DeviceUsage + request device.ContainerDeviceRequest + annos map[string]string + wantFit bool + wantLen int + wantDevIDs []string + wantReason string + }{ + { + name: "fit success", + devices: []*device.DeviceUsage{ + { + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }, + { + ID: "dev-1", + Index: 1, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }, + }, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: true, + wantLen: 1, + wantDevIDs: []string{"dev-0"}, + wantReason: "", + }, + { + name: "fit fail: type mismatch", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Health: true, + Type: VastaiDevice, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Type: "OtherType", + Memreq: 512, + MemPercentagereq: 0, + Coresreq: 50, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardTypeMismatch", + }, + { + name: "fit fail: user assign use uuid mismatch", + devices: []*device.DeviceUsage{{ + ID: "dev-1", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{VastaiUseUUID: "dev-0"}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardUuidMismatch", + }, + { + name: "fit fail: user assign no use uuid match", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{VastaiNoUseUUID: "dev-0"}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardUuidMismatch", + }, + { + name: "fit fail: card overused", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 1, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardTimeSlicingExhausted", + }, + { + name: "fit fail: AllocatedCardsInsufficientRequest", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 2, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 AllocatedCardsInsufficientRequest", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + allocated := &device.PodDevices{} + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: test.annos, + }, + } + fit, result, reason := dev.Fit(test.devices, test.request, pod, &device.NodeInfo{}, allocated) + if fit != test.wantFit { + t.Errorf("Fit: got %v, want %v", fit, test.wantFit) + } + if test.wantFit { + if len(result[VastaiDevice]) != test.wantLen { + t.Errorf("expected len: %d, got len %d", test.wantLen, len(result[VastaiDevice])) + } + for idx, id := range test.wantDevIDs { + if id != result[VastaiDevice][idx].UUID { + t.Errorf("expected device id: %s, got device id %s", id, result[VastaiDevice][idx].UUID) + } + } + } + + if reason != test.wantReason { + t.Errorf("expected reason: %s, got reason: %s", test.wantReason, reason) + } + }) + } +} + +func TestDevices_AddResourceUsage(t *testing.T) { + tests := []struct { + name string + deviceUsage *device.DeviceUsage + ctr *device.ContainerDevice + wantErr bool + wantUsage *device.DeviceUsage + }{ + { + name: "test add resource usage", + deviceUsage: &device.DeviceUsage{ + ID: "dev-0", + Used: 1, + Usedcores: 0, + Usedmem: 0, + }, + ctr: &device.ContainerDevice{ + UUID: "dev-0", + Usedcores: 0, + Usedmem: 0, + }, + wantUsage: &device.DeviceUsage{ + ID: "dev-0", + Used: 2, + Usedcores: 0, + Usedmem: 0, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dev := &VastaiDevices{} + if err := dev.AddResourceUsage(&corev1.Pod{}, tt.deviceUsage, tt.ctr); (err != nil) != tt.wantErr { + t.Errorf("AddResourceUsage() error=%v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + if tt.deviceUsage.Usedcores != tt.wantUsage.Usedcores { + t.Errorf("expected used cores: %d, got used cores %d", tt.wantUsage.Usedcores, tt.deviceUsage.Usedcores) + } + if tt.deviceUsage.Usedmem != tt.wantUsage.Usedmem { + t.Errorf("expected used mem: %d, got used mem %d", tt.wantUsage.Usedmem, tt.deviceUsage.Usedmem) + } + if tt.deviceUsage.Used != tt.wantUsage.Used { + t.Errorf("expected used: %d, got used %d", tt.wantUsage.Used, tt.deviceUsage.Used) + } + } + }) + } +} diff --git a/pkg/scheduler/config/config.go b/pkg/scheduler/config/config.go index a012eabf0..f131c2325 100644 --- a/pkg/scheduler/config/config.go +++ b/pkg/scheduler/config/config.go @@ -38,6 +38,7 @@ import ( "github.com/Project-HAMi/HAMi/pkg/device/metax" "github.com/Project-HAMi/HAMi/pkg/device/mthreads" "github.com/Project-HAMi/HAMi/pkg/device/nvidia" + "github.com/Project-HAMi/HAMi/pkg/device/vastai" "github.com/Project-HAMi/HAMi/pkg/util" ) @@ -82,6 +83,7 @@ type Config struct { KunlunConfig kunlun.KunlunConfig `yaml:"kunlun"` AWSNeuronConfig awsneuron.AWSNeuronConfig `yaml:"awsneuron"` AMDGPUConfig amd.AMDConfig `yaml:"amd"` + VastaiConfig vastai.VastaiConfig `yaml:"vastai"` VNPUs []ascend.VNPUConfig `yaml:"vnpus"` } @@ -209,6 +211,13 @@ func InitDevicesWithConfig(config *Config) error { } return amd.InitAMDGPUDevice(amdGPUConfig), nil }, config.AMDGPUConfig}, + {vastai.VastaiDevice, vastai.VastaiCommonWord, func(cfg any) (device.Devices, error) { + vastaiConfig, ok := cfg.(vastai.VastaiConfig) + if !ok { + return nil, fmt.Errorf("invalid configuration for %s", vastai.VastaiCommonWord) + } + return vastai.InitVastaiDevice(vastaiConfig), nil + }, config.VastaiConfig}, } // Initialize all devices using the wrapped functions diff --git a/pkg/scheduler/config/config_test.go b/pkg/scheduler/config/config_test.go index 997b0f601..10857a873 100644 --- a/pkg/scheduler/config/config_test.go +++ b/pkg/scheduler/config/config_test.go @@ -39,6 +39,7 @@ import ( "github.com/Project-HAMi/HAMi/pkg/device/metax" "github.com/Project-HAMi/HAMi/pkg/device/mthreads" "github.com/Project-HAMi/HAMi/pkg/device/nvidia" + "github.com/Project-HAMi/HAMi/pkg/device/vastai" ) func loadTestConfig() string { @@ -426,6 +427,7 @@ func setupTest(t *testing.T) (map[string]string, map[string]device.Devices) { kunlun.XPUDevice: kunlun.XPUCommonWord, awsneuron.AWSNeuronDevice: awsneuron.AWSNeuronCommonWord, amd.AMDDevice: amd.AMDDevice, + vastai.VastaiDevice: vastai.VastaiCommonWord, } return expectedDevices, device.DevicesMap