Skip to content

Commit 71cf9b0

Browse files
committed
fix: unit test issues
1 parent b00fcdc commit 71cf9b0

File tree

14 files changed

+712
-234
lines changed

14 files changed

+712
-234
lines changed

internal/hypervisor/backend/kubernetes/deviceplugin.go

Lines changed: 101 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import (
2222
"net"
2323
"os"
2424
"path/filepath"
25+
"strconv"
2526
"sync"
2627
"time"
2728

29+
"github.com/NexusGPU/tensor-fusion/internal/constants"
2830
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/api"
2931
"github.com/NexusGPU/tensor-fusion/internal/hypervisor/framework"
3032
"google.golang.org/grpc"
@@ -117,6 +119,9 @@ func (dp *DevicePlugin) Start() error {
117119
return fmt.Errorf("failed to register with kubelet: %w", err)
118120
}
119121

122+
// Initialize device list with dummy index devices (1-512)
123+
dp.updateDeviceList()
124+
120125
// Start device monitoring
121126
go dp.monitorDevices()
122127

@@ -194,21 +199,20 @@ func (dp *DevicePlugin) monitorDevices() {
194199
}
195200
}
196201

197-
// updateDeviceList updates the list of available devices
202+
// updateDeviceList updates the list of available dummy index devices
203+
// This device plugin registers tensor-fusion.ai/index resource, not real GPU devices.
204+
// We advertise 512 dummy devices (indices 1-512) for pod identification.
205+
// Real GPU devices are allocated by scheduler and set in pod annotations.
198206
func (dp *DevicePlugin) updateDeviceList() {
199-
devices, err := dp.deviceController.ListDevices(dp.ctx)
200-
if err != nil {
201-
klog.Errorf("Failed to list devices: %v", err)
202-
return
203-
}
204-
205207
dp.mu.Lock()
206208
defer dp.mu.Unlock()
207209

208-
pluginDevices := make([]*pluginapi.Device, 0, len(devices))
209-
for _, device := range devices {
210+
// Advertise 512 dummy index devices (1-512) for pod identification
211+
// These are NOT real GPU devices - they're just used to match pods by index
212+
pluginDevices := make([]*pluginapi.Device, 0, 512)
213+
for i := 1; i <= 512; i++ {
210214
pluginDevices = append(pluginDevices, &pluginapi.Device{
211-
ID: device.UUID,
215+
ID: fmt.Sprintf("%d", i), // Index as device ID
212216
Health: pluginapi.Healthy,
213217
})
214218
}
@@ -259,44 +263,91 @@ func (dp *DevicePlugin) ListAndWatch(req *pluginapi.Empty, stream pluginapi.Devi
259263
}
260264

261265
// Allocate handles device allocation requests from kubelet
266+
// IMPORTANT: This device plugin registers tensor-fusion.ai/index as a dummy resource.
267+
// The pod index (1-512) is used to identify which pod is requesting allocation.
268+
// The actual GPU device UUIDs are already set by the centralized scheduler in pod annotations:
269+
// - tensor-fusion.ai/gpu-ids: comma-separated GPU UUIDs (for all isolation modes)
270+
// - tensor-fusion.ai/partition: partition template ID (only for partitioned isolation mode)
271+
//
272+
// The len(req.ContainerRequests) is just the number of containers in the pod requesting
273+
// tensor-fusion.ai/index resource - it's NOT the pod index. The pod index comes from
274+
// DevicesIds[0] which contains the index value from resource limits.
275+
//
276+
// We do NOT allocate the fake tensor-fusion.ai/index device - it's only used for pod identification.
277+
// CDIDevices in the response is kept empty to prevent kubelet from allocating the dummy device.
262278
func (dp *DevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
263-
klog.Infof("Allocate called with %d container requests", len(req.ContainerRequests))
279+
// len(req.ContainerRequests) identifies how many containers in the pod are requesting
280+
// tensor-fusion.ai/index resource - this is for logging/identification only
281+
klog.Infof("Allocate called with %d container requests (pod may have multiple containers)", len(req.ContainerRequests))
264282

265283
responses := make([]*pluginapi.ContainerAllocateResponse, 0, len(req.ContainerRequests))
266284

267-
for _, containerReq := range req.ContainerRequests {
268-
// Extract pod UID and namespace from environment variables or annotations
269-
// The kubelet passes these in the container request
270-
podUID := ""
271-
podName := ""
272-
namespace := ""
285+
for containerIdx, containerReq := range req.ContainerRequests {
286+
// Extract pod index from DevicesIds - this contains the index value (1-512) from resource limits
287+
// Resource limit: tensor-fusion.ai/index: 3 -> DevicesIds: ["3"]
288+
// This is the actual pod index used to match the pod in the pod cache
289+
if len(containerReq.DevicesIds) == 0 {
290+
return nil, fmt.Errorf("container request %d has no DevicesIds (expected pod index value 1-512)", containerIdx)
291+
}
292+
293+
// The DevicesIds contains the pod index value (1-512) from resource limits
294+
// This is NOT the device to allocate - it's just the pod identifier
295+
podIndex := containerReq.DevicesIds[0]
296+
if podIndex == "" {
297+
return nil, fmt.Errorf("container request %d has empty DevicesIds (expected pod index)", containerIdx)
298+
}
299+
300+
// Validate index is in valid range (1-512)
301+
indexNum, err := strconv.Atoi(podIndex)
302+
if err != nil {
303+
return nil, fmt.Errorf("container request %d has invalid index format: %s (expected number 1-512)", containerIdx, podIndex)
304+
}
305+
if indexNum < 1 || indexNum > 512 {
306+
return nil, fmt.Errorf("container request %d has index out of range: %d (expected 1-512)", containerIdx, indexNum)
307+
}
308+
309+
klog.V(4).Infof("Processing allocation for container index %d, pod index %s (from DevicesIds)", containerIdx, podIndex)
273310

274-
// Get worker info from kubelet client
275-
workerInfo, err := dp.kubeletClient.GetWorkerInfoForAllocation(ctx, containerReq)
311+
// Get worker info from kubelet client using pod index
312+
workerInfo, err := dp.kubeletClient.GetWorkerInfoForAllocationByIndex(ctx, podIndex)
276313
if err != nil {
277-
klog.Errorf("Failed to get worker info: %v", err)
278-
return nil, fmt.Errorf("failed to get worker info: %w", err)
314+
klog.Errorf("Failed to get worker info for pod index %s: %v", podIndex, err)
315+
return nil, fmt.Errorf("failed to get worker info for pod index %s: %w", podIndex, err)
279316
}
280317

281318
if workerInfo == nil {
282-
return nil, fmt.Errorf("worker info not found for allocation request")
319+
return nil, fmt.Errorf("worker info not found for pod index %s", podIndex)
283320
}
284321

285-
podUID = workerInfo.PodUID
286-
podName = workerInfo.PodName
287-
namespace = workerInfo.Namespace
322+
// Check for duplicate index annotations (multiple pods with same index)
323+
if err := dp.kubeletClient.CheckDuplicateIndex(ctx, podIndex, workerInfo.PodUID); err != nil {
324+
klog.Errorf("Duplicate index detected for pod index %s: %v", podIndex, err)
325+
return nil, fmt.Errorf("duplicate index detected: %w", err)
326+
}
288327

289-
// Compose allocation request
290-
deviceUUIDs := make([]string, 0, len(containerReq.DevicesIds))
291-
deviceUUIDs = append(deviceUUIDs, containerReq.DevicesIds...)
328+
// Device UUIDs are already set by scheduler in annotations, not from DevicesIds
329+
// DevicesIds is just the dummy tensor-fusion.ai/index resource
330+
deviceUUIDs := workerInfo.DeviceUUIDs
331+
if len(deviceUUIDs) == 0 {
332+
return nil, fmt.Errorf("no device UUIDs found in pod annotations for pod %s/%s", workerInfo.Namespace, workerInfo.PodName)
333+
}
292334

335+
// Extract partition template ID if in partitioned mode
336+
templateID := workerInfo.TemplateID
337+
if workerInfo.IsolationMode == api.IsolationModePartitioned {
338+
if partitionID, exists := workerInfo.Annotations[constants.PartitionTemplateIDAnnotation]; exists {
339+
templateID = partitionID
340+
}
341+
}
342+
343+
// Compose allocation request
293344
allocReq := &api.DeviceAllocateRequest{
294-
WorkerUID: podUID,
345+
WorkerUID: workerInfo.PodUID,
295346
DeviceUUIDs: deviceUUIDs,
296347
IsolationMode: workerInfo.IsolationMode,
297348
MemoryLimitBytes: workerInfo.MemoryLimitBytes,
298349
ComputeLimitUnits: workerInfo.ComputeLimitUnits,
299-
TemplateID: workerInfo.TemplateID,
350+
TemplateID: templateID,
300351
}
301352

302353
// Call device controller to allocate
@@ -310,10 +361,13 @@ func (dp *DevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateReq
310361
}
311362

312363
// Build container response
364+
// IMPORTANT: CdiDevices must be empty to prevent dummy tensor-fusion.ai/index device
365+
// from being allocated by kubelet
313366
containerResp := &pluginapi.ContainerAllocateResponse{
314-
Envs: allocResp.EnvVars,
315-
Mounts: make([]*pluginapi.Mount, 0),
316-
Devices: make([]*pluginapi.DeviceSpec, 0),
367+
Envs: allocResp.EnvVars,
368+
Mounts: make([]*pluginapi.Mount, 0),
369+
Devices: make([]*pluginapi.DeviceSpec, 0),
370+
CdiDevices: []*pluginapi.CDIDevice{}, // Empty to prevent dummy device allocation
317371
}
318372

319373
// Add device nodes
@@ -341,22 +395,29 @@ func (dp *DevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateReq
341395

342396
// Store allocation info in kubelet client
343397
allocation := &api.DeviceAllocation{
344-
DeviceUUID: deviceUUIDs[0], // Assuming single device for now
345-
PodUID: podUID,
346-
PodName: podName,
347-
Namespace: namespace,
398+
DeviceUUID: deviceUUIDs[0], // Use first device UUID
399+
PodUID: workerInfo.PodUID,
400+
PodName: workerInfo.PodName,
401+
Namespace: workerInfo.Namespace,
348402
IsolationMode: workerInfo.IsolationMode,
349-
TemplateID: workerInfo.TemplateID,
403+
TemplateID: templateID,
350404
MemoryLimit: workerInfo.MemoryLimitBytes,
351405
ComputeLimit: workerInfo.ComputeLimitUnits,
352-
WorkerID: podUID,
406+
WorkerID: workerInfo.PodUID,
353407
AllocatedAt: time.Now(),
354408
}
355409

356-
if err := dp.kubeletClient.StoreAllocation(podUID, allocation); err != nil {
410+
if err := dp.kubeletClient.StoreAllocation(workerInfo.PodUID, allocation); err != nil {
357411
klog.Warningf("Failed to store allocation: %v", err)
358412
}
359413

414+
// Remove PodIndexAnnotation after successful allocation to release the index
415+
// This prevents the index from being matched to this pod in future allocation cycles
416+
if err := dp.kubeletClient.RemovePodIndexAnnotation(ctx, workerInfo.PodUID, workerInfo.Namespace, workerInfo.PodName); err != nil {
417+
klog.Warningf("Failed to remove pod index annotation for pod %s/%s: %v", workerInfo.Namespace, workerInfo.PodName, err)
418+
// Don't fail allocation if annotation removal fails
419+
}
420+
360421
responses = append(responses, containerResp)
361422
}
362423

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
Copyright 2024.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package kubernetes
18+
19+
import (
20+
"testing"
21+
22+
"github.com/stretchr/testify/assert"
23+
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
24+
)
25+
26+
// TestDevicePluginAllocate_ExtractsIndexFromDevicesIds tests that the device plugin
27+
// correctly extracts the pod index from DevicesIds[0], not from len(req.ContainerRequests)
28+
// This is a key test to verify the device plugin implementation matches the design:
29+
// - DevicesIds[0] contains the index value (1-512) from resource limits
30+
// - len(req.ContainerRequests) is just the number of containers, NOT the pod index
31+
// - CdiDevices must be empty to prevent dummy device allocation
32+
func TestDevicePluginAllocate_ExtractsIndexFromDevicesIds(t *testing.T) {
33+
// This test verifies the key design principle:
34+
// The pod index comes from DevicesIds[0], which contains the value from
35+
// tensor-fusion.ai/index resource limit, NOT from len(req.ContainerRequests)
36+
37+
req := &pluginapi.AllocateRequest{
38+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
39+
{
40+
DevicesIds: []string{"3"}, // Index "3" from resource limit
41+
},
42+
},
43+
}
44+
45+
// Verify the structure: len(ContainerRequests) = 1, but index is "3" from DevicesIds[0]
46+
assert.Len(t, req.ContainerRequests, 1, "Should have 1 container request")
47+
assert.Equal(t, "3", req.ContainerRequests[0].DevicesIds[0], "Index should come from DevicesIds[0], not from len(ContainerRequests)")
48+
49+
// This demonstrates that len(req.ContainerRequests) is NOT the pod index
50+
// The pod index is extracted from DevicesIds[0]
51+
assert.NotEqual(t, len(req.ContainerRequests), 3, "len(ContainerRequests) should NOT equal the pod index")
52+
}
53+
54+
// TestDevicePluginAllocate_MultipleContainers tests that len(req.ContainerRequests)
55+
// is used for iteration, not for pod index identification
56+
func TestDevicePluginAllocate_MultipleContainers(t *testing.T) {
57+
// Create request with 2 containers, both with index "5"
58+
// len(ContainerRequests) = 2, but pod index is still "5" from DevicesIds
59+
req := &pluginapi.AllocateRequest{
60+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
61+
{
62+
DevicesIds: []string{"5"}, // First container: index 5
63+
},
64+
{
65+
DevicesIds: []string{"5"}, // Second container: same pod, same index
66+
},
67+
},
68+
}
69+
70+
// Verify: len(ContainerRequests) = 2, but index is "5" from DevicesIds
71+
assert.Len(t, req.ContainerRequests, 2, "Should have 2 container requests")
72+
assert.Equal(t, "5", req.ContainerRequests[0].DevicesIds[0], "First container index from DevicesIds")
73+
assert.Equal(t, "5", req.ContainerRequests[1].DevicesIds[0], "Second container index from DevicesIds")
74+
75+
// Key verification: len(ContainerRequests) is NOT the pod index
76+
assert.NotEqual(t, len(req.ContainerRequests), 5, "len(ContainerRequests) should NOT equal the pod index")
77+
78+
// Both containers have the same index because they're in the same pod
79+
assert.Equal(t, req.ContainerRequests[0].DevicesIds[0], req.ContainerRequests[1].DevicesIds[0],
80+
"Both containers should have the same index (same pod)")
81+
}

0 commit comments

Comments
 (0)