From 40aabc8f7d53371f7518825919f5f4366a1a99a9 Mon Sep 17 00:00:00 2001 From: Vadym Fedorov Date: Tue, 14 Oct 2025 09:51:34 -0500 Subject: [PATCH] pass complete info instead of truncating Signed-off-by: Vadym Fedorov --- pkg/dcgm/diag.go | 9 +- pkg/dcgm/diag_test.go | 212 ++++++++++++++++++++++++++++++++++ pkg/dcgm/diag_test_helpers.go | 91 +++++++++++++++ 3 files changed, 307 insertions(+), 5 deletions(-) create mode 100644 pkg/dcgm/diag_test.go create mode 100644 pkg/dcgm/diag_test_helpers.go diff --git a/pkg/dcgm/diag.go b/pkg/dcgm/diag.go index f075b46..3e00e66 100644 --- a/pkg/dcgm/diag.go +++ b/pkg/dcgm/diag.go @@ -7,6 +7,7 @@ package dcgm import "C" import ( + "strings" "unsafe" ) @@ -122,16 +123,14 @@ func getErrorMsg(entityId uint, testId uint, response C.dcgmDiagResponse_v12) (m } func getInfoMsg(entityId uint, testId uint, response C.dcgmDiagResponse_v12) string { + var msgs []string for i := 0; i < int(response.numInfo); i++ { if uint(response.info[i].entity.entityId) != entityId || uint(response.info[i].testId) != testId { continue } - - msg := C.GoString((*C.char)(unsafe.Pointer(&response.info[i].msg))) - return msg + msgs = append(msgs, C.GoString((*C.char)(unsafe.Pointer(&response.info[i].msg)))) } - - return "" + return strings.Join(msgs, " | ") } func getTestName(resultIdx uint, response C.dcgmDiagResponse_v12) string { diff --git a/pkg/dcgm/diag_test.go b/pkg/dcgm/diag_test.go new file mode 100644 index 0000000..af4fb67 --- /dev/null +++ b/pkg/dcgm/diag_test.go @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dcgm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetInfoMsg_NoMessages verifies getInfoMsg returns empty string when no info messages exist +func TestGetInfoMsg_NoMessages(t *testing.T) { + response := createTestDiagResponse() + + result := getInfoMsg(0, 0, response) + + assert.Empty(t, result, "expected empty string when no info messages exist") +} + +// TestGetInfoMsg_SingleMessage verifies getInfoMsg returns the single message without separator +func TestGetInfoMsg_SingleMessage(t *testing.T) { + response := createTestDiagResponse() + + expectedMsg := "Allocated 83618558100 bytes (98.4%)" + addInfoMessage(&response, 0, testMemoryIndex, expectedMsg) + + result := getInfoMsg(0, testMemoryIndex, response) + + assert.Equal(t, expectedMsg, result, "expected single message to be returned as-is") +} + +// TestGetInfoMsg_MultipleMessages verifies all matching info messages are concatenated +func TestGetInfoMsg_MultipleMessages(t *testing.T) { + response := createTestDiagResponse() + + entityID := uint(0) + testID := uint(testPCIIndex) + + messages := []string{ + "GPU to Host bandwidth: 28.27 GB/s", + "Host to GPU bandwidth: 27.65 GB/s", + "bidirectional bandwidth: 50.59 GB/s", + "GPU to Host latency: 1.305 us", + "Host to GPU latency: 2.097 us", + "bidirectional latency: 2.666 us", + } + + for _, msg := range messages { + addInfoMessage(&response, entityID, testID, msg) + } + + result := getInfoMsg(entityID, testID, response) + + expected := "GPU to Host bandwidth: 28.27 GB/s | Host to GPU bandwidth: 27.65 GB/s | bidirectional bandwidth: 50.59 GB/s | GPU to Host latency: 1.305 us | Host to GPU latency: 2.097 us | bidirectional latency: 2.666 us" + assert.Equal(t, expected, result, "expected all messages to be concatenated with ' | ' separator") +} + +// TestGetInfoMsg_FiltersByEntityID verifies only messages matching entityId are returned +func TestGetInfoMsg_FiltersByEntityID(t *testing.T) { + response := createTestDiagResponse() + + targetEntityID := uint(0) + testID := uint(testMemoryIndex) + + // Add messages for different entities + addInfoMessage(&response, targetEntityID, testID, "Message for entity 0") + addInfoMessage(&response, 1, testID, "Message for entity 1") + addInfoMessage(&response, targetEntityID, testID, "Another message for entity 0") + + result := getInfoMsg(targetEntityID, testID, response) + + expected := "Message for entity 0 | Another message for entity 0" + assert.Equal(t, expected, result, "expected only messages matching entityId to be included") + assert.NotContains(t, result, "entity 1", "should not contain messages from different entity") +} + +// TestGetInfoMsg_FiltersByTestID verifies only messages matching testId are returned +func TestGetInfoMsg_FiltersByTestID(t *testing.T) { + response := createTestDiagResponse() + + entityID := uint(0) + targetTestID := uint(testMemoryIndex) + + // Add messages for different test IDs + addInfoMessage(&response, entityID, targetTestID, "Memory test message 1") + addInfoMessage(&response, entityID, testPCIIndex, "PCIe test message") + addInfoMessage(&response, entityID, targetTestID, "Memory test message 2") + + result := getInfoMsg(entityID, targetTestID, response) + + expected := "Memory test message 1 | Memory test message 2" + assert.Equal(t, expected, result, "expected only messages matching testId to be included") + assert.NotContains(t, result, "PCIe", "should not contain messages from different test") +} + +// TestGetInfoMsg_NoMatchingMessages verifies empty string when no messages match filters +func TestGetInfoMsg_NoMatchingMessages(t *testing.T) { + response := createTestDiagResponse() + + // Add messages that don't match the query + addInfoMessage(&response, 0, testMemoryIndex, "Some message") + addInfoMessage(&response, 1, testPCIIndex, "Another message") + + // Query with different entityId and testId + result := getInfoMsg(99, 99, response) + + assert.Empty(t, result, "expected empty string when no messages match the filters") +} + +// TestDiagResultString verifies diagResultString conversion +func TestDiagResultString(t *testing.T) { + tests := []struct { + name string + input int + expected string + }{ + {"pass", testDiagResultPass, "pass"}, + {"skip", testDiagResultSkip, "skipped"}, + {"warn", testDiagResultWarn, "warn"}, + {"fail", testDiagResultFail, "fail"}, + {"not run", testDiagResultNotRun, "notrun"}, + {"unknown", 999, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diagResultString(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestGpuTestName verifies gpuTestName conversion +func TestGpuTestName(t *testing.T) { + tests := []struct { + name string + input int + expected string + }{ + {"memory", testMemoryIndex, "memory"}, + {"diagnostic", testDiagnosticIndex, "diagnostic"}, + {"pcie", testPCIIndex, "pcie"}, + {"sm stress", testSMStressIndex, "sm stress"}, + {"targeted stress", testTargetedStressIndex, "targeted stress"}, + {"targeted power", testTargetedPowerIndex, "targeted power"}, + {"memory bandwidth", testMemoryBandwidthIndex, "memory bandwidth"}, + {"memtest", testMemtestIndex, "memtest"}, + {"pulse", testPulseTestIndex, "pulse"}, + {"eud", testEUDTestIndex, "eud"}, + {"software", testSoftwareIndex, "software"}, + {"context create", testContextCreateIndex, "context create"}, + {"unknown", 999, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := gpuTestName(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestNewDiagResult verifies DiagResult construction with multiple info messages +func TestNewDiagResult(t *testing.T) { + response := createTestDiagResponse() + + entityID := uint(0) + testID := uint(testPCIIndex) + serialNumber := "1652923033635" + + // Setup result + addDiagResult(&response, entityID, testID, testDiagResultPass) + + // Setup multiple info messages + messages := []string{ + "GPU to Host bandwidth: 28.27 GB/s", + "Host to GPU bandwidth: 27.65 GB/s", + "bidirectional bandwidth: 50.59 GB/s", + } + for _, msg := range messages { + addInfoMessage(&response, entityID, testID, msg) + } + + // Setup entity with serial number + addEntityWithSerial(&response, entityID, serialNumber) + + result := newDiagResult(0, response) + + require.NotNil(t, result) + assert.Equal(t, "pass", result.Status) + assert.Equal(t, "pcie", result.TestName) + assert.Equal(t, "GPU to Host bandwidth: 28.27 GB/s | Host to GPU bandwidth: 27.65 GB/s | bidirectional bandwidth: 50.59 GB/s", result.TestOutput) + assert.Equal(t, uint(0), result.ErrorCode) + assert.Empty(t, result.ErrorMessage) + assert.Equal(t, serialNumber, result.SerialNumber) + assert.Equal(t, entityID, result.EntityID) +} diff --git a/pkg/dcgm/diag_test_helpers.go b/pkg/dcgm/diag_test_helpers.go new file mode 100644 index 0000000..4f4b668 --- /dev/null +++ b/pkg/dcgm/diag_test_helpers.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dcgm + +/* +#include +#include +#include "dcgm_agent.h" +#include "dcgm_structs.h" +*/ +import "C" + +import ( + "unsafe" +) + +// createTestDiagResponse creates a dcgmDiagResponse_v12 for testing +func createTestDiagResponse() C.dcgmDiagResponse_v12 { + var response C.dcgmDiagResponse_v12 + response.version = C.dcgmDiagResponse_version12 + return response +} + +// addInfoMessage adds an info message to a dcgmDiagResponse_v12 for testing +func addInfoMessage(response *C.dcgmDiagResponse_v12, entityID uint, testID uint, message string) { + idx := response.numInfo + cStr := C.CString(message) + defer C.free(unsafe.Pointer(cStr)) + C.strcpy(&response.info[idx].msg[0], cStr) + response.info[idx].entity.entityId = C.uint(entityID) + response.info[idx].entity.entityGroupId = C.DCGM_FE_GPU + response.info[idx].testId = C.uint(testID) + response.numInfo++ +} + +// addDiagResult adds a diagnostic result to a dcgmDiagResponse_v12 for testing +func addDiagResult(response *C.dcgmDiagResponse_v12, entityID uint, testID uint, result int) { + idx := response.numResults + response.results[idx].entity.entityId = C.uint(entityID) + response.results[idx].entity.entityGroupId = C.DCGM_FE_GPU + response.results[idx].testId = C.uint(testID) + response.results[idx].result = C.dcgmDiagResult_t(result) + response.numResults++ +} + +// addEntityWithSerial adds an entity with serial number to a dcgmDiagResponse_v12 for testing +func addEntityWithSerial(response *C.dcgmDiagResponse_v12, entityID uint, serialNumber string) { + idx := response.numEntities + cStr := C.CString(serialNumber) + defer C.free(unsafe.Pointer(cStr)) + C.strcpy(&response.entities[idx].serialNum[0], cStr) + response.entities[idx].entity.entityId = C.uint(entityID) + response.entities[idx].entity.entityGroupId = C.DCGM_FE_GPU + response.numEntities++ +} + +// Test constants exposed for testing +const ( + testDiagResultPass = C.DCGM_DIAG_RESULT_PASS + testDiagResultSkip = C.DCGM_DIAG_RESULT_SKIP + testDiagResultWarn = C.DCGM_DIAG_RESULT_WARN + testDiagResultFail = C.DCGM_DIAG_RESULT_FAIL + testDiagResultNotRun = C.DCGM_DIAG_RESULT_NOT_RUN + + testMemoryIndex = C.DCGM_MEMORY_INDEX + testDiagnosticIndex = C.DCGM_DIAGNOSTIC_INDEX + testPCIIndex = C.DCGM_PCI_INDEX + testSMStressIndex = C.DCGM_SM_STRESS_INDEX + testTargetedStressIndex = C.DCGM_TARGETED_STRESS_INDEX + testTargetedPowerIndex = C.DCGM_TARGETED_POWER_INDEX + testMemoryBandwidthIndex = C.DCGM_MEMORY_BANDWIDTH_INDEX + testMemtestIndex = C.DCGM_MEMTEST_INDEX + testPulseTestIndex = C.DCGM_PULSE_TEST_INDEX + testEUDTestIndex = C.DCGM_EUD_TEST_INDEX + testSoftwareIndex = C.DCGM_SOFTWARE_INDEX + testContextCreateIndex = C.DCGM_CONTEXT_CREATE_INDEX +)