Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,21 @@ import (
)

func main() {
ret := nvml.Init()
ret, err := nvml.Init()
if err != nil {
log.Fatalf("Unable to open NVML library: %v", err)
}
if ret != nvml.SUCCESS {
log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret))
}
defer func() {
ret := nvml.Shutdown()
ret, err := nvml.Shutdown()
if ret != nvml.SUCCESS {
log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret))
}
if err != nil {
log.Fatalf("Unable to close NVML library: %v", err)
}
}()

count, ret := nvml.DeviceGetCount()
Expand Down
10 changes: 8 additions & 2 deletions examples/compute-processes/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,21 @@ import (
)

func main() {
ret := nvml.Init()
ret, err := nvml.Init()
if err != nil {
log.Fatalf("Unable to open NVML library: %v", err)
}
if ret != nvml.SUCCESS {
log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret))
}
defer func() {
ret := nvml.Shutdown()
ret, err := nvml.Shutdown()
if ret != nvml.SUCCESS {
log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret))
}
if err != nil {
log.Fatalf("Unable to close NVML library: %v", err)
}
}()

count, ret := nvml.DeviceGetCount()
Expand Down
10 changes: 8 additions & 2 deletions examples/devices/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,21 @@ import (
)

func main() {
ret := nvml.Init()
ret, err := nvml.Init()
if err != nil {
log.Fatalf("Unable to open NVML library: %v", err)
}
if ret != nvml.SUCCESS {
log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret))
}
defer func() {
ret := nvml.Shutdown()
ret, err := nvml.Shutdown()
if ret != nvml.SUCCESS {
log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret))
}
if err != nil {
log.Fatalf("Unable to close NVML library: %v", err)
}
}()

count, ret := nvml.DeviceGetCount()
Expand Down
22 changes: 10 additions & 12 deletions gen/nvml/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package nvml

import (
"fmt"

"github.com/NVIDIA/go-nvml/pkg/dl"
)

Expand All @@ -30,45 +28,45 @@ const (
var nvml *dl.DynamicLibrary

// nvml.Init()
func Init() Return {
func Init() (Return, error) {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
return ERROR_LIBRARY_NOT_FOUND
return ERROR_LIBRARY_NOT_FOUND, err
}

nvml = lib
updateVersionedSymbols()

return nvmlInit()
return nvmlInit(), nil
}

// nvml.InitWithFlags()
func InitWithFlags(Flags uint32) Return {
func InitWithFlags(Flags uint32) (Return, error) {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
return ERROR_LIBRARY_NOT_FOUND
return ERROR_LIBRARY_NOT_FOUND, err
}

nvml = lib

return nvmlInitWithFlags(Flags)
return nvmlInitWithFlags(Flags), nil
}

// nvml.Shutdown()
func Shutdown() Return {
func Shutdown() (Return, error) {
ret := nvmlShutdown()
if ret != SUCCESS {
return ret
return ret, nil
}

err := nvml.Close()
if err != nil {
panic(fmt.Sprintf("error closing %s: %v", nvmlLibraryName, err))
return ret, err
}

return ret
return ret, nil
}

// Default all versioned APIs to v1 (to infer the types)
Expand Down
10 changes: 8 additions & 2 deletions gen/nvml/nvml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@ import (
)

func TestInit(t *testing.T) {
ret := Init()
ret, err := Init()
if err != nil {
t.Errorf("NVML open: %v", err)
}
if ret != SUCCESS {
t.Errorf("Init: %v", ret)
} else {
t.Logf("Init: %v", ret)
}

ret = Shutdown()
ret, err = Shutdown()
if ret != SUCCESS {
t.Errorf("Shutdown: %v", ret)
} else {
t.Logf("Shutdown: %v", ret)
}
if err != nil {
t.Errorf("NVML close: %v", err)
}
}

func TestSystem(t *testing.T) {
Expand Down
22 changes: 10 additions & 12 deletions pkg/nvml/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package nvml

import (
"fmt"

"github.com/NVIDIA/go-nvml/pkg/dl"
)

Expand All @@ -30,45 +28,45 @@ const (
var nvml *dl.DynamicLibrary

// nvml.Init()
func Init() Return {
func Init() (Return, error) {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
return ERROR_LIBRARY_NOT_FOUND
return ERROR_LIBRARY_NOT_FOUND, err
}

nvml = lib
updateVersionedSymbols()

return nvmlInit()
return nvmlInit(), nil
}

// nvml.InitWithFlags()
func InitWithFlags(Flags uint32) Return {
func InitWithFlags(Flags uint32) (Return, error) {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
return ERROR_LIBRARY_NOT_FOUND
return ERROR_LIBRARY_NOT_FOUND, err
}

nvml = lib

return nvmlInitWithFlags(Flags)
return nvmlInitWithFlags(Flags), nil
}

// nvml.Shutdown()
func Shutdown() Return {
func Shutdown() (Return, error) {
ret := nvmlShutdown()
if ret != SUCCESS {
return ret
return ret, nil
}

err := nvml.Close()
if err != nil {
panic(fmt.Sprintf("error closing %s: %v", nvmlLibraryName, err))
return ret, err
}

return ret
return ret, nil
}

// Default all versioned APIs to v1 (to infer the types)
Expand Down
10 changes: 8 additions & 2 deletions pkg/nvml/nvml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@ import (
)

func TestInit(t *testing.T) {
ret := Init()
ret, err := Init()
if err != nil {
t.Errorf("NVML open: %v", err)
}
if ret != SUCCESS {
t.Errorf("Init: %v", ret)
} else {
t.Logf("Init: %v", ret)
}

ret = Shutdown()
ret, err = Shutdown()
if ret != SUCCESS {
t.Errorf("Shutdown: %v", ret)
} else {
t.Logf("Shutdown: %v", ret)
}
if err != nil {
t.Errorf("NVML close: %v", err)
}
}

func TestSystem(t *testing.T) {
Expand Down