diff --git a/pkg/dcgm/api.go b/pkg/dcgm/api.go index 7dd9acf..3a88cbb 100644 --- a/pkg/dcgm/api.go +++ b/pkg/dcgm/api.go @@ -113,14 +113,14 @@ func HealthCheckByGpuId(gpuID uint) (DeviceHealth, error) { // ListenForPolicyViolations sets up monitoring for the specified policy conditions on all GPUs // Returns a channel that receives policy violations and any error encountered -func ListenForPolicyViolations(ctx context.Context, typ ...policyCondition) (<-chan PolicyViolation, error) { +func ListenForPolicyViolations(ctx context.Context, typ ...PolicyCondition) (<-chan PolicyViolation, error) { groupID := GroupAllGPUs() return ListenForPolicyViolationsForGroup(ctx, groupID, typ...) } // ListenForPolicyViolationsForGroup sets up policy monitoring for the specified GPU group // Returns a channel that receives policy violations and any error encountered -func ListenForPolicyViolationsForGroup(ctx context.Context, group GroupHandle, typ ...policyCondition) (<-chan PolicyViolation, error) { +func ListenForPolicyViolationsForGroup(ctx context.Context, group GroupHandle, typ ...PolicyCondition) (<-chan PolicyViolation, error) { return registerPolicy(ctx, group, typ...) } @@ -143,3 +143,23 @@ func GetNvLinkLinkStatus() ([]NvLinkStatus, error) { func GetNvLinkP2PStatus() (NvLinkP2PStatus, error) { return getNvLinkP2PStatus() } + +// SetPolicyForGroup configures policies with optional custom thresholds and actions for a GPU group +func SetPolicyForGroup(group GroupHandle, configs ...PolicyConfig) error { + return setPolicyForGroupWithConfig(group, configs...) +} + +// GetPolicyForGroup retrieves the current policy configuration for a GPU group +func GetPolicyForGroup(group GroupHandle) (*PolicyStatus, error) { + return getPolicyForGroup(group) +} + +// ClearPolicyForGroup clears all policy conditions for a GPU group +func ClearPolicyForGroup(group GroupHandle) error { + return clearPolicyForGroup(group) +} + +// WatchPolicyViolationsForGroup registers to receive violation notifications for a specific GPU group +func WatchPolicyViolationsForGroup(ctx context.Context, group GroupHandle, typ ...PolicyCondition) (<-chan PolicyViolation, error) { + return registerPolicyOnly(ctx, group, typ...) +} diff --git a/pkg/dcgm/policy.go b/pkg/dcgm/policy.go index 466481a..79172a5 100644 --- a/pkg/dcgm/policy.go +++ b/pkg/dcgm/policy.go @@ -20,36 +20,100 @@ import ( ) // PolicyCondition represents a type of policy violation that can be monitored -type policyCondition string +type PolicyCondition string + +// This alias is maintained for backward compatibility. +type policyCondition = PolicyCondition // Policy condition types const ( // DbePolicy represents a Double-bit ECC error policy condition - DbePolicy = policyCondition("Double-bit ECC error") + DbePolicy = PolicyCondition("Double-bit ECC error") // PCIePolicy represents a PCI error policy condition - PCIePolicy = policyCondition("PCI error") + PCIePolicy = PolicyCondition("PCI error") // MaxRtPgPolicy represents a Maximum Retired Pages Limit policy condition - MaxRtPgPolicy = policyCondition("Max Retired Pages Limit") + MaxRtPgPolicy = PolicyCondition("Max Retired Pages Limit") // ThermalPolicy represents a Thermal Limit policy condition - ThermalPolicy = policyCondition("Thermal Limit") + ThermalPolicy = PolicyCondition("Thermal Limit") // PowerPolicy represents a Power Limit policy condition - PowerPolicy = policyCondition("Power Limit") + PowerPolicy = PolicyCondition("Power Limit") // NvlinkPolicy represents an NVLink error policy condition - NvlinkPolicy = policyCondition("Nvlink Error") + NvlinkPolicy = PolicyCondition("Nvlink Error") // XidPolicy represents an XID error policy condition - XidPolicy = policyCondition("XID Error") + XidPolicy = PolicyCondition("XID Error") +) + +// Default policy thresholds matching dcgmi defaults +const ( + // DefaultMaxRetiredPages is the default threshold for retired pages (matches dcgmi default) + DefaultMaxRetiredPages = 10 + + // DefaultMaxTemperature is the default threshold for temperature in Celsius (matches dcgmi default) + DefaultMaxTemperature = 100 + + // DefaultMaxPower is the default threshold for power in Watts (matches dcgmi default) + DefaultMaxPower = 250 +) + +// PolicyAction specifies the action to take when a policy violation occurs +type PolicyAction uint32 + +const ( + // PolicyActionNone indicates no action should be taken on violation (default) + PolicyActionNone PolicyAction = 0 + + // PolicyActionGPUReset indicates the GPU should be reset on violation + PolicyActionGPUReset PolicyAction = 1 ) +// PolicyValidation specifies the validation to perform after a policy action +type PolicyValidation uint32 + +const ( + // PolicyValidationNone indicates no validation after action (default) + PolicyValidationNone PolicyValidation = 0 + + // PolicyValidationShort indicates a short system validation should be performed + PolicyValidationShort PolicyValidation = 1 + + // PolicyValidationMedium indicates a medium system validation should be performed + PolicyValidationMedium PolicyValidation = 2 + + // PolicyValidationLong indicates a long system validation should be performed + PolicyValidationLong PolicyValidation = 3 +) + +// PolicyConfig configures a policy condition with optional custom thresholds and actions +type PolicyConfig struct { + // Condition specifies the type of policy to monitor + Condition PolicyCondition + + // Action specifies what action to take when this policy violation occurs (optional, defaults to PolicyActionNone) + Action *PolicyAction + + // Validation specifies what validation to perform after the action (optional, defaults to PolicyValidationNone) + Validation *PolicyValidation + + // MaxRetiredPages specifies the threshold for MaxRtPgPolicy (optional, defaults to DefaultMaxRetiredPages) + MaxRetiredPages *uint32 + + // MaxTemperature specifies the threshold for ThermalPolicy in Celsius (optional, defaults to DefaultMaxTemperature) + MaxTemperature *uint32 + + // MaxPower specifies the threshold for PowerPolicy in Watts (optional, defaults to DefaultMaxPower) + MaxPower *uint32 +} + // PolicyViolation represents a detected violation of a policy condition type PolicyViolation struct { // Condition specifies the type of policy that was violated - Condition policyCondition + Condition PolicyCondition // Timestamp indicates when the violation occurred Timestamp time.Time // Data contains violation-specific details @@ -200,7 +264,7 @@ func makePolicyParmsMap() { // //export ViolationRegistration func ViolationRegistration(data unsafe.Pointer) int { - var con policyCondition + var con PolicyCondition var timestamp time.Time var val any @@ -324,7 +388,257 @@ func setPolicy(groupID GroupHandle, condition C.dcgmPolicyCondition_t, paramList return } -func registerPolicy(ctx context.Context, groupID GroupHandle, typ ...policyCondition) (<-chan PolicyViolation, error) { +func setPolicyInternal(groupID GroupHandle, condition C.dcgmPolicyCondition_t, configs []policyConfigInternal, action PolicyAction, validation PolicyValidation) (err error) { + var policy C.dcgmPolicy_t + policy.version = makeVersion1(unsafe.Sizeof(policy)) + policy.mode = C.dcgmPolicyMode_t(C.DCGM_OPERATION_MODE_AUTO) + policy.action = C.dcgmPolicyAction_t(action) + policy.isolation = C.DCGM_POLICY_ISOLATION_NONE + policy.validation = C.dcgmPolicyValidation_t(validation) + policy.condition = condition + + // iterate on configs for given policy conditions + for _, cfg := range configs { + // set policy condition parameters + // set condition type (bool or longlong) + policy.parms[cfg.index].tag = cfg.param.typ + + // set condition val (violation threshold) + // policy.parms.val is a C union type + // cgo docs: Go doesn't have support for C's union type + // C union types are represented as a Go byte array + binary.LittleEndian.PutUint32(policy.parms[cfg.index].val[:], cfg.param.value) + } + + var statusHandle C.dcgmStatus_t + + result := C.dcgmPolicySet(handle.handle, groupID.handle, &policy, statusHandle) + if err = errorString(result); err != nil { + return fmt.Errorf("error setting policies: %s", err) + } + + return +} + +type policyConfigInternal struct { + index policyIndex + param policyConditionParam +} + +// PolicyStatus represents the current policy configuration for a group +type PolicyStatus struct { + // Mode indicates the operation mode (automatic or manual) + Mode uint32 + + // Action specifies what action is taken on violation + Action PolicyAction + + // Validation specifies what validation is performed after action + Validation PolicyValidation + + // Conditions is a map of enabled policy conditions with their thresholds + // Key is the PolicyCondition, value is the threshold (if applicable) + Conditions map[PolicyCondition]interface{} +} + +func getPolicyForGroup(groupID GroupHandle) (*PolicyStatus, error) { + var policy C.dcgmPolicy_t + policy.version = makeVersion1(unsafe.Sizeof(policy)) + + var statusHandle C.dcgmStatus_t + + result := C.dcgmPolicyGet(handle.handle, groupID.handle, 1, &policy, statusHandle) + if err := errorString(result); err != nil { + return nil, fmt.Errorf("error getting policy: %s", err) + } + + status := &PolicyStatus{ + Mode: uint32(policy.mode), + Action: PolicyAction(policy.action), + Validation: PolicyValidation(policy.validation), + Conditions: make(map[PolicyCondition]interface{}), + } + + condition := policy.condition + + // Check each condition bit and extract its parameters + if condition&C.DCGM_POLICY_COND_DBE != 0 { + status.Conditions[DbePolicy] = true + } + + if condition&C.DCGM_POLICY_COND_PCI != 0 { + status.Conditions[PCIePolicy] = true + } + + if condition&C.DCGM_POLICY_COND_MAX_PAGES_RETIRED != 0 { + param := policy.parms[maxRtPgPolicyIndex] + if param.tag == 1 { // LLONG type + threshold := binary.LittleEndian.Uint32(param.val[:]) + status.Conditions[MaxRtPgPolicy] = threshold + } + } + + if condition&C.DCGM_POLICY_COND_THERMAL != 0 { + param := policy.parms[thermalPolicyIndex] + if param.tag == 1 { // LLONG type + threshold := binary.LittleEndian.Uint32(param.val[:]) + status.Conditions[ThermalPolicy] = threshold + } + } + + if condition&C.DCGM_POLICY_COND_POWER != 0 { + param := policy.parms[powerPolicyIndex] + if param.tag == 1 { // LLONG type + threshold := binary.LittleEndian.Uint32(param.val[:]) + status.Conditions[PowerPolicy] = threshold + } + } + + if condition&C.DCGM_POLICY_COND_NVLINK != 0 { + status.Conditions[NvlinkPolicy] = true + } + + if condition&C.DCGM_POLICY_COND_XID != 0 { + status.Conditions[XidPolicy] = true + } + + return status, nil +} + +func clearPolicyForGroup(groupID GroupHandle) error { + // Clear all policies by setting condition to 0 (no conditions enabled) + var policy C.dcgmPolicy_t + policy.version = makeVersion1(unsafe.Sizeof(policy)) + policy.mode = C.dcgmPolicyMode_t(C.DCGM_OPERATION_MODE_AUTO) + policy.action = C.DCGM_POLICY_ACTION_NONE + policy.isolation = C.DCGM_POLICY_ISOLATION_NONE + policy.validation = C.DCGM_POLICY_VALID_NONE + policy.condition = 0 // No conditions - clears all policies + + var statusHandle C.dcgmStatus_t + + result := C.dcgmPolicySet(handle.handle, groupID.handle, &policy, statusHandle) + if err := errorString(result); err != nil { + return fmt.Errorf("error clearing policies: %s", err) + } + + return nil +} + +func setPolicyForGroupWithConfig(groupID GroupHandle, configs ...PolicyConfig) error { + const ( + policyFieldTypeBool = 0 + policyFieldTypeLong = 1 + policyBoolValue = 1 + ) + + if len(configs) == 0 { + return fmt.Errorf("at least one policy config must be provided") + } + + // Extract action and validation from first config (applies to all conditions) + // This matches dcgmi behavior where --set actn,val applies to the entire policy set + action := PolicyActionNone + validation := PolicyValidationNone + + if configs[0].Action != nil { + action = *configs[0].Action + } + if configs[0].Validation != nil { + validation = *configs[0].Validation + } + + // Build internal configs with custom or default thresholds + internalConfigs := make([]policyConfigInternal, len(configs)) + var condition C.dcgmPolicyCondition_t = 0 + + for i, cfg := range configs { + var idx policyIndex + var param policyConditionParam + + switch cfg.Condition { + case DbePolicy: + idx = dbePolicyIndex + condition |= C.DCGM_POLICY_COND_DBE + param = policyConditionParam{ + typ: policyFieldTypeBool, + value: policyBoolValue, + } + + case PCIePolicy: + idx = pciePolicyIndex + condition |= C.DCGM_POLICY_COND_PCI + param = policyConditionParam{ + typ: policyFieldTypeBool, + value: policyBoolValue, + } + + case MaxRtPgPolicy: + idx = maxRtPgPolicyIndex + condition |= C.DCGM_POLICY_COND_MAX_PAGES_RETIRED + threshold := uint32(DefaultMaxRetiredPages) + if cfg.MaxRetiredPages != nil { + threshold = *cfg.MaxRetiredPages + } + param = policyConditionParam{ + typ: policyFieldTypeLong, + value: threshold, + } + + case ThermalPolicy: + idx = thermalPolicyIndex + condition |= C.DCGM_POLICY_COND_THERMAL + threshold := uint32(DefaultMaxTemperature) + if cfg.MaxTemperature != nil { + threshold = *cfg.MaxTemperature + } + param = policyConditionParam{ + typ: policyFieldTypeLong, + value: threshold, + } + + case PowerPolicy: + idx = powerPolicyIndex + condition |= C.DCGM_POLICY_COND_POWER + threshold := uint32(DefaultMaxPower) + if cfg.MaxPower != nil { + threshold = *cfg.MaxPower + } + param = policyConditionParam{ + typ: policyFieldTypeLong, + value: threshold, + } + + case NvlinkPolicy: + idx = nvlinkPolicyIndex + condition |= C.DCGM_POLICY_COND_NVLINK + param = policyConditionParam{ + typ: policyFieldTypeBool, + value: policyBoolValue, + } + + case XidPolicy: + idx = xidPolicyIndex + condition |= C.DCGM_POLICY_COND_XID + param = policyConditionParam{ + typ: policyFieldTypeBool, + value: policyBoolValue, + } + + default: + return fmt.Errorf("unknown policy condition: %s", cfg.Condition) + } + + internalConfigs[i] = policyConfigInternal{ + index: idx, + param: param, + } + } + + return setPolicyInternal(groupID, condition, internalConfigs, action, validation) +} + +func registerPolicy(ctx context.Context, groupID GroupHandle, typ ...PolicyCondition) (<-chan PolicyViolation, error) { var err error // init policy globals for internal API makePolicyChannels() @@ -408,6 +722,73 @@ func registerPolicy(ctx context.Context, groupID GroupHandle, typ ...policyCondi return violation, err } +func registerPolicyOnly(ctx context.Context, groupID GroupHandle, typ ...PolicyCondition) (<-chan PolicyViolation, error) { + var err error + // init policy globals for internal API + makePolicyChannels() + + // get all conditions to listen for + var condition C.dcgmPolicyCondition_t = 0 + + for _, t := range typ { + switch t { + case DbePolicy: + condition |= C.DCGM_POLICY_COND_DBE + case PCIePolicy: + condition |= C.DCGM_POLICY_COND_PCI + case MaxRtPgPolicy: + condition |= C.DCGM_POLICY_COND_MAX_PAGES_RETIRED + case ThermalPolicy: + condition |= C.DCGM_POLICY_COND_THERMAL + case PowerPolicy: + condition |= C.DCGM_POLICY_COND_POWER + case NvlinkPolicy: + condition |= C.DCGM_POLICY_COND_NVLINK + case XidPolicy: + condition |= C.DCGM_POLICY_COND_XID + } + } + + // Register for violations without setting policies + result := C.dcgmPolicyRegister_v2(handle.handle, groupID.handle, condition, C.fpRecvUpdates(C.violationNotify), C.ulong(0)) + + if err = errorString(result); err != nil { + return nil, &Error{msg: C.GoString(C.errorString(result)), Code: result} + } + + violation := make(chan PolicyViolation, len(typ)) + + go func() { + defer func() { + close(violation) + unregisterPolicy(groupID, condition) + }() + + for { + select { + case dbe := <-callbacks["dbe"]: + violation <- dbe + case pcie := <-callbacks["pcie"]: + violation <- pcie + case maxrtpg := <-callbacks["maxrtpg"]: + violation <- maxrtpg + case thermal := <-callbacks["thermal"]: + violation <- thermal + case power := <-callbacks["power"]: + violation <- power + case nvlink := <-callbacks["nvlink"]: + violation <- nvlink + case xid := <-callbacks["xid"]: + violation <- xid + case <-ctx.Done(): + return + } + } + }() + + return violation, err +} + func unregisterPolicy(groupID GroupHandle, condition C.dcgmPolicyCondition_t) { result := C.dcgmPolicyUnregister(handle.handle, groupID.handle, condition) diff --git a/pkg/dcgm/policy_test.go b/pkg/dcgm/policy_test.go index 0f4fdde..5dc52ce 100644 --- a/pkg/dcgm/policy_test.go +++ b/pkg/dcgm/policy_test.go @@ -356,3 +356,430 @@ func joinPolicy(policy []policyCondition, sep string) string { return result.String() } + +func TestSetAndGetPolicy(t *testing.T) { + t.Log("Initializing DCGM in Embedded mode...") + cleanup, err := Init(Embedded) + require.NoError(t, err) + defer cleanup() + t.Log("DCGM initialized successfully") + + group := GroupAllGPUs() + t.Logf("Created group handle for all GPUs: %+v", group) + + // Check how many GPUs we have + gpuCount, err := GetAllDeviceCount() + require.NoError(t, err) + t.Logf("Found %d GPU(s) in the system", gpuCount) + + action := PolicyActionNone + validation := PolicyValidationNone + + // Test cases for each policy type + testCases := []struct { + name string + config PolicyConfig + expected interface{} + conditionID PolicyCondition + }{ + { + name: "ThermalPolicy", + config: PolicyConfig{ + Condition: ThermalPolicy, + Action: &action, + Validation: &validation, + MaxTemperature: ptrUint32(85), + }, + expected: uint32(85), + conditionID: ThermalPolicy, + }, + { + name: "PowerPolicy", + config: PolicyConfig{ + Condition: PowerPolicy, + Action: &action, + Validation: &validation, + MaxPower: ptrUint32(300), + }, + expected: uint32(300), + conditionID: PowerPolicy, + }, + { + name: "MaxRtPgPolicy", + config: PolicyConfig{ + Condition: MaxRtPgPolicy, + Action: &action, + Validation: &validation, + MaxRetiredPages: ptrUint32(15), + }, + expected: uint32(15), + conditionID: MaxRtPgPolicy, + }, + { + name: "DbePolicy", + config: PolicyConfig{ + Condition: DbePolicy, + Action: &action, + Validation: &validation, + }, + expected: true, + conditionID: DbePolicy, + }, + { + name: "PCIePolicy", + config: PolicyConfig{ + Condition: PCIePolicy, + Action: &action, + Validation: &validation, + }, + expected: true, + conditionID: PCIePolicy, + }, + { + name: "NvlinkPolicy", + config: PolicyConfig{ + Condition: NvlinkPolicy, + Action: &action, + Validation: &validation, + }, + expected: true, + conditionID: NvlinkPolicy, + }, + { + name: "XidPolicy", + config: PolicyConfig{ + Condition: XidPolicy, + Action: &action, + Validation: &validation, + }, + expected: true, + conditionID: XidPolicy, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Logf("Setting %s policy...", tc.name) + + err := SetPolicyForGroup(group, tc.config) + require.NoError(t, err) + t.Logf("%s policy set successfully", tc.name) + + // Get the policy and verify it was set correctly + t.Log("Retrieving policy configuration...") + status, err := GetPolicyForGroup(group) + require.NoError(t, err) + require.NotNil(t, status) + t.Logf("Policy retrieved - Mode: %d, Action: %v, Validation: %v, Conditions: %v", + status.Mode, status.Action, status.Validation, status.Conditions) + + // Verify the policy is set + assert.Contains(t, status.Conditions, tc.conditionID) + assert.Equal(t, tc.expected, status.Conditions[tc.conditionID]) + assert.Equal(t, action, status.Action) + assert.Equal(t, validation, status.Validation) + + t.Logf("%s policy assertions passed", tc.name) + }) + } +} + +func TestSetAndGetMultiplePolicies(t *testing.T) { + t.Log("Initializing DCGM in Embedded mode...") + cleanup, err := Init(Embedded) + require.NoError(t, err) + defer cleanup() + t.Log("DCGM initialized successfully") + + group := GroupAllGPUs() + t.Logf("Created group handle for all GPUs: %+v", group) + + // Check how many GPUs we have + gpuCount, err := GetAllDeviceCount() + require.NoError(t, err) + t.Logf("Found %d GPU(s) in the system", gpuCount) + + action := PolicyActionNone + validation := PolicyValidationNone + + // Set multiple policies at once + t.Log("Setting multiple policies simultaneously...") + thermalThreshold := uint32(90) + powerThreshold := uint32(350) + maxRetiredPages := uint32(20) + + err = SetPolicyForGroup(group, + PolicyConfig{ + Condition: ThermalPolicy, + Action: &action, + Validation: &validation, + MaxTemperature: &thermalThreshold, + }, + PolicyConfig{ + Condition: PowerPolicy, + Action: &action, + Validation: &validation, + MaxPower: &powerThreshold, + }, + PolicyConfig{ + Condition: MaxRtPgPolicy, + Action: &action, + Validation: &validation, + MaxRetiredPages: &maxRetiredPages, + }, + PolicyConfig{ + Condition: DbePolicy, + Action: &action, + Validation: &validation, + }, + PolicyConfig{ + Condition: XidPolicy, + Action: &action, + Validation: &validation, + }, + ) + require.NoError(t, err) + t.Log("Multiple policies set successfully") + + // Get the policy and verify all were set correctly + t.Log("Retrieving policy configuration...") + status, err := GetPolicyForGroup(group) + require.NoError(t, err) + require.NotNil(t, status) + t.Logf("Policy retrieved - Mode: %d, Action: %v, Validation: %v, Conditions: %v", + status.Mode, status.Action, status.Validation, status.Conditions) + + // Verify all policies are present + t.Log("Verifying all policies were set correctly...") + require.Len(t, status.Conditions, 5, "Expected 5 policy conditions to be set") + + // Verify each policy individually + assert.Contains(t, status.Conditions, ThermalPolicy) + assert.Equal(t, thermalThreshold, status.Conditions[ThermalPolicy]) + t.Logf("✓ ThermalPolicy: %d°C", thermalThreshold) + + assert.Contains(t, status.Conditions, PowerPolicy) + assert.Equal(t, powerThreshold, status.Conditions[PowerPolicy]) + t.Logf("✓ PowerPolicy: %dW", powerThreshold) + + assert.Contains(t, status.Conditions, MaxRtPgPolicy) + assert.Equal(t, maxRetiredPages, status.Conditions[MaxRtPgPolicy]) + t.Logf("✓ MaxRtPgPolicy: %d pages", maxRetiredPages) + + assert.Contains(t, status.Conditions, DbePolicy) + assert.Equal(t, true, status.Conditions[DbePolicy]) + t.Log("✓ DbePolicy: enabled") + + assert.Contains(t, status.Conditions, XidPolicy) + assert.Equal(t, true, status.Conditions[XidPolicy]) + t.Log("✓ XidPolicy: enabled") + + // Verify action and validation apply to all + assert.Equal(t, action, status.Action) + assert.Equal(t, validation, status.Validation) + + t.Log("All multiple policy assertions passed") +} + +func TestSetPolicyAndWatchViolations(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + time.Sleep(100 * time.Millisecond) + }() + + t.Log("Initializing DCGM in Embedded mode...") + cleanup, err := Init(Embedded) + require.NoError(t, err) + defer cleanup() + t.Log("DCGM initialized successfully") + + numGPUs, err := GetAllDeviceCount() + require.NoError(t, err) + t.Logf("Found %d GPU(s) in the system", numGPUs) + + if numGPUs+1 > MAX_NUM_DEVICES { + t.Skipf("Unable to add fake GPU with more than %d gpus", MAX_NUM_DEVICES) + } + + // Create fake GPUs for testing + t.Log("Creating fake GPU entities for testing...") + entityList := []MigHierarchyInfo{ + {Entity: GroupEntityPair{EntityGroupId: FE_GPU}}, + {Entity: GroupEntityPair{EntityGroupId: FE_GPU}}, + {Entity: GroupEntityPair{EntityGroupId: FE_GPU}}, + {Entity: GroupEntityPair{EntityGroupId: FE_GPU}}, + } + _, err = CreateFakeEntities(entityList) + require.NoError(t, err) + t.Log("Fake GPU entities created") + + group := GroupAllGPUs() + t.Logf("Created group handle for all GPUs: %+v", group) + + // Set up policies with thresholds using SetPolicyForGroup + action := PolicyActionNone + validation := PolicyValidationNone + thermalThreshold := uint32(100) + powerThreshold := uint32(250) + + t.Log("Setting thermal and power policies with SetPolicyForGroup...") + err = SetPolicyForGroup(group, + PolicyConfig{ + Condition: ThermalPolicy, + Action: &action, + Validation: &validation, + MaxTemperature: &thermalThreshold, + }, + PolicyConfig{ + Condition: PowerPolicy, + Action: &action, + Validation: &validation, + MaxPower: &powerThreshold, + }, + ) + require.NoError(t, err) + t.Log("Policies set successfully with SetPolicyForGroup") + + // Watch for policy violations using WatchPolicyViolationsForGroup + t.Log("Starting to watch for policy violations with WatchPolicyViolationsForGroup...") + violations, err := WatchPolicyViolationsForGroup(ctx, group, ThermalPolicy, PowerPolicy) + require.NoError(t, err) + t.Log("Watching for violations") + + // Test 1: Inject thermal violation + t.Run("ThermalViolation", func(t *testing.T) { + gpu, _ := secureRandomUint(4) + t.Logf("Injecting thermal violation for GPU %d (threshold: %d°C)", gpu, thermalThreshold) + + err := InjectFieldValue(gpu, + DCGM_FI_DEV_GPU_TEMP, + DCGM_FT_INT64, + 0, + time.Now().Add(60*time.Second).UnixMicro(), + int64(thermalThreshold+1), // Exceed threshold + ) + require.NoError(t, err) + + // Wait for violation + select { + case violation := <-violations: + t.Logf("Received violation: %+v", violation) + assert.Equal(t, ThermalPolicy, violation.Condition) + require.IsType(t, ThermalPolicyCondition{}, violation.Data) + thermalData := violation.Data.(ThermalPolicyCondition) + assert.Equal(t, uint(thermalThreshold+1), thermalData.ThermalViolation) + t.Logf("✓ Thermal violation detected: %d°C", thermalData.ThermalViolation) + case <-time.After(20 * time.Second): + t.Fatal("Timeout waiting for thermal violation") + } + }) + + // Test 2: Inject power violation + t.Run("PowerViolation", func(t *testing.T) { + gpu, _ := secureRandomUint(4) + t.Logf("Injecting power violation for GPU %d (threshold: %dW)", gpu, powerThreshold) + + err := InjectFieldValue(gpu, + DCGM_FI_DEV_POWER_USAGE, + DCGM_FT_DOUBLE, + 0, + time.Now().Add(60*time.Second).UnixMicro(), + float64(powerThreshold+50), // Exceed threshold + ) + require.NoError(t, err) + + // Wait for violation + select { + case violation := <-violations: + t.Logf("Received violation: %+v", violation) + assert.Equal(t, PowerPolicy, violation.Condition) + require.IsType(t, PowerPolicyCondition{}, violation.Data) + powerData := violation.Data.(PowerPolicyCondition) + assert.Equal(t, uint(powerThreshold+50), powerData.PowerViolation) + t.Logf("✓ Power violation detected: %dW", powerData.PowerViolation) + case <-time.After(20 * time.Second): + t.Fatal("Timeout waiting for power violation") + } + }) + + t.Log("All SetPolicyForGroup + WatchPolicyViolationsForGroup tests passed") +} + +func TestClearPolicyForGroup(t *testing.T) { + t.Log("Initializing DCGM in Embedded mode...") + cleanup, err := Init(Embedded) + require.NoError(t, err) + defer cleanup() + t.Log("DCGM initialized successfully") + + group := GroupAllGPUs() + t.Logf("Created group handle for all GPUs: %+v", group) + + // Check how many GPUs we have + gpuCount, err := GetAllDeviceCount() + require.NoError(t, err) + t.Logf("Found %d GPU(s) in the system", gpuCount) + + action := PolicyActionNone + validation := PolicyValidationNone + + // Step 1: Set some policies + t.Log("Step 1: Setting multiple policies...") + thermalThreshold := uint32(90) + powerThreshold := uint32(350) + + err = SetPolicyForGroup(group, + PolicyConfig{ + Condition: ThermalPolicy, + Action: &action, + Validation: &validation, + MaxTemperature: &thermalThreshold, + }, + PolicyConfig{ + Condition: PowerPolicy, + Action: &action, + Validation: &validation, + MaxPower: &powerThreshold, + }, + PolicyConfig{ + Condition: DbePolicy, + Action: &action, + Validation: &validation, + }, + ) + require.NoError(t, err) + t.Log("Policies set successfully") + + // Step 2: Verify policies were set + t.Log("Step 2: Verifying policies were set...") + status, err := GetPolicyForGroup(group) + require.NoError(t, err) + require.NotNil(t, status) + require.Len(t, status.Conditions, 3, "Expected 3 policies to be set") + assert.Contains(t, status.Conditions, ThermalPolicy) + assert.Contains(t, status.Conditions, PowerPolicy) + assert.Contains(t, status.Conditions, DbePolicy) + t.Logf("Verified 3 policies are active: %v", status.Conditions) + + // Step 3: Clear all policies + t.Log("Step 3: Clearing all policies...") + err = ClearPolicyForGroup(group) + require.NoError(t, err) + t.Log("Policies cleared successfully") + + // Step 4: Verify policies were cleared + t.Log("Step 4: Verifying policies were cleared...") + status, err = GetPolicyForGroup(group) + require.NoError(t, err) + require.NotNil(t, status) + assert.Empty(t, status.Conditions, "Expected no policies after clear") + t.Logf("Verified all policies cleared. Conditions map: %v", status.Conditions) + + t.Log("All clear policy tests passed") +} + +// Helper function to create pointer to uint32 +func ptrUint32(v uint32) *uint32 { + return &v +}