mirror of
https://github.com/kubeshark/kubeshark.git
synced 2026-03-04 11:42:14 +00:00
Compare commits
1 Commits
master
...
feature/mc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0014b5ad8c |
158
cmd/mcpRunner.go
158
cmd/mcpRunner.go
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -324,6 +325,16 @@ func (s *mcpServer) invalidateHubMCPCache() {
|
||||
s.cachedHubMCP = nil
|
||||
}
|
||||
|
||||
// getBaseURL returns the hub API base URL by stripping /mcp from hubBaseURL.
|
||||
// The hub URL is always the frontend URL + /api, and hubBaseURL is frontendURL/api/mcp.
|
||||
// Ensures backend connection is established first.
|
||||
func (s *mcpServer) getBaseURL() (string, error) {
|
||||
if errMsg := s.ensureBackendConnection(); errMsg != "" {
|
||||
return "", fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
return strings.TrimSuffix(s.hubBaseURL, "/mcp"), nil
|
||||
}
|
||||
|
||||
func writeErrorToStderr(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format+"\n", args...)
|
||||
}
|
||||
@@ -379,6 +390,14 @@ func (s *mcpServer) handleRequest(req *jsonRPCRequest) {
|
||||
|
||||
func (s *mcpServer) handleInitialize(req *jsonRPCRequest) {
|
||||
var instructions string
|
||||
fileDownloadInstructions := `
|
||||
|
||||
Downloading files (e.g., PCAP exports):
|
||||
When a tool like export_snapshot_pcap returns a relative file path, you MUST use the file tools to retrieve the file:
|
||||
- get_file_url: Resolves the relative path to a full download URL you can share with the user.
|
||||
- download_file: Downloads the file to the local filesystem so it can be opened or analyzed.
|
||||
Typical workflow: call export_snapshot_pcap → receive a relative path → call download_file with that path → share the local file path with the user.`
|
||||
|
||||
if s.urlMode {
|
||||
instructions = fmt.Sprintf(`Kubeshark MCP Server - Connected to: %s
|
||||
|
||||
@@ -392,7 +411,7 @@ Available tools for traffic analysis:
|
||||
- get_api_stats: Get aggregated API statistics
|
||||
- And more - use tools/list to see all available tools
|
||||
|
||||
Use the MCP tools directly - do NOT use kubectl or curl to access Kubeshark.`, s.directURL)
|
||||
Use the MCP tools directly - do NOT use kubectl or curl to access Kubeshark.`, s.directURL) + fileDownloadInstructions
|
||||
} else if s.allowDestructive {
|
||||
instructions = `Kubeshark MCP Server - Proxy Mode (Destructive Operations ENABLED)
|
||||
|
||||
@@ -410,7 +429,7 @@ Safe operations:
|
||||
Traffic analysis tools (require Kubeshark to be running):
|
||||
- list_workloads, list_api_calls, list_l4_flows, get_api_stats, and more
|
||||
|
||||
Use the MCP tools - do NOT use kubectl, helm, or curl directly.`
|
||||
Use the MCP tools - do NOT use kubectl, helm, or curl directly.` + fileDownloadInstructions
|
||||
} else {
|
||||
instructions = `Kubeshark MCP Server - Proxy Mode (Read-Only)
|
||||
|
||||
@@ -425,7 +444,7 @@ Available operations:
|
||||
Traffic analysis tools (require Kubeshark to be running):
|
||||
- list_workloads, list_api_calls, list_l4_flows, get_api_stats, and more
|
||||
|
||||
Use the MCP tools - do NOT use kubectl, helm, or curl directly.`
|
||||
Use the MCP tools - do NOT use kubectl, helm, or curl directly.` + fileDownloadInstructions
|
||||
}
|
||||
|
||||
result := mcpInitializeResult{
|
||||
@@ -456,6 +475,40 @@ func (s *mcpServer) handleListTools(req *jsonRPCRequest) {
|
||||
}`),
|
||||
})
|
||||
|
||||
// Add file URL and download tools - available in all modes
|
||||
tools = append(tools, mcpTool{
|
||||
Name: "get_file_url",
|
||||
Description: "When a tool (e.g., export_snapshot_pcap) returns a relative file path, use this tool to resolve it into a fully-qualified download URL. The URL can be shared with the user for manual download.",
|
||||
InputSchema: json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The relative file path returned by a Hub tool (e.g., '/snapshots/abc/data.pcap')"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}`),
|
||||
})
|
||||
tools = append(tools, mcpTool{
|
||||
Name: "download_file",
|
||||
Description: "When a tool (e.g., export_snapshot_pcap) returns a relative file path, use this tool to download the file to the local filesystem. This is the preferred way to retrieve PCAP exports and other files from Kubeshark.",
|
||||
InputSchema: json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The relative file path returned by a Hub tool (e.g., '/snapshots/abc/data.pcap')"
|
||||
},
|
||||
"dest": {
|
||||
"type": "string",
|
||||
"description": "Local destination file path. If not provided, uses the filename from the path in the current directory."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}`),
|
||||
})
|
||||
|
||||
// Add destructive tools only if --allow-destructive flag was set (and not in URL mode)
|
||||
if !s.urlMode && s.allowDestructive {
|
||||
tools = append(tools, mcpTool{
|
||||
@@ -653,6 +706,20 @@ func (s *mcpServer) handleCallTool(req *jsonRPCRequest) {
|
||||
IsError: isError,
|
||||
})
|
||||
return
|
||||
case "get_file_url":
|
||||
result, isError = s.callGetFileURL(params.Arguments)
|
||||
s.sendResult(req.ID, mcpCallToolResult{
|
||||
Content: []mcpContent{{Type: "text", Text: result}},
|
||||
IsError: isError,
|
||||
})
|
||||
return
|
||||
case "download_file":
|
||||
result, isError = s.callDownloadFile(params.Arguments)
|
||||
s.sendResult(req.ID, mcpCallToolResult{
|
||||
Content: []mcpContent{{Type: "text", Text: result}},
|
||||
IsError: isError,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Forward Hub tools to the API
|
||||
@@ -706,6 +773,82 @@ func (s *mcpServer) callHubTool(toolName string, args map[string]any) (string, b
|
||||
}
|
||||
|
||||
|
||||
func (s *mcpServer) callGetFileURL(args map[string]any) (string, bool) {
|
||||
filePath, _ := args["path"].(string)
|
||||
if filePath == "" {
|
||||
return "Error: 'path' parameter is required", true
|
||||
}
|
||||
|
||||
baseURL, err := s.getBaseURL()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), true
|
||||
}
|
||||
|
||||
// Ensure path starts with /
|
||||
if !strings.HasPrefix(filePath, "/") {
|
||||
filePath = "/" + filePath
|
||||
}
|
||||
|
||||
fullURL := strings.TrimSuffix(baseURL, "/") + filePath
|
||||
return fullURL, false
|
||||
}
|
||||
|
||||
func (s *mcpServer) callDownloadFile(args map[string]any) (string, bool) {
|
||||
filePath, _ := args["path"].(string)
|
||||
if filePath == "" {
|
||||
return "Error: 'path' parameter is required", true
|
||||
}
|
||||
|
||||
baseURL, err := s.getBaseURL()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), true
|
||||
}
|
||||
|
||||
// Ensure path starts with /
|
||||
if !strings.HasPrefix(filePath, "/") {
|
||||
filePath = "/" + filePath
|
||||
}
|
||||
|
||||
fullURL := strings.TrimSuffix(baseURL, "/") + filePath
|
||||
|
||||
// Determine destination file path
|
||||
dest, _ := args["dest"].(string)
|
||||
if dest == "" {
|
||||
dest = path.Base(filePath)
|
||||
}
|
||||
|
||||
// Download the file
|
||||
resp, err := s.httpClient.Get(fullURL)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error downloading file: %v", err), true
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Sprintf("Error downloading file: HTTP %d", resp.StatusCode), true
|
||||
}
|
||||
|
||||
// Write to destination
|
||||
outFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error creating file %s: %v", dest, err), true
|
||||
}
|
||||
defer func() { _ = outFile.Close() }()
|
||||
|
||||
written, err := io.Copy(outFile, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error writing file %s: %v", dest, err), true
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"url": fullURL,
|
||||
"path": dest,
|
||||
"size": written,
|
||||
}
|
||||
resultBytes, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(resultBytes), false
|
||||
}
|
||||
|
||||
func (s *mcpServer) callStartKubeshark(args map[string]any) (string, bool) {
|
||||
// Build the kubeshark tap command
|
||||
cmdArgs := []string{"tap"}
|
||||
@@ -913,6 +1056,11 @@ func listMCPTools(directURL string) {
|
||||
fmt.Printf("URL Mode: %s\n\n", directURL)
|
||||
fmt.Println("Cluster management tools disabled (Kubeshark managed externally)")
|
||||
fmt.Println()
|
||||
fmt.Println("Local Tools:")
|
||||
fmt.Println(" check_kubeshark_status Check if Kubeshark is running")
|
||||
fmt.Println(" get_file_url Resolve a relative path to a full download URL")
|
||||
fmt.Println(" download_file Download a file from Kubeshark to local disk")
|
||||
fmt.Println()
|
||||
|
||||
hubURL := strings.TrimSuffix(directURL, "/") + "/api/mcp"
|
||||
fetchAndDisplayTools(hubURL, 30*time.Second)
|
||||
@@ -925,6 +1073,10 @@ func listMCPTools(directURL string) {
|
||||
fmt.Println(" start_kubeshark Start Kubeshark to capture traffic")
|
||||
fmt.Println(" stop_kubeshark Stop Kubeshark and clean up resources")
|
||||
fmt.Println()
|
||||
fmt.Println("File Tools:")
|
||||
fmt.Println(" get_file_url Resolve a relative path to a full download URL")
|
||||
fmt.Println(" download_file Download a file from Kubeshark to local disk")
|
||||
fmt.Println()
|
||||
|
||||
// Establish proxy connection to Kubeshark
|
||||
fmt.Println("Connecting to Kubeshark...")
|
||||
|
||||
203
cmd/mcp_test.go
203
cmd/mcp_test.go
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -126,8 +128,18 @@ func TestMCP_ToolsList_CLIOnly(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
tools := resp.Result.(map[string]any)["tools"].([]any)
|
||||
if len(tools) != 1 || tools[0].(map[string]any)["name"] != "check_kubeshark_status" {
|
||||
t.Error("Expected only check_kubeshark_status tool")
|
||||
// Should have check_kubeshark_status + get_file_url + download_file = 3 tools
|
||||
if len(tools) != 3 {
|
||||
t.Errorf("Expected 3 tools, got %d", len(tools))
|
||||
}
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.(map[string]any)["name"].(string)] = true
|
||||
}
|
||||
for _, expected := range []string{"check_kubeshark_status", "get_file_url", "download_file"} {
|
||||
if !toolNames[expected] {
|
||||
t.Errorf("Missing expected tool: %s", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,9 +175,9 @@ func TestMCP_ToolsList_WithHubBackend(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
tools := resp.Result.(map[string]any)["tools"].([]any)
|
||||
// Should have CLI tools (3) + Hub tools (2) = 5 tools
|
||||
if len(tools) < 5 {
|
||||
t.Errorf("Expected at least 5 tools, got %d", len(tools))
|
||||
// Should have CLI tools (3) + file tools (2) + Hub tools (2) = 7 tools
|
||||
if len(tools) < 7 {
|
||||
t.Errorf("Expected at least 7 tools, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,6 +475,187 @@ func TestMCP_BackendInitialization_Concurrent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_GetFileURL_ProxyMode(t *testing.T) {
|
||||
s := &mcpServer{
|
||||
httpClient: &http.Client{},
|
||||
stdin: &bytes.Buffer{},
|
||||
stdout: &bytes.Buffer{},
|
||||
hubBaseURL: "http://127.0.0.1:8899/api/mcp",
|
||||
backendInitialized: true,
|
||||
}
|
||||
resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{
|
||||
Name: "get_file_url",
|
||||
Arguments: map[string]any{"path": "/snapshots/abc/data.pcap"},
|
||||
}))
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
text := resp.Result.(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
expected := "http://127.0.0.1:8899/api/snapshots/abc/data.pcap"
|
||||
if text != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_GetFileURL_URLMode(t *testing.T) {
|
||||
s := &mcpServer{
|
||||
httpClient: &http.Client{},
|
||||
stdin: &bytes.Buffer{},
|
||||
stdout: &bytes.Buffer{},
|
||||
hubBaseURL: "https://kubeshark.example.com/api/mcp",
|
||||
backendInitialized: true,
|
||||
urlMode: true,
|
||||
directURL: "https://kubeshark.example.com",
|
||||
}
|
||||
resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{
|
||||
Name: "get_file_url",
|
||||
Arguments: map[string]any{"path": "/snapshots/xyz/export.pcap"},
|
||||
}))
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
text := resp.Result.(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
expected := "https://kubeshark.example.com/api/snapshots/xyz/export.pcap"
|
||||
if text != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_GetFileURL_MissingPath(t *testing.T) {
|
||||
s := &mcpServer{
|
||||
httpClient: &http.Client{},
|
||||
stdin: &bytes.Buffer{},
|
||||
stdout: &bytes.Buffer{},
|
||||
hubBaseURL: "http://127.0.0.1:8899/api/mcp",
|
||||
backendInitialized: true,
|
||||
}
|
||||
resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{
|
||||
Name: "get_file_url",
|
||||
Arguments: map[string]any{},
|
||||
}))
|
||||
result := resp.Result.(map[string]any)
|
||||
if !result["isError"].(bool) {
|
||||
t.Error("Expected isError=true when path is missing")
|
||||
}
|
||||
text := result["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
if !strings.Contains(text, "path") {
|
||||
t.Error("Error message should mention 'path'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_DownloadFile(t *testing.T) {
|
||||
fileContent := "test pcap data content"
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/snapshots/abc/data.pcap" {
|
||||
_, _ = w.Write([]byte(fileContent))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Use temp dir for download destination
|
||||
tmpDir := t.TempDir()
|
||||
dest := filepath.Join(tmpDir, "downloaded.pcap")
|
||||
|
||||
s := &mcpServer{
|
||||
httpClient: &http.Client{},
|
||||
stdin: &bytes.Buffer{},
|
||||
stdout: &bytes.Buffer{},
|
||||
hubBaseURL: mockServer.URL + "/api/mcp",
|
||||
backendInitialized: true,
|
||||
}
|
||||
resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{
|
||||
Name: "download_file",
|
||||
Arguments: map[string]any{"path": "/snapshots/abc/data.pcap", "dest": dest},
|
||||
}))
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
result := resp.Result.(map[string]any)
|
||||
if result["isError"] != nil && result["isError"].(bool) {
|
||||
t.Fatalf("Expected no error, got: %v", result["content"])
|
||||
}
|
||||
|
||||
text := result["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
var downloadResult map[string]any
|
||||
if err := json.Unmarshal([]byte(text), &downloadResult); err != nil {
|
||||
t.Fatalf("Failed to parse download result JSON: %v", err)
|
||||
}
|
||||
if downloadResult["path"] != dest {
|
||||
t.Errorf("Expected path %q, got %q", dest, downloadResult["path"])
|
||||
}
|
||||
if downloadResult["size"].(float64) != float64(len(fileContent)) {
|
||||
t.Errorf("Expected size %d, got %v", len(fileContent), downloadResult["size"])
|
||||
}
|
||||
|
||||
// Verify the file was actually written
|
||||
content, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
if string(content) != fileContent {
|
||||
t.Errorf("Expected file content %q, got %q", fileContent, string(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_DownloadFile_CustomDest(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("data"))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
customDest := filepath.Join(tmpDir, "custom-name.pcap")
|
||||
|
||||
s := &mcpServer{
|
||||
httpClient: &http.Client{},
|
||||
stdin: &bytes.Buffer{},
|
||||
stdout: &bytes.Buffer{},
|
||||
hubBaseURL: mockServer.URL + "/api/mcp",
|
||||
backendInitialized: true,
|
||||
}
|
||||
resp := parseResponse(t, sendRequest(s, "tools/call", 1, mcpCallToolParams{
|
||||
Name: "download_file",
|
||||
Arguments: map[string]any{"path": "/snapshots/abc/export.pcap", "dest": customDest},
|
||||
}))
|
||||
result := resp.Result.(map[string]any)
|
||||
if result["isError"] != nil && result["isError"].(bool) {
|
||||
t.Fatalf("Expected no error, got: %v", result["content"])
|
||||
}
|
||||
|
||||
text := result["content"].([]any)[0].(map[string]any)["text"].(string)
|
||||
var downloadResult map[string]any
|
||||
if err := json.Unmarshal([]byte(text), &downloadResult); err != nil {
|
||||
t.Fatalf("Failed to parse download result JSON: %v", err)
|
||||
}
|
||||
if downloadResult["path"] != customDest {
|
||||
t.Errorf("Expected path %q, got %q", customDest, downloadResult["path"])
|
||||
}
|
||||
|
||||
if _, err := os.Stat(customDest); os.IsNotExist(err) {
|
||||
t.Error("Expected file to exist at custom destination")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_ToolsList_IncludesFileTools(t *testing.T) {
|
||||
s := newTestMCPServer()
|
||||
resp := parseResponse(t, sendRequest(s, "tools/list", 1, nil))
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Unexpected error: %v", resp.Error)
|
||||
}
|
||||
tools := resp.Result.(map[string]any)["tools"].([]any)
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.(map[string]any)["name"].(string)] = true
|
||||
}
|
||||
for _, expected := range []string{"get_file_url", "download_file"} {
|
||||
if !toolNames[expected] {
|
||||
t.Errorf("Missing expected tool: %s", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCP_FullConversation(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
|
||||
Reference in New Issue
Block a user