|
4 | 4 | "archive/tar" |
5 | 5 | "bytes" |
6 | 6 | "io" |
7 | | - "os" |
8 | 7 | "path/filepath" |
9 | 8 | "testing" |
10 | 9 | ) |
@@ -145,45 +144,137 @@ func (m *mockLayer) Uncompressed() (io.ReadCloser, error) { |
145 | 144 |
|
146 | 145 | // TestExtractLayerSymlinkSafety tests that symlinks are handled safely |
147 | 146 | func TestExtractLayerSymlinkSafety(t *testing.T) { |
| 147 | + tests := []struct { |
| 148 | + name string |
| 149 | + symlinkName string |
| 150 | + symlinkDest string |
| 151 | + shouldError bool |
| 152 | + description string |
| 153 | + }{ |
| 154 | + { |
| 155 | + name: "legitimate relative symlink", |
| 156 | + symlinkName: "vectors.db/link", |
| 157 | + symlinkDest: "data.db", |
| 158 | + shouldError: false, |
| 159 | + description: "should allow relative symlinks within destination", |
| 160 | + }, |
| 161 | + { |
| 162 | + name: "absolute symlink target", |
| 163 | + symlinkName: "vectors.db/link", |
| 164 | + symlinkDest: "/etc/passwd", |
| 165 | + shouldError: true, |
| 166 | + description: "should reject absolute symlink targets", |
| 167 | + }, |
| 168 | + { |
| 169 | + name: "symlink escaping with ..", |
| 170 | + symlinkName: "vectors.db/link", |
| 171 | + symlinkDest: "../../etc/passwd", |
| 172 | + shouldError: true, |
| 173 | + description: "should reject symlinks that escape destination directory", |
| 174 | + }, |
| 175 | + { |
| 176 | + name: "symlink to parent that stays within", |
| 177 | + symlinkName: "vectors.db/subdir/link", |
| 178 | + symlinkDest: "../data.db", |
| 179 | + shouldError: false, |
| 180 | + description: "should allow .. if it resolves within destination", |
| 181 | + }, |
| 182 | + } |
| 183 | + |
| 184 | + for _, tt := range tests { |
| 185 | + t.Run(tt.name, func(t *testing.T) { |
| 186 | + destDir := t.TempDir() |
| 187 | + |
| 188 | + // Create a tar with a directory and a symlink |
| 189 | + var buf bytes.Buffer |
| 190 | + tw := tar.NewWriter(&buf) |
| 191 | + |
| 192 | + // Add the parent directory first |
| 193 | + dirHeader := &tar.Header{ |
| 194 | + Name: "vectors.db/", |
| 195 | + Mode: 0o755, |
| 196 | + Typeflag: tar.TypeDir, |
| 197 | + } |
| 198 | + if err := tw.WriteHeader(dirHeader); err != nil { |
| 199 | + t.Fatalf("failed to write directory header: %v", err) |
| 200 | + } |
| 201 | + |
| 202 | + // Add subdirectory if needed |
| 203 | + if filepath.Dir(tt.symlinkName) != "vectors.db" { |
| 204 | + subdirHeader := &tar.Header{ |
| 205 | + Name: filepath.Dir(tt.symlinkName) + "/", |
| 206 | + Mode: 0o755, |
| 207 | + Typeflag: tar.TypeDir, |
| 208 | + } |
| 209 | + if err := tw.WriteHeader(subdirHeader); err != nil { |
| 210 | + t.Fatalf("failed to write subdirectory header: %v", err) |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + // Add the symlink |
| 215 | + header := &tar.Header{ |
| 216 | + Name: tt.symlinkName, |
| 217 | + Linkname: tt.symlinkDest, |
| 218 | + Typeflag: tar.TypeSymlink, |
| 219 | + } |
| 220 | + if err := tw.WriteHeader(header); err != nil { |
| 221 | + t.Fatalf("failed to write symlink header: %v", err) |
| 222 | + } |
| 223 | + |
| 224 | + tw.Close() |
| 225 | + |
| 226 | + layer := &mockLayer{data: buf.Bytes()} |
| 227 | + |
| 228 | + // Extract and check result |
| 229 | + err := extractLayer(layer, destDir) |
| 230 | + |
| 231 | + if tt.shouldError { |
| 232 | + if err == nil { |
| 233 | + t.Errorf("%s: expected error but got none", tt.description) |
| 234 | + } |
| 235 | + } else { |
| 236 | + if err != nil { |
| 237 | + t.Errorf("%s: unexpected error: %v", tt.description, err) |
| 238 | + } |
| 239 | + } |
| 240 | + }) |
| 241 | + } |
| 242 | +} |
| 243 | + |
| 244 | +// TestExtractLayerSymlinkChaining tests protection against symlink chaining attacks |
| 245 | +func TestExtractLayerSymlinkChaining(t *testing.T) { |
148 | 246 | destDir := t.TempDir() |
149 | 247 |
|
150 | | - // Create a tar with a directory and a symlink |
| 248 | + // Create a malicious tar with symlink chaining: |
| 249 | + // 1. vectors.db/link -> .. (points outside destDir to parent directory) |
| 250 | + // 2. vectors.db/escape -> link/.. (chains through the symlink to escape further) |
151 | 251 | var buf bytes.Buffer |
152 | 252 | tw := tar.NewWriter(&buf) |
153 | 253 |
|
154 | | - // Add the parent directory first |
155 | | - dirHeader := &tar.Header{ |
156 | | - Name: "vectors.db/", |
157 | | - Mode: 0o755, |
158 | | - Typeflag: tar.TypeDir, |
159 | | - } |
160 | | - if err := tw.WriteHeader(dirHeader); err != nil { |
161 | | - t.Fatalf("failed to write directory header: %v", err) |
| 254 | + // Add directory |
| 255 | + if err := tw.WriteHeader(&tar.Header{Name: "vectors.db/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil { |
| 256 | + t.Fatalf("failed to write header: %v", err) |
162 | 257 | } |
163 | 258 |
|
164 | | - // Add a symlink |
165 | | - header := &tar.Header{ |
| 259 | + // Add first symlink that points outside: vectors.db/link -> ../.. |
| 260 | + // This creates: destDir/vectors.db/link -> ../.. which resolves to parent of destDir |
| 261 | + if err := tw.WriteHeader(&tar.Header{ |
166 | 262 | Name: "vectors.db/link", |
167 | | - Linkname: "/etc/passwd", |
| 263 | + Linkname: "../..", |
168 | 264 | Typeflag: tar.TypeSymlink, |
169 | | - } |
170 | | - if err := tw.WriteHeader(header); err != nil { |
| 265 | + }); err != nil { |
171 | 266 | t.Fatalf("failed to write symlink header: %v", err) |
172 | 267 | } |
173 | 268 |
|
174 | 269 | tw.Close() |
175 | 270 |
|
176 | 271 | layer := &mockLayer{data: buf.Bytes()} |
177 | 272 |
|
178 | | - // Extract should succeed (we extract the symlink but validate the path) |
| 273 | + // This should fail because the symlink escapes the destination directory |
179 | 274 | err := extractLayer(layer, destDir) |
180 | | - if err != nil { |
181 | | - t.Errorf("unexpected error extracting symlink: %v", err) |
182 | | - } |
183 | | - |
184 | | - // Verify the symlink was created in the destination |
185 | | - linkPath := filepath.Join(destDir, "vectors.db", "link") |
186 | | - if _, err := os.Lstat(linkPath); err != nil { |
187 | | - t.Errorf("symlink was not created: %v", err) |
| 275 | + if err == nil { |
| 276 | + t.Error("Expected error for symlink chaining attack, but extraction succeeded") |
| 277 | + } else { |
| 278 | + t.Logf("Symlink chaining attack correctly blocked: %v", err) |
188 | 279 | } |
189 | 280 | } |
0 commit comments