Skip to content

Commit cbd8cf9

Browse files
committed
implement NRI plugin server to inject management CDI devices
Signed-off-by: Tariq Ibrahim <tibrahim@nvidia.com>
1 parent 34882b2 commit cbd8cf9

File tree

589 files changed

+165224
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

589 files changed

+165224
-24
lines changed

cmd/nvidia-ctk-installer/container/container.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ type Options struct {
4949
// mount.
5050
ExecutablePath string
5151
// EnabledCDI indicates whether CDI should be enabled.
52-
EnableCDI bool
53-
RuntimeName string
54-
RuntimeDir string
55-
SetAsDefault bool
56-
RestartMode string
57-
HostRootMount string
52+
EnableCDI bool
53+
EnableNRI bool
54+
RuntimeName string
55+
RuntimeDir string
56+
SetAsDefault bool
57+
RestartMode string
58+
HostRootMount string
59+
NRIPluginIndex string
60+
NRISocket string
5861

5962
ConfigSources []string
6063
}
@@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
128131
cfg.EnableCDI()
129132
}
130133

134+
if o.EnableNRI {
135+
cfg.EnableNRI()
136+
}
137+
131138
return nil
132139
}
133140

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package nri
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
8+
"github.com/containerd/nri/pkg/api"
9+
nriplugin "github.com/containerd/nri/pkg/stub"
10+
"sigs.k8s.io/yaml"
11+
12+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
13+
)
14+
15+
// Compile-time interface checks
16+
var (
17+
_ nriplugin.Plugin = (*Plugin)(nil)
18+
)
19+
20+
const (
21+
// nodeResourceCDIDeviceKey is the prefix of the key used for CDI device annotations.
22+
nodeResourceCDIDeviceKey = "cdi-devices.noderesource.dev"
23+
// nriCDIDeviceKey is the prefix of the key used for CDI device annotations.
24+
nriCDIDeviceKey = "cdi-devices.nri.io"
25+
// defaultNRISocket represents the default path of the NRI socket
26+
defaultNRISocket = "/var/run/nri/nri.sock"
27+
)
28+
29+
type Plugin struct {
30+
logger logger.Interface
31+
32+
stub nriplugin.Stub
33+
}
34+
35+
// NewPlugin creates a new NRI plugin for injecting CDI devices
36+
func NewPlugin(logger logger.Interface) *Plugin {
37+
return &Plugin{
38+
logger: logger,
39+
}
40+
}
41+
42+
// CreateContainer handles container creation requests.
43+
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
44+
adjust := &api.ContainerAdjustment{}
45+
46+
if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
47+
return nil, nil, err
48+
}
49+
50+
return adjust, nil, nil
51+
}
52+
53+
func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
54+
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
55+
if err != nil {
56+
return err
57+
}
58+
59+
if len(devices) == 0 {
60+
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
61+
return nil
62+
}
63+
64+
for _, name := range devices {
65+
a.AddCDIDevice(
66+
&api.CDIDevice{
67+
Name: name,
68+
},
69+
)
70+
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
71+
}
72+
73+
return nil
74+
}
75+
76+
func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
77+
var (
78+
cdiDevices []string
79+
)
80+
81+
annotation := getAnnotation(annotations, nodeResourceCDIDeviceKey, nriCDIDeviceKey, ctr)
82+
if len(annotation) == 0 {
83+
return nil, nil
84+
}
85+
86+
if err := yaml.Unmarshal(annotation, &cdiDevices); err != nil {
87+
return nil, fmt.Errorf("invalid CDI device annotation %q: %w", string(annotation), err)
88+
}
89+
90+
return cdiDevices, nil
91+
}
92+
93+
func getAnnotation(annotations map[string]string, mainKey, oldKey, ctr string) []byte {
94+
for _, key := range []string{
95+
mainKey + "/container." + ctr,
96+
oldKey + "/container." + ctr,
97+
mainKey + "/pod",
98+
oldKey + "/pod",
99+
mainKey,
100+
oldKey,
101+
} {
102+
if value, ok := annotations[key]; ok {
103+
return []byte(value)
104+
}
105+
}
106+
107+
return nil
108+
}
109+
110+
// Construct a container name for log messages.
111+
func containerName(pod *api.PodSandbox, container *api.Container) string {
112+
if pod != nil {
113+
return pod.Name + "/" + container.Name
114+
}
115+
return container.Name
116+
}
117+
118+
// Start starts the NRI plugin
119+
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
120+
if len(nriSocketPath) == 0 {
121+
nriSocketPath = defaultNRISocket
122+
}
123+
_, err := os.Stat(nriSocketPath)
124+
if err != nil {
125+
return fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
126+
}
127+
128+
var pluginOpts []nriplugin.Option
129+
pluginOpts = append(pluginOpts, nriplugin.WithPluginIdx(nriPluginIdx))
130+
pluginOpts = append(pluginOpts, nriplugin.WithSocketPath(nriSocketPath))
131+
if p.stub, err = nriplugin.New(p, pluginOpts...); err != nil {
132+
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
133+
}
134+
err = p.stub.Start(ctx)
135+
if err != nil {
136+
return fmt.Errorf("plugin exited with error: %w", err)
137+
}
138+
return nil
139+
}
140+
141+
// Stop stops the NRI plugin
142+
func (p *Plugin) Stop() {
143+
if p != nil && p.stub != nil {
144+
p.stub.Stop()
145+
}
146+
}

