1919package shim
2020
2121import (
22+ "bufio"
2223 "context"
2324 "crypto/sha256"
2425 "fmt"
26+ "io"
27+ "math"
2528 "net"
2629 "os"
2730 "path/filepath"
2831 "runtime"
32+ "strconv"
2933 "strings"
3034 "syscall"
3135 "time"
3236
37+ "github.com/containerd/log"
38+ "github.com/mdlayher/vsock"
39+
3340 "github.com/containerd/containerd/v2/defaults"
3441 "github.com/containerd/containerd/v2/pkg/namespaces"
3542 "github.com/containerd/containerd/v2/pkg/sys"
@@ -38,6 +45,9 @@ import (
3845const (
3946 shimBinaryFormat = "containerd-shim-%s-%s"
4047 socketPathLimit = 106
48+ protoVsock = "vsock"
49+ protoHybridVsock = "hvsock"
50+ protoUnix = "unix"
4151)
4252
4353func getSysProcAttr () * syscall.SysProcAttr {
@@ -76,7 +86,21 @@ func SocketAddress(ctx context.Context, socketPath, id string) (string, error) {
7686
7787// AnonDialer returns a dialer for a socket
7888func AnonDialer (address string , timeout time.Duration ) (net.Conn , error ) {
79- return net .DialTimeout ("unix" , socket (address ).path (), timeout )
89+ proto , addr , ok := strings .Cut (address , "://" )
90+ if ! ok {
91+ return net .DialTimeout ("unix" , socket (address ).path (), timeout )
92+ }
93+ switch proto {
94+ case protoVsock :
95+ // vsock dialer can not set timeout
96+ return dialVsock (addr )
97+ case protoHybridVsock :
98+ return dialHybridVsock (addr , timeout )
99+ case protoUnix :
100+ return net .DialTimeout ("unix" , socket (address ).path (), timeout )
101+ default :
102+ return nil , fmt .Errorf ("unsupported protocol: %s" , proto )
103+ }
80104}
81105
82106// AnonReconnectDialer returns a dialer for an existing socket on reconnection
@@ -177,3 +201,88 @@ func CanConnect(address string) bool {
177201 conn .Close ()
178202 return true
179203}
204+
205+ func hybridVsockDialer (addr string , port uint64 , timeout time.Duration ) (net.Conn , error ) {
206+ timeoutCh := time .After (timeout )
207+ // Do 10 retries before timeout
208+ retryInterval := timeout / 10
209+ for {
210+ conn , err := net .DialTimeout ("unix" , addr , timeout )
211+ if err != nil {
212+ return nil , err
213+ }
214+ if _ , err = conn .Write ([]byte (fmt .Sprintf ("CONNECT %d\n " , port ))); err != nil {
215+ conn .Close ()
216+ return nil , err
217+ }
218+ errChan := make (chan error , 1 )
219+ go func () {
220+ reader := bufio .NewReader (conn )
221+ response , err := reader .ReadString ('\n' )
222+ if err != nil {
223+ errChan <- err
224+ return
225+ }
226+ if strings .Contains (response , "OK" ) {
227+ errChan <- nil
228+ } else {
229+ errChan <- fmt .Errorf ("hybrid vsock handshake response error: %s" , response )
230+ }
231+ }()
232+ select {
233+ case err = <- errChan :
234+ if err != nil {
235+ conn .Close ()
236+ // When it is EOF, maybe the server side is not ready.
237+ if err == io .EOF {
238+ log .G (context .Background ()).Warnf ("Read hybrid vsock got EOF, server may not ready" )
239+ time .Sleep (retryInterval )
240+ continue
241+ }
242+ return nil , err
243+ }
244+ return conn , nil
245+ case <- timeoutCh :
246+ conn .Close ()
247+ return nil , fmt .Errorf ("timeout waiting for hybrid vsocket handshake of %s:%d" , addr , port )
248+ }
249+ }
250+
251+ }
252+
253+ func dialVsock (address string ) (net.Conn , error ) {
254+ contextIDString , portString , ok := strings .Cut (address , ":" )
255+ if ! ok {
256+ return nil , fmt .Errorf ("invalid vsock address %s" , address )
257+ }
258+ contextID , err := strconv .ParseUint (contextIDString , 10 , 0 )
259+ if err != nil {
260+ return nil , fmt .Errorf ("failed to parse vsock context id %s, %v" , contextIDString , err )
261+ }
262+ if contextID > math .MaxUint32 {
263+ return nil , fmt .Errorf ("vsock context id %d is invalid" , contextID )
264+ }
265+ port , err := strconv .ParseUint (portString , 10 , 0 )
266+ if err != nil {
267+ return nil , fmt .Errorf ("failed to parse vsock port %s, %v" , portString , err )
268+ }
269+ if port > math .MaxUint32 {
270+ return nil , fmt .Errorf ("vsock port %d is invalid" , port )
271+ }
272+ return vsock .Dial (uint32 (contextID ), uint32 (port ), & vsock.Config {})
273+ }
274+
275+ func dialHybridVsock (address string , timeout time.Duration ) (net.Conn , error ) {
276+ addr , portString , ok := strings .Cut (address , ":" )
277+ if ! ok {
278+ return nil , fmt .Errorf ("invalid hybrid vsock address %s" , address )
279+ }
280+ port , err := strconv .ParseUint (portString , 10 , 0 )
281+ if err != nil {
282+ return nil , fmt .Errorf ("failed to parse hybrid vsock port %s, %v" , portString , err )
283+ }
284+ if port > math .MaxUint32 {
285+ return nil , fmt .Errorf ("hybrid vsock port %d is invalid" , port )
286+ }
287+ return hybridVsockDialer (addr , port , timeout )
288+ }
0 commit comments