diff --git a/cmd/mcpRunner.go b/cmd/mcpRunner.go index 2a02102b2..ebe8c55bb 100644 --- a/cmd/mcpRunner.go +++ b/cmd/mcpRunner.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "os/exec" + "path" "strings" "sync" "time" @@ -324,6 +325,16 @@ func (s *mcpServer) invalidateHubMCPCache() { s.cachedHubMCP = nil } +// getBaseURL returns the hub API base URL by stripping /mcp from hubBaseURL. +// The hub URL is always the frontend URL + /api, and hubBaseURL is frontendURL/api/mcp. +// Ensures backend connection is established first. +func (s *mcpServer) getBaseURL() (string, error) { + if errMsg := s.ensureBackendConnection(); errMsg != "" { + return "", fmt.Errorf("%s", errMsg) + } + return strings.TrimSuffix(s.hubBaseURL, "/mcp"), nil +} + func writeErrorToStderr(format string, args ...any) { fmt.Fprintf(os.Stderr, format+"\n", args...) } @@ -379,6 +390,14 @@ func (s *mcpServer) handleRequest(req *jsonRPCRequest) { func (s *mcpServer) handleInitialize(req *jsonRPCRequest) { var instructions string + fileDownloadInstructions := ` + +Downloading files (e.g., PCAP exports): +When a tool like export_snapshot_pcap returns a relative file path, you MUST use the file tools to retrieve the file: +- get_file_url: Resolves the relative path to a full download URL you can share with the user. +- download_file: Downloads the file to the local filesystem so it can be opened or analyzed. +Typical workflow: call export_snapshot_pcap → receive a relative path → call download_file with that path → share the local file path with the user.` + if s.urlMode { instructions = fmt.Sprintf(`Kubeshark MCP Server - Connected to: %s @@ -392,7 +411,7 @@ Available tools for traffic analysis: - get_api_stats: Get aggregated API statistics - And more - use tools/list to see all available tools -Use the MCP tools directly - do NOT use kubectl or curl to access Kubeshark.`, s.directURL) +Use the MCP tools directly - do NOT use kubectl or curl to access Kubeshark.`, s.directURL) + fileDownloadInstructions } else if s.allowDestructive { instructions = `Kubeshark MCP Server - Proxy Mode (Destructive Operations ENABLED) @@ -410,7 +429,7 @@ Safe operations: Traffic analysis tools (require Kubeshark to be running): - list_workloads, list_api_calls, list_l4_flows, get_api_stats, and more -Use the MCP tools - do NOT use kubectl, helm, or curl directly.` +Use the MCP tools - do NOT use kubectl, helm, or curl directly.` + fileDownloadInstructions } else { instructions = `Kubeshark MCP Server - Proxy Mode (Read-Only) @@ -425,7 +444,7 @@ Available operations: Traffic analysis tools (require Kubeshark to be running): - list_workloads, list_api_calls, list_l4_flows, get_api_stats, and more -Use the MCP tools - do NOT use kubectl, helm, or curl directly.` +Use the MCP tools - do NOT use kubectl, helm, or curl directly.` + fileDownloadInstructions } result := mcpInitializeResult{ @@ -456,6 +475,40 @@ func (s *mcpServer) handleListTools(req *jsonRPCRequest) { }`), }) + // Add file URL and download tools - available in all modes + tools = append(tools, mcpTool{ + Name: "get_file_url", + Description: "When a tool (e.g., export_snapshot_pcap) returns a relative file path, use this tool to resolve it into a fully-qualified download URL. The URL can be shared with the user for manual download.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The relative file path returned by a Hub tool (e.g., '/snapshots/abc/data.pcap')" + } + }, + "required": ["path"] + }`), + }) + tools = append(tools, mcpTool{ + Name: "download_file", + Description: "When a tool (e.g., export_snapshot_pcap) returns a relative file path, use this tool to download the file to the local filesystem. This is the preferred way to retrieve PCAP exports and other files from Kubeshark.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The relative file path returned by a Hub tool (e.g., '/snapshots/abc/data.pcap')" + }, + "dest": { + "type": "string", + "description": "Local destination file path. If not provided, uses the filename from the path in the current directory." + } + }, + "required": ["path"] + }`), + }) + // Add destructive tools only if --allow-destructive flag was set (and not in URL mode) if !s.urlMode && s.allowDestructive { tools = append(tools, mcpTool{ @@ -653,6 +706,20 @@ func (s *mcpServer) handleCallTool(req *jsonRPCRequest) { IsError: isError, }) return + case "get_file_url": + result, isError = s.callGetFileURL(params.Arguments) + s.sendResult(req.ID, mcpCallToolResult{ + Content: []mcpContent{{Type: "text", Text: result}}, + IsError: isError, + }) + return + case "download_file": + result, isError = s.callDownloadFile(params.Arguments) + s.sendResult(req.ID, mcpCallToolResult{ + Content: []mcpContent{{Type: "text", Text: result}}, + IsError: isError, + }) + return } // Forward Hub tools to the API @@ -706,6 +773,91 @@ func (s *mcpServer) callHubTool(toolName string, args map[string]any) (string, b } +func (s *mcpServer) callGetFileURL(args map[string]any) (string, bool) { + filePath, _ := args["path"].(string) + if filePath == "" { + return "Error: 'path' parameter is required", true + } + + baseURL, err := s.getBaseURL() + if err != nil { + return fmt.Sprintf("Error: %v", err), true + } + + // Ensure path starts with / + if !strings.HasPrefix(filePath, "/") { + filePath = "/" + filePath + } + + fullURL := strings.TrimSuffix(baseURL, "/") + filePath + return fullURL, false +} + +func (s *mcpServer) callDownloadFile(args map[string]any) (string, bool) { + filePath, _ := args["path"].(string) + if filePath == "" { + return "Error: 'path' parameter is required", true + } + + baseURL, err := s.getBaseURL() + if err != nil { + return fmt.Sprintf("Error: %v", err), true + } + + // Ensure path starts with / + if !strings.HasPrefix(filePath, "/") { + filePath = "/" + filePath + } + + fullURL := strings.TrimSuffix(baseURL, "/") + filePath + + // Determine destination file path + dest, _ := args["dest"].(string) + if dest == "" { + dest = path.Base(filePath) + } + + // Use a dedicated HTTP client for file downloads. + // The default s.httpClient has a 30s total timeout which would fail for large files (up to 10GB). + // This client sets only connection-level timeouts and lets the body stream without a deadline. + downloadClient := &http.Client{ + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + }, + } + + resp, err := downloadClient.Get(fullURL) + if err != nil { + return fmt.Sprintf("Error downloading file: %v", err), true + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return fmt.Sprintf("Error downloading file: HTTP %d", resp.StatusCode), true + } + + // Write to destination + outFile, err := os.Create(dest) + if err != nil { + return fmt.Sprintf("Error creating file %s: %v", dest, err), true + } + defer func() { _ = outFile.Close() }() + + written, err := io.Copy(outFile, resp.Body) + if err != nil { + return fmt.Sprintf("Error writing file %s: %v", dest, err), true + } + + result := map[string]any{ + "url": fullURL, + "path": dest, + "size": written, + } + resultBytes, _ := json.MarshalIndent(result, "", " ") + return string(resultBytes), false +} + func (s *mcpServer) callStartKubeshark(args map[string]any) (string, bool) { // Build the kubeshark tap command cmdArgs := []string{"tap"} @@ -913,6 +1065,11 @@ func listMCPTools(directURL string) { fmt.Printf("URL Mode: %s\n\n", directURL) fmt.Println("Cluster management tools disabled (Kubeshark managed externally)") fmt.Println() + fmt.Println("Local Tools:") + fmt.Println(" check_kubeshark_status Check if Kubeshark is running") + fmt.Println(" get_file_url Resolve a relative path to a full download URL") + fmt.Println(" download_file Download a file from Kubeshark to local disk") + fmt.Println() hubURL := strings.TrimSuffix(directURL, "/") + "/api/mcp" fetchAndDisplayTools(hubURL, 30*time.Second) @@ -925,6 +1082,10 @@ func listMCPTools(directURL string) { fmt.Println(" start_kubeshark Start Kubeshark to capture traffic") fmt.Println(" stop_kubeshark Stop Kubeshark and clean up resources") fmt.Println() + fmt.Println("File Tools:") + fmt.Println(" get_file_url Resolve a relative path to a full download URL") + fmt.Println(" download_file Download a file from Kubeshark to local disk") + fmt.Println() // Establish proxy connection to Kubeshark fmt.Println("Connecting to Kubeshark...") diff --git a/cmd/mcp_test.go b/cmd/mcp_test.go index d7aad1367..62096ed3c 100644 --- a/cmd/mcp_test.go +++ b/cmd/mcp_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" ) @@ -126,8 +128,18 @@ func TestMCP_ToolsList_CLIOnly(t *testing.T) { t.Fatalf("Unexpected error: %v", resp.Error) } tools := resp.Result.(map[string]any)["tools"].([]any) - if len(tools) != 1 || tools[0].(map[string]any)["name"] != "check_kubeshark_status" { - t.Error("Expected only check_kubeshark_status tool") + // Should have check_kubeshark_status + get_file_url + download_file = 3 tools + if len(tools) != 3 { + t.Errorf("Expected 3 tools, got %d", len(tools)) + } + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.(map[string]any)["name"].(string)] = true + } + for _, expected := range []string{"check_kubeshark_status", "get_file_url", "download_file"} { + if !toolNames[expected] { + t.Errorf("Missing expected tool: %s", expected) + } } } @@ -163,9 +175,9 @@ func TestMCP_ToolsList_WithHubBackend(t *testing.T) { t.Fatalf("Unexpected error: %v", resp.Error) } tools := resp.Result.(map[string]any)["tools"].([]any) - // Should have CLI tools (3) + Hub tools (2) = 5 tools - if len(tools) < 5 { - t.Errorf("Expected at least 5 tools, got %d", len(tools)) + // Should have CLI tools (3) + file tools (2) + Hub tools (2) = 7 tools + if len(tools) < 7 { + t.Errorf("Expected at least 7 tools, got %d", len(tools)) } } @@ -463,6 +475,187 @@ func TestMCP_BackendInitialization_Concurrent(t *testing.T) { } } +func TestMCP_GetFileURL_ProxyMode(t *testing.T) { + s := &mcpServer{ + httpClient: &http.Client{}, + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + hubBaseURL: "http://127.0.0.1:8899/api/mcp", + backendInitialized: true, + } + resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{ + Name: "get_file_url", + Arguments: map[string]any{"path": "/snapshots/abc/data.pcap"}, + })) + if resp.Error != nil { + t.Fatalf("Unexpected error: %v", resp.Error) + } + text := resp.Result.(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string) + expected := "http://127.0.0.1:8899/api/snapshots/abc/data.pcap" + if text != expected { + t.Errorf("Expected %q, got %q", expected, text) + } +} + +func TestMCP_GetFileURL_URLMode(t *testing.T) { + s := &mcpServer{ + httpClient: &http.Client{}, + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + hubBaseURL: "https://kubeshark.example.com/api/mcp", + backendInitialized: true, + urlMode: true, + directURL: "https://kubeshark.example.com", + } + resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{ + Name: "get_file_url", + Arguments: map[string]any{"path": "/snapshots/xyz/export.pcap"}, + })) + if resp.Error != nil { + t.Fatalf("Unexpected error: %v", resp.Error) + } + text := resp.Result.(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string) + expected := "https://kubeshark.example.com/api/snapshots/xyz/export.pcap" + if text != expected { + t.Errorf("Expected %q, got %q", expected, text) + } +} + +func TestMCP_GetFileURL_MissingPath(t *testing.T) { + s := &mcpServer{ + httpClient: &http.Client{}, + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + hubBaseURL: "http://127.0.0.1:8899/api/mcp", + backendInitialized: true, + } + resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{ + Name: "get_file_url", + Arguments: map[string]any{}, + })) + result := resp.Result.(map[string]any) + if !result["isError"].(bool) { + t.Error("Expected isError=true when path is missing") + } + text := result["content"].([]any)[0].(map[string]any)["text"].(string) + if !strings.Contains(text, "path") { + t.Error("Error message should mention 'path'") + } +} + +func TestMCP_DownloadFile(t *testing.T) { + fileContent := "test pcap data content" + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/snapshots/abc/data.pcap" { + _, _ = w.Write([]byte(fileContent)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer mockServer.Close() + + // Use temp dir for download destination + tmpDir := t.TempDir() + dest := filepath.Join(tmpDir, "downloaded.pcap") + + s := &mcpServer{ + httpClient: &http.Client{}, + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + hubBaseURL: mockServer.URL + "/api/mcp", + backendInitialized: true, + } + resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{ + Name: "download_file", + Arguments: map[string]any{"path": "/snapshots/abc/data.pcap", "dest": dest}, + })) + if resp.Error != nil { + t.Fatalf("Unexpected error: %v", resp.Error) + } + result := resp.Result.(map[string]any) + if result["isError"] != nil && result["isError"].(bool) { + t.Fatalf("Expected no error, got: %v", result["content"]) + } + + text := result["content"].([]any)[0].(map[string]any)["text"].(string) + var downloadResult map[string]any + if err := json.Unmarshal([]byte(text), &downloadResult); err != nil { + t.Fatalf("Failed to parse download result JSON: %v", err) + } + if downloadResult["path"] != dest { + t.Errorf("Expected path %q, got %q", dest, downloadResult["path"]) + } + if downloadResult["size"].(float64) != float64(len(fileContent)) { + t.Errorf("Expected size %d, got %v", len(fileContent), downloadResult["size"]) + } + + // Verify the file was actually written + content, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + if string(content) != fileContent { + t.Errorf("Expected file content %q, got %q", fileContent, string(content)) + } +} + +func TestMCP_DownloadFile_CustomDest(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("data")) + })) + defer mockServer.Close() + + tmpDir := t.TempDir() + customDest := filepath.Join(tmpDir, "custom-name.pcap") + + s := &mcpServer{ + httpClient: &http.Client{}, + stdin: &bytes.Buffer{}, + stdout: &bytes.Buffer{}, + hubBaseURL: mockServer.URL + "/api/mcp", + backendInitialized: true, + } + resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{ + Name: "download_file", + Arguments: map[string]any{"path": "/snapshots/abc/export.pcap", "dest": customDest}, + })) + result := resp.Result.(map[string]any) + if result["isError"] != nil && result["isError"].(bool) { + t.Fatalf("Expected no error, got: %v", result["content"]) + } + + text := result["content"].([]any)[0].(map[string]any)["text"].(string) + var downloadResult map[string]any + if err := json.Unmarshal([]byte(text), &downloadResult); err != nil { + t.Fatalf("Failed to parse download result JSON: %v", err) + } + if downloadResult["path"] != customDest { + t.Errorf("Expected path %q, got %q", customDest, downloadResult["path"]) + } + + if _, err := os.Stat(customDest); os.IsNotExist(err) { + t.Error("Expected file to exist at custom destination") + } +} + +func TestMCP_ToolsList_IncludesFileTools(t *testing.T) { + s := newTestMCPServer() + resp := parseResponse(t, sendRequest(s, "tools/list", 1, nil)) + if resp.Error != nil { + t.Fatalf("Unexpected error: %v", resp.Error) + } + tools := resp.Result.(map[string]any)["tools"].([]any) + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.(map[string]any)["name"].(string)] = true + } + for _, expected := range []string{"get_file_url", "download_file"} { + if !toolNames[expected] { + t.Errorf("Missing expected tool: %s", expected) + } + } +} + func TestMCP_FullConversation(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" {