cmd/nvidia-ctk-installer/container/runtime/runtime.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ const (
3434
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
3535
defaultRuntimeName = "nvidia"
3636
defaultHostRootMount = "/host"
37+
defaultNRIPluginIdx = "10"
38+
defaultNRISocket = "/var/run/nri/nri.sock"
3739

3840
runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
3941
)
@@ -94,6 +96,27 @@ func Flags(opts *Options) []cli.Flag {
9496
Destination: &opts.EnableCDI,
9597
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
9698
},
99+
&cli.BoolFlag{
100+
Name: "enable-nri-in-runtime",
101+
Usage: "Enable NRI in the configured runtime",
102+
Destination: &opts.EnableNRI,
103+
Value: true,
104+
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
105+
},
106+
&cli.StringFlag{
107+
Name: "nri-plugin-index",
108+
Usage: "Specify the plugin index to register to NRI",
109+
Value: defaultNRIPluginIdx,
110+
Destination: &opts.NRIPluginIndex,
111+
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
112+
},
113+
&cli.StringFlag{
114+
Name: "nri-socket",
115+
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
116+
Value: defaultNRISocket,
117+
Destination: &opts.NRISocket,
118+
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
119+
},
97120
&cli.StringFlag{
98121
Name: "host-root",
99122
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",

cmd/nvidia-ctk-installer/main.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"os/signal"
88
"path/filepath"
99
"syscall"
10+
"time"
1011

1112
"github.com/urfave/cli/v3"
1213
"golang.org/x/sys/unix"
1314

1415
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
16+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
1517
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
1618
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
1719
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@@ -26,6 +28,9 @@ const (
2628
toolkitSubDir = "toolkit"
2729

2830
defaultRuntime = "docker"
31+
32+
retryBackoff = 2 * time.Second
33+
maxRetryAttempts = 5
2934
)
3035

3136
var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
@@ -70,13 +75,15 @@ func main() {
7075
type app struct {
7176
logger logger.Interface
7277

73-
toolkit *toolkit.Installer
78+
nriPlugin *nri.Plugin
79+
toolkit *toolkit.Installer
7480
}
7581

7682
// NewApp creates the CLI app fro the specified options.
7783
func NewApp(logger logger.Interface) *cli.Command {
7884
a := app{
79-
logger: logger,
85+
logger: logger,
86+
nriPlugin: nri.NewPlugin(logger),
8087
}
8188
return a.build()
8289
}
@@ -93,8 +100,8 @@ func (a app) build() *cli.Command {
93100
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
94101
return ctx, a.Before(cmd, &options)
95102
},
96-
Action: func(_ context.Context, cmd *cli.Command) error {
97-
return a.Run(cmd, &options)
103+
Action: func(ctx context.Context, cmd *cli.Command) error {
104+
return a.Run(ctx, cmd, &options)
98105
},
99106
Flags: []cli.Flag{
100107
&cli.BoolFlag{
@@ -194,7 +201,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
194201
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
195202
// If the application is run as a daemon, the application waits and unconfigures
196203
// the runtime on termination.
197-
func (a *app) Run(c *cli.Command, o *options) error {
204+
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
198205
err := a.initialize(o.pidFile)
199206
if err != nil {
200207
return fmt.Errorf("unable to initialize: %v", err)
@@ -222,6 +229,11 @@ func (a *app) Run(c *cli.Command, o *options) error {
222229
}
223230

224231
if !o.noDaemon {
232+
if o.runtimeOptions.EnableNRI {
233+
if err = a.startNRIPluginServer(ctx, o.runtimeOptions); err != nil {
234+
a.logger.Errorf("unable to start NRI plugin server: %v", err)
235+
}
236+
}
225237
err = a.waitForSignal()
226238
if err != nil {
227239
return fmt.Errorf("unable to wait for signal: %v", err)
@@ -287,9 +299,38 @@ func (a *app) waitForSignal() error {
287299
return nil
288300
}
289301

302+
func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) error {
303+
a.logger.Infof("Starting the NRI Plugin server....")
304+
305+
retriable := func() error {
306+
return a.nriPlugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
307+
}
308+
var err error
309+
for i := 0; i < maxRetryAttempts; i++ {
310+
err = retriable()
311+
if err == nil {
312+
break
313+
}
314+
if i == maxRetryAttempts-1 {
315+
break
316+
}
317+
time.Sleep(retryBackoff)
318+
}
319+
if err != nil {
320+
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
321+
return err
322+
}
323+
return nil
324+
}
325+
290326
func (a *app) shutdown(pidFile string) {
291327
a.logger.Infof("Shutting Down")
292328

329+
if a.nriPlugin != nil {
330+
a.logger.Infof("Stopping NRI plugin server...")
331+
a.nriPlugin.Stop()
332+
}
333+
293334
err := os.Remove(pidFile)
294335
if err != nil {
295336
a.logger.Warningf("Unable to remove pidfile: %v", err)

cmd/nvidia-ctk-installer/main_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ version = 2
444444
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
445445
"--restart-mode=none",
446446
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
447+
"--enable-nri-in-runtime=false",
447448
}
448449

449450
err := app.Run(context.Background(), append(testArgs, tc.args...))

go.mod

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.25.0
55
require (
66
github.com/NVIDIA/go-nvlib v0.8.1
77
github.com/NVIDIA/go-nvml v0.13.0-1
8+
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
89
github.com/google/uuid v1.6.0
910
github.com/moby/sys/mountinfo v0.7.2
1011
github.com/moby/sys/reexec v0.1.0
@@ -19,24 +20,31 @@ require (
1920
github.com/urfave/cli/v3 v3.6.1
2021
golang.org/x/mod v0.30.0
2122
golang.org/x/sys v0.38.0
23+
sigs.k8s.io/yaml v1.4.0
2224
tags.cncf.io/container-device-interface v1.0.2-0.20251114135136-1b24d969689f
2325
tags.cncf.io/container-device-interface/specs-go v1.0.0
2426
)
2527

2628
require (
2729
cyphar.com/go-pathrs v0.2.1 // indirect
30+
github.com/containerd/log v0.1.0 // indirect
31+
github.com/containerd/ttrpc v1.2.7 // indirect
2832
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
2933
github.com/davecgh/go-spew v1.1.1 // indirect
3034
github.com/fsnotify/fsnotify v1.7.0 // indirect
35+
github.com/golang/protobuf v1.5.3 // indirect
3136
github.com/hashicorp/errwrap v1.1.0 // indirect
32-
github.com/kr/pretty v0.3.1 // indirect
37+
github.com/knqyf263/go-plugin v0.9.0 // indirect
38+
github.com/kr/text v0.2.0 // indirect
3339
github.com/moby/sys/capability v0.4.0 // indirect
3440
github.com/opencontainers/cgroups v0.0.4 // indirect
3541
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
3642
github.com/pmezard/go-difflib v1.0.0 // indirect
3743
github.com/rogpeppe/go-internal v1.11.0 // indirect
44+
github.com/tetratelabs/wazero v1.9.0 // indirect
3845
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
39-
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
46+
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
47+
google.golang.org/grpc v1.57.1 // indirect
48+
google.golang.org/protobuf v1.36.5 // indirect
4049
gopkg.in/yaml.v3 v3.0.1 // indirect
41-
sigs.k8s.io/yaml v1.4.0 // indirect
4250
)

0 commit comments

Comments
 (0)