Skip to content

Commit 94a00c8

Browse files
committed
implement NRI plugin server to inject management CDI devices
Signed-off-by: Tariq Ibrahim <tibrahim@nvidia.com>
1 parent 34882b2 commit 94a00c8

File tree

589 files changed

+165236
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

589 files changed

+165236
-24
lines changed

cmd/nvidia-ctk-installer/container/container.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ type Options struct {
4949
// mount.
5050
ExecutablePath string
5151
// EnabledCDI indicates whether CDI should be enabled.
52-
EnableCDI bool
53-
RuntimeName string
54-
RuntimeDir string
55-
SetAsDefault bool
56-
RestartMode string
57-
HostRootMount string
52+
EnableCDI bool
53+
EnableNRI bool
54+
RuntimeName string
55+
RuntimeDir string
56+
SetAsDefault bool
57+
RestartMode string
58+
HostRootMount string
59+
NRIPluginIndex string
60+
NRISocket string
5861

5962
ConfigSources []string
6063
}
@@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
128131
cfg.EnableCDI()
129132
}
130133

134+
if o.EnableNRI {
135+
cfg.EnableNRI()
136+
}
137+
131138
return nil
132139
}
133140

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package nri
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
8+
"github.com/containerd/nri/pkg/api"
9+
nriplugin "github.com/containerd/nri/pkg/stub"
10+
"sigs.k8s.io/yaml"
11+
12+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
13+
)
14+
15+
// Compile-time interface checks
16+
var (
17+
_ nriplugin.Plugin = (*Plugin)(nil)
18+
)
19+
20+
const (
21+
// nodeResourceCDIDeviceKey is the prefix of the key used for CDI device annotations.
22+
nodeResourceCDIDeviceKey = "cdi-devices.noderesource.dev"
23+
// nriCDIDeviceKey is the prefix of the key used for CDI device annotations.
24+
nriCDIDeviceKey = "cdi-devices.nri.io"
25+
// defaultNRISocket represents the default path of the NRI socket
26+
defaultNRISocket = "/var/run/nri/nri.sock"
27+
)
28+
29+
type Plugin struct {
30+
logger logger.Interface
31+
32+
stub nriplugin.Stub
33+
}
34+
35+
// NewPlugin creates a new NRI plugin for injecting CDI devices
36+
func NewPlugin(logger logger.Interface) *Plugin {
37+
return &Plugin{
38+
logger: logger,
39+
}
40+
}
41+
42+
// CreateContainer handles container creation requests.
43+
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
44+
adjust := &api.ContainerAdjustment{}
45+
46+
if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
47+
return nil, nil, err
48+
}
49+
50+
return adjust, nil, nil
51+
}
52+
53+
func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
54+
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
55+
if err != nil {
56+
return err
57+
}
58+
59+
if len(devices) == 0 {
60+
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
61+
return nil
62+
}
63+
64+
for _, name := range devices {
65+
a.AddCDIDevice(
66+
&api.CDIDevice{
67+
Name: name,
68+
},
69+
)
70+
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
71+
}
72+
73+
return nil
74+
}
75+
76+
func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
77+
var (
78+
cdiDevices []string
79+
)
80+
81+
annotation := getAnnotation(annotations, nodeResourceCDIDeviceKey, nriCDIDeviceKey, ctr)
82+
if len(annotation) == 0 {
83+
return nil, nil
84+
}
85+
86+
if err := yaml.Unmarshal(annotation, &cdiDevices); err != nil {
87+
return nil, fmt.Errorf("invalid CDI device annotation %q: %w", string(annotation), err)
88+
}
89+
90+
return cdiDevices, nil
91+
}
92+
93+
func getAnnotation(annotations map[string]string, mainKey, oldKey, ctr string) []byte {
94+
for _, key := range []string{
95+
mainKey + "/container." + ctr,
96+
oldKey + "/container." + ctr,
97+
mainKey + "/pod",
98+
oldKey + "/pod",
99+
mainKey,
100+
oldKey,
101+
} {
102+
if value, ok := annotations[key]; ok {
103+
return []byte(value)
104+
}
105+
}
106+
107+
return nil
108+
}
109+
110+
// Construct a container name for log messages.
111+
func containerName(pod *api.PodSandbox, container *api.Container) string {
112+
if pod != nil {
113+
return pod.Name + "/" + container.Name
114+
}
115+
return container.Name
116+
}
117+
118+
// Start starts the NRI plugin
119+
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
120+
if len(nriSocketPath) == 0 {
121+
nriSocketPath = defaultNRISocket
122+
}
123+
_, err := os.Stat(nriSocketPath)
124+
if err != nil {
125+
return fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
126+
}
127+
128+
var pluginOpts []nriplugin.Option
129+
pluginOpts = append(pluginOpts, nriplugin.WithPluginIdx(nriPluginIdx))
130+
pluginOpts = append(pluginOpts, nriplugin.WithSocketPath(nriSocketPath))
131+
if p.stub, err = nriplugin.New(p, pluginOpts...); err != nil {
132+
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
133+
}
134+
err = p.stub.Start(ctx)
135+
if err != nil {
136+
return fmt.Errorf("plugin exited with error: %w", err)
137+
}
138+
return nil
139+
}
140+
141+
// Stop stops the NRI plugin
142+
func (p *Plugin) Stop() {
143+
if p != nil && p.stub != nil {
144+
p.stub.Stop()
145+
}
146+
}

