Skip to content

Commit 575b731

Browse files
update main packages in embeddings examples
1 parent 6ca511a commit 575b731

File tree

7 files changed

+167
-41
lines changed

7 files changed

+167
-41
lines changed

examples/embeddings/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Downloads the embeddings OCI artifact and installs the vector.db directory to `~
1010

1111
```bash
1212
# From repository root
13-
go run ./examples/embeddings/pull.go
13+
go run ./examples/embeddings/pull/main.go
1414
```
1515

1616
The Pull function will:
@@ -28,14 +28,14 @@ Creates an OCI artifact from a local vector.db directory and pushes it to a regi
2828

2929
```bash
3030
# From repository root
31-
go run ./examples/embeddings/push.go <vector-db-path> <oci-ref>
31+
go run ./examples/embeddings/push/main.go <vector-db-path> <oci-ref>
3232
```
3333

3434
### Example
3535

3636
```bash
3737
# Push the local vectors.db to your own registry
38-
go run ./examples/embeddings/push.go ~/.docker/mcp/vectors.db jimclark106/embeddings:v1.0
38+
go run ./examples/embeddings/push/main.go ~/.docker/mcp/vectors.db jimclark106/embeddings:v1.0
3939
```
4040

4141
The Push function will:
File renamed without changes.
File renamed without changes.

pkg/gateway/embeddings/oci.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,27 @@ func extractLayer(layer interface{ Uncompressed() (io.ReadCloser, error) }, dest
130130
}
131131

132132
target := filepath.Join(destDir, header.Name)
133-
// Clean the path to resolve any ".." elements and ensure it's within destDir
134-
cleanedTarget := filepath.Clean(target)
133+
134+
// Resolve any previously-extracted symbolic links in the target path
135+
// This prevents symlink chaining attacks where a symlink could be used
136+
// to escape the destination directory
137+
resolvedTarget, err := filepath.EvalSymlinks(target)
138+
if err != nil {
139+
// If EvalSymlinks fails (e.g., path doesn't exist yet), fall back to Clean
140+
// This is expected for new files/dirs that haven't been created yet
141+
resolvedTarget = filepath.Clean(target)
142+
}
143+
135144
cleanedDestDir := filepath.Clean(destDir)
136145

137-
// Use filepath.Rel to check if target is within destDir
146+
// Use filepath.Rel to check if resolved target is within destDir
138147
// If the relative path starts with "..", it's trying to escape
139-
relPath, err := filepath.Rel(cleanedDestDir, cleanedTarget)
148+
relPath, err := filepath.Rel(cleanedDestDir, resolvedTarget)
140149
if err != nil || len(relPath) == 0 || (relPath[0] == '.' && len(relPath) > 1 && relPath[1] == '.') {
141150
return fmt.Errorf("invalid tar entry path (potential path traversal): %s", header.Name)
142151
}
143152

144-
target = cleanedTarget
153+
target = filepath.Clean(target)
145154

146155
switch header.Typeflag {
147156
case tar.TypeDir:
@@ -167,7 +176,30 @@ func extractLayer(layer interface{ Uncompressed() (io.ReadCloser, error) }, dest
167176
file.Close()
168177

169178
case tar.TypeSymlink:
170-
// Handle symlinks
179+
// Handle symlinks - validate the link target to prevent symlink attacks
180+
// Reject absolute symlink targets
181+
if filepath.IsAbs(header.Linkname) {
182+
return fmt.Errorf("invalid symlink target (absolute path not allowed): %s -> %s", header.Name, header.Linkname)
183+
}
184+
185+
// Calculate where the symlink target would resolve to
186+
// The symlink is created at 'target', and points to 'header.Linkname'
187+
linkTargetPath := filepath.Join(filepath.Dir(target), header.Linkname)
188+
189+
// Resolve any previously-extracted symbolic links in the target path
190+
// This prevents symlink chaining attacks
191+
resolvedLinkTarget, err := filepath.EvalSymlinks(linkTargetPath)
192+
if err != nil {
193+
// If EvalSymlinks fails, fall back to Clean (target doesn't exist yet)
194+
resolvedLinkTarget = filepath.Clean(linkTargetPath)
195+
}
196+
197+
// Ensure the symlink target is within the destination directory
198+
relLinkPath, err := filepath.Rel(cleanedDestDir, resolvedLinkTarget)
199+
if err != nil || len(relLinkPath) == 0 || (relLinkPath[0] == '.' && len(relLinkPath) > 1 && relLinkPath[1] == '.') {
200+
return fmt.Errorf("invalid symlink target (potential path traversal): %s -> %s", header.Name, header.Linkname)
201+
}
202+
171203
if err := os.Symlink(header.Linkname, target); err != nil {
172204
return fmt.Errorf("failed to create symlink: %w", err)
173205
}

pkg/gateway/embeddings/oci_test.go

Lines changed: 115 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"archive/tar"
55
"bytes"
66
"io"
7-
"os"
87
"path/filepath"
98
"testing"
109
)
@@ -145,45 +144,137 @@ func (m *mockLayer) Uncompressed() (io.ReadCloser, error) {
145144

146145
// TestExtractLayerSymlinkSafety tests that symlinks are handled safely
147146
func TestExtractLayerSymlinkSafety(t *testing.T) {
147+
tests := []struct {
148+
name string
149+
symlinkName string
150+
symlinkDest string
151+
shouldError bool
152+
description string
153+
}{
154+
{
155+
name: "legitimate relative symlink",
156+
symlinkName: "vectors.db/link",
157+
symlinkDest: "data.db",
158+
shouldError: false,
159+
description: "should allow relative symlinks within destination",
160+
},
161+
{
162+
name: "absolute symlink target",
163+
symlinkName: "vectors.db/link",
164+
symlinkDest: "/etc/passwd",
165+
shouldError: true,
166+
description: "should reject absolute symlink targets",
167+
},
168+
{
169+
name: "symlink escaping with ..",
170+
symlinkName: "vectors.db/link",
171+
symlinkDest: "../../etc/passwd",
172+
shouldError: true,
173+
description: "should reject symlinks that escape destination directory",
174+
},
175+
{
176+
name: "symlink to parent that stays within",
177+
symlinkName: "vectors.db/subdir/link",
178+
symlinkDest: "../data.db",
179+
shouldError: false,
180+
description: "should allow .. if it resolves within destination",
181+
},
182+
}
183+
184+
for _, tt := range tests {
185+
t.Run(tt.name, func(t *testing.T) {
186+
destDir := t.TempDir()
187+
188+
// Create a tar with a directory and a symlink
189+
var buf bytes.Buffer
190+
tw := tar.NewWriter(&buf)
191+
192+
// Add the parent directory first
193+
dirHeader := &tar.Header{
194+
Name: "vectors.db/",
195+
Mode: 0o755,
196+
Typeflag: tar.TypeDir,
197+
}
198+
if err := tw.WriteHeader(dirHeader); err != nil {
199+
t.Fatalf("failed to write directory header: %v", err)
200+
}
201+
202+
// Add subdirectory if needed
203+
if filepath.Dir(tt.symlinkName) != "vectors.db" {
204+
subdirHeader := &tar.Header{
205+
Name: filepath.Dir(tt.symlinkName) + "/",
206+
Mode: 0o755,
207+
Typeflag: tar.TypeDir,
208+
}
209+
if err := tw.WriteHeader(subdirHeader); err != nil {
210+
t.Fatalf("failed to write subdirectory header: %v", err)
211+
}
212+
}
213+
214+
// Add the symlink
215+
header := &tar.Header{
216+
Name: tt.symlinkName,
217+
Linkname: tt.symlinkDest,
218+
Typeflag: tar.TypeSymlink,
219+
}
220+
if err := tw.WriteHeader(header); err != nil {
221+
t.Fatalf("failed to write symlink header: %v", err)
222+
}
223+
224+
tw.Close()
225+
226+
layer := &mockLayer{data: buf.Bytes()}
227+
228+
// Extract and check result
229+
err := extractLayer(layer, destDir)
230+
231+
if tt.shouldError {
232+
if err == nil {
233+
t.Errorf("%s: expected error but got none", tt.description)
234+
}
235+
} else {
236+
if err != nil {
237+
t.Errorf("%s: unexpected error: %v", tt.description, err)
238+
}
239+
}
240+
})
241+
}
242+
}
243+
244+
// TestExtractLayerSymlinkChaining tests protection against symlink chaining attacks
245+
func TestExtractLayerSymlinkChaining(t *testing.T) {
148246
destDir := t.TempDir()
149247

150-
// Create a tar with a directory and a symlink
248+
// Create a malicious tar with symlink chaining:
249+
// 1. vectors.db/link -> .. (points outside destDir to parent directory)
250+
// 2. vectors.db/escape -> link/.. (chains through the symlink to escape further)
151251
var buf bytes.Buffer
152252
tw := tar.NewWriter(&buf)
153253

154-
// Add the parent directory first
155-
dirHeader := &tar.Header{
156-
Name: "vectors.db/",
157-
Mode: 0o755,
158-
Typeflag: tar.TypeDir,
159-
}
160-
if err := tw.WriteHeader(dirHeader); err != nil {
161-
t.Fatalf("failed to write directory header: %v", err)
254+
// Add directory
255+
if err := tw.WriteHeader(&tar.Header{Name: "vectors.db/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil {
256+
t.Fatalf("failed to write header: %v", err)
162257
}
163258

164-
// Add a symlink
165-
header := &tar.Header{
259+
// Add first symlink that points outside: vectors.db/link -> ../..
260+
// This creates: destDir/vectors.db/link -> ../.. which resolves to parent of destDir
261+
if err := tw.WriteHeader(&tar.Header{
166262
Name: "vectors.db/link",
167-
Linkname: "/etc/passwd",
263+
Linkname: "../..",
168264
Typeflag: tar.TypeSymlink,
169-
}
170-
if err := tw.WriteHeader(header); err != nil {
265+
}); err != nil {
171266
t.Fatalf("failed to write symlink header: %v", err)
172267
}
173268

174269
tw.Close()
175270

176271
layer := &mockLayer{data: buf.Bytes()}
177272

178-
// Extract should succeed (we extract the symlink but validate the path)
273+
// This should fail because the symlink escapes the destination directory
179274
err := extractLayer(layer, destDir)
180-
if err != nil {
181-
t.Errorf("unexpected error extracting symlink: %v", err)
182-
}
183-
184-
// Verify the symlink was created in the destination
185-
linkPath := filepath.Join(destDir, "vectors.db", "link")
186-
if _, err := os.Lstat(linkPath); err != nil {
187-
t.Errorf("symlink was not created: %v", err)
275+
if err == nil {
276+
t.Error("Expected error for symlink chaining attack, but extraction succeeded")
277+
} else {
278+
t.Logf("Symlink chaining attack correctly blocked: %v", err)
188279
}
189280
}

pkg/gateway/findmcps.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func keywordStrategy(configuration Configuration) mcp.ToolHandler {
3232
return func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
3333
// Parse parameters
3434
var params struct {
35-
Prompt string `json:"prompt"`
35+
Query string `json:"query"`
3636
Limit int `json:"limit"`
3737
}
3838

@@ -49,7 +49,7 @@ func keywordStrategy(configuration Configuration) mcp.ToolHandler {
4949
return nil, fmt.Errorf("failed to parse arguments: %w", err)
5050
}
5151

52-
if params.Prompt == "" {
52+
if params.Query == "" {
5353
return nil, fmt.Errorf("query parameter is required")
5454
}
5555

@@ -58,7 +58,7 @@ func keywordStrategy(configuration Configuration) mcp.ToolHandler {
5858
}
5959

6060
// Search through the catalog servers
61-
query := strings.ToLower(strings.TrimSpace(params.Prompt))
61+
query := strings.ToLower(strings.TrimSpace(params.Query))
6262
var matches []ServerMatch
6363

6464
for serverName, server := range configuration.servers {
@@ -177,7 +177,7 @@ func keywordStrategy(configuration Configuration) mcp.ToolHandler {
177177
}
178178

179179
response := map[string]any{
180-
"prompt": params.Prompt,
180+
"prompt": params.Query,
181181
"total_matches": len(results),
182182
"servers": results,
183183
}
@@ -197,7 +197,7 @@ func embeddingStrategy(g *Gateway) mcp.ToolHandler {
197197
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
198198
// Parse parameters
199199
var params struct {
200-
Prompt string `json:"prompt"`
200+
Query string `json:"query"`
201201
Limit int `json:"limit"`
202202
}
203203

@@ -214,7 +214,7 @@ func embeddingStrategy(g *Gateway) mcp.ToolHandler {
214214
return nil, fmt.Errorf("failed to parse arguments: %w", err)
215215
}
216216

217-
if params.Prompt == "" {
217+
if params.Query == "" {
218218
return nil, fmt.Errorf("query parameter is required")
219219
}
220220

@@ -223,13 +223,13 @@ func embeddingStrategy(g *Gateway) mcp.ToolHandler {
223223
}
224224

225225
// Use vector similarity search to find relevant servers
226-
results, err := g.findServersByEmbedding(ctx, params.Prompt, params.Limit)
226+
results, err := g.findServersByEmbedding(ctx, params.Query, params.Limit)
227227
if err != nil {
228228
return nil, fmt.Errorf("failed to find servers: %w", err)
229229
}
230230

231231
response := map[string]any{
232-
"prompt": params.Prompt,
232+
"prompt": params.Query,
233233
"total_matches": len(results),
234234
"servers": results,
235235
}

pkg/gateway/registry.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
)
1717

1818
// readServersFromURL fetches and parses server definitions from a URL
19+
//
20+
//nolint:unused // TODO: This function will be used when registry import feature is enabled
1921
func (g *Gateway) readServersFromURL(ctx context.Context, url string) (map[string]catalog.Server, error) {
2022
servers := make(map[string]catalog.Server)
2123

@@ -63,6 +65,7 @@ func (g *Gateway) readServersFromURL(ctx context.Context, url string) (map[strin
6365
return nil, fmt.Errorf("unable to parse response as OCI catalog or direct catalog format")
6466
}
6567

68+
//nolint:unused // TODO: This handler will be used when registry import feature is enabled
6669
func registryImportHandler(g *Gateway, configuration Configuration) mcp.ToolHandler {
6770
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
6871
// Parse parameters

0 commit comments

Comments
 (0)