cmd/nvidia-ctk-installer/container/runtime/runtime.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ const (
3434
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
3535
defaultRuntimeName = "nvidia"
3636
defaultHostRootMount = "/host"
37+
defaultNRIPluginIdx = "10"
38+
defaultNRISocket = "/var/run/nri/nri.sock"
3739

3840
runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
3941
)
@@ -94,6 +96,27 @@ func Flags(opts *Options) []cli.Flag {
9496
Destination: &opts.EnableCDI,
9597
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
9698
},
99+
&cli.BoolFlag{
100+
Name: "enable-nri-in-runtime",
101+
Usage: "Enable NRI in the configured runtime",
102+
Destination: &opts.EnableNRI,
103+
Value: true,
104+
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
105+
},
106+
&cli.StringFlag{
107+
Name: "nri-plugin-index",
108+
Usage: "Specify the plugin index to register to NRI",
109+
Value: defaultNRIPluginIdx,
110+
Destination: &opts.NRIPluginIndex,
111+
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
112+
},
113+
&cli.StringFlag{
114+
Name: "nri-socket",
115+
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
116+
Value: defaultNRISocket,
117+
Destination: &opts.NRISocket,
118+
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
119+
},
97120
&cli.StringFlag{
98121
Name: "host-root",
99122
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",

cmd/nvidia-ctk-installer/main.go

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@ package main
33
import (
44
"context"
55
"fmt"
6+
"net"
67
"os"
78
"os/signal"
89
"path/filepath"
910
"syscall"
11+
"time"
1012

1113
"github.com/urfave/cli/v3"
1214
"golang.org/x/sys/unix"
1315

1416
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
17+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
1518
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
1619
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
1720
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@@ -26,6 +29,9 @@ const (
2629
toolkitSubDir = "toolkit"
2730

2831
defaultRuntime = "docker"
32+
33+
retryBackoff = 2 * time.Second
34+
maxRetryAttempts = 5
2935
)
3036

3137
var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
@@ -70,13 +76,15 @@ func main() {
7076
type app struct {
7177
logger logger.Interface
7278

73-
toolkit *toolkit.Installer
79+
nriPlugin *nri.Plugin
80+
toolkit *toolkit.Installer
7481
}
7582

7683
// NewApp creates the CLI app fro the specified options.
7784
func NewApp(logger logger.Interface) *cli.Command {
7885
a := app{
79-
logger: logger,
86+
logger: logger,
87+
nriPlugin: nri.NewPlugin(logger),
8088
}
8189
return a.build()
8290
}
@@ -93,8 +101,8 @@ func (a app) build() *cli.Command {
93101
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
94102
return ctx, a.Before(cmd, &options)
95103
},
96-
Action: func(_ context.Context, cmd *cli.Command) error {
97-
return a.Run(cmd, &options)
104+
Action: func(ctx context.Context, cmd *cli.Command) error {
105+
return a.Run(ctx, cmd, &options)
98106
},
99107
Flags: []cli.Flag{
100108
&cli.BoolFlag{
@@ -194,7 +202,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
194202
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
195203
// If the application is run as a daemon, the application waits and unconfigures
196204
// the runtime on termination.
197-
func (a *app) Run(c *cli.Command, o *options) error {
205+
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
198206
err := a.initialize(o.pidFile)
199207
if err != nil {
200208
return fmt.Errorf("unable to initialize: %v", err)
@@ -222,6 +230,15 @@ func (a *app) Run(c *cli.Command, o *options) error {
222230
}
223231

224232
if !o.noDaemon {
233+
if o.runtimeOptions.EnableNRI {
234+
if err = a.waitForRuntimeDaemon(o.runtimeOptions.Socket); err != nil {
235+
return fmt.Errorf("failed to connect to the %s daemon in %s: %v", o.runtime, o.runtimeOptions.Socket, err)
236+
}
237+
err = a.nriPlugin.Start(ctx, o.runtimeOptions.NRISocket, o.runtimeOptions.NRIPluginIndex)
238+
if err != nil {
239+
a.logger.Errorf("unable to start NRI plugin server: %v", err)
240+
}
241+
}
225242
err = a.waitForSignal()
226243
if err != nil {
227244
return fmt.Errorf("unable to wait for signal: %v", err)
@@ -287,9 +304,45 @@ func (a *app) waitForSignal() error {
287304
return nil
288305
}
289306

307+
func (a *app) waitForRuntimeDaemon(socket string) error {
308+
a.logger.Infof("Waiting for runtime daemon")
309+
310+
runtimeDaemonConnect := func() error {
311+
conn, err := net.Dial("unix", socket)
312+
a.logger.Infof("conn: %v, error: %v", conn, err)
313+
defer func() {
314+
if err := conn.Close(); err != nil {
315+
a.logger.Warningf("error closing connection: %v", err)
316+
}
317+
}()
318+
return err
319+
}
320+
var err error
321+
for i := 0; i < maxRetryAttempts; i++ {
322+
err = runtimeDaemonConnect()
323+
if err == nil {
324+
break
325+
}
326+
if i == maxRetryAttempts-1 {
327+
break
328+
}
329+
time.Sleep(retryBackoff)
330+
}
331+
if err != nil {
332+
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
333+
return err
334+
}
335+
return nil
336+
}
337+
290338
func (a *app) shutdown(pidFile string) {
291339
a.logger.Infof("Shutting Down")
292340

341+
if a.nriPlugin != nil {
342+
a.logger.Infof("Stopping NRI plugin server...")
343+
a.nriPlugin.Stop()
344+
}
345+
293346
err := os.Remove(pidFile)
294347
if err != nil {
295348
a.logger.Warningf("Unable to remove pidfile: %v", err)

cmd/nvidia-ctk-installer/main_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ version = 2
444444
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
445445
"--restart-mode=none",
446446
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
447+
"--enable-nri-in-runtime=false",
447448
}
448449

449450
err := app.Run(context.Background(), append(testArgs, tc.args...))

go.mod

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.25.0
55
require (
66
github.com/NVIDIA/go-nvlib v0.8.1
77
github.com/NVIDIA/go-nvml v0.13.0-1
8+
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
89
github.com/google/uuid v1.6.0
910
github.com/moby/sys/mountinfo v0.7.2
1011
github.com/moby/sys/reexec v0.1.0
@@ -19,24 +20,31 @@ require (
1920
github.com/urfave/cli/v3 v3.6.1
2021
golang.org/x/mod v0.30.0
2122
golang.org/x/sys v0.38.0
23+
sigs.k8s.io/yaml v1.4.0
2224
tags.cncf.io/container-device-interface v1.0.2-0.20251114135136-1b24d969689f
2325
tags.cncf.io/container-device-interface/specs-go v1.0.0
2426
)
2527

2628
require (
2729
cyphar.com/go-pathrs v0.2.1 // indirect
30+
github.com/containerd/log v0.1.0 // indirect
31+
github.com/containerd/ttrpc v1.2.7 // indirect
2832
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
2933
github.com/davecgh/go-spew v1.1.1 // indirect
3034
github.com/fsnotify/fsnotify v1.7.0 // indirect
35+
github.com/golang/protobuf v1.5.3 // indirect
3136
github.com/hashicorp/errwrap v1.1.0 // indirect
32-
github.com/kr/pretty v0.3.1 // indirect
37+
github.com/knqyf263/go-plugin v0.9.0 // indirect
38+
github.com/kr/text v0.2.0 // indirect
3339
github.com/moby/sys/capability v0.4.0 // indirect
3440
github.com/opencontainers/cgroups v0.0.4 // indirect
3541
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
3642
github.com/pmezard/go-difflib v1.0.0 // indirect
3743
github.com/rogpeppe/go-internal v1.11.0 // indirect
44+
github.com/tetratelabs/wazero v1.9.0 // indirect
3845
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
39-
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
46+
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
47+
google.golang.org/grpc v1.57.1 // indirect
48+
google.golang.org/protobuf v1.36.5 // indirect
4049
gopkg.in/yaml.v3 v3.0.1 // indirect
41-
sigs.k8s.io/yaml v1.4.0 // indirect
4250
)

0 commit comments

Comments
 (0)