mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-07-10 13:54:20 +00:00
* feat: first mcp impl Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: update Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: wip Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: switcheed to stdio transport Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: readme Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat: fix the linter 🤖 Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat: fix the linter 🤖 Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat(mcp): implement MCP server and handler - Implement MCP server and handler - Add MCP server to serve - Add MCP handler to handle MCP requests - Add MCP server to serve - Add MCP handler to handle MCP requests Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat: consolidating code duplication Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat: added http sse support Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: fixed broken tests Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: updated and fixed linter Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: updated and fixed linter Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: updated the linter issues Signed-off-by: Alex Jones <alexsimonjones@gmail.com> --------- Signed-off-by: Alex Jones <alexsimonjones@gmail.com>
417 lines
13 KiB
Go
417 lines
13 KiB
Go
/*
|
|
Copyright 2024 The K8sGPT Authors.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
schemav1 "buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go/schema/v1"
|
|
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
|
"github.com/k8sgpt-ai/k8sgpt/pkg/analysis"
|
|
"github.com/k8sgpt-ai/k8sgpt/pkg/kubernetes"
|
|
"github.com/k8sgpt-ai/k8sgpt/pkg/server/config"
|
|
mcp_golang "github.com/metoro-io/mcp-golang"
|
|
"github.com/metoro-io/mcp-golang/transport/stdio"
|
|
"github.com/spf13/viper"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// MCPServer represents an MCP server for k8sgpt
|
|
type MCPServer struct {
|
|
server *mcp_golang.Server
|
|
port string
|
|
aiProvider *ai.AIProvider
|
|
useHTTP bool
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewMCPServer creates a new MCP server
|
|
func NewMCPServer(port string, aiProvider *ai.AIProvider, useHTTP bool, logger *zap.Logger) (*MCPServer, error) {
|
|
// Create MCP server with stdio transport
|
|
transport := stdio.NewStdioServerTransport()
|
|
|
|
server := mcp_golang.NewServer(transport)
|
|
|
|
return &MCPServer{
|
|
server: server,
|
|
port: port,
|
|
aiProvider: aiProvider,
|
|
useHTTP: useHTTP,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
// Start starts the MCP server
|
|
func (s *MCPServer) Start() error {
|
|
if s.server == nil {
|
|
return fmt.Errorf("server not initialized")
|
|
}
|
|
|
|
// Register analyze tool
|
|
if err := s.server.RegisterTool("analyze", "Analyze Kubernetes resources", s.handleAnalyze); err != nil {
|
|
return fmt.Errorf("failed to register analyze tool: %v", err)
|
|
}
|
|
|
|
// Register cluster info tool
|
|
if err := s.server.RegisterTool("cluster-info", "Get Kubernetes cluster information", s.handleClusterInfo); err != nil {
|
|
return fmt.Errorf("failed to register cluster-info tool: %v", err)
|
|
}
|
|
|
|
// Register config tool
|
|
if err := s.server.RegisterTool("config", "Configure K8sGPT settings", s.handleConfig); err != nil {
|
|
return fmt.Errorf("failed to register config tool: %v", err)
|
|
}
|
|
|
|
// Register resources
|
|
if err := s.registerResources(); err != nil {
|
|
return fmt.Errorf("failed to register resources: %v", err)
|
|
}
|
|
|
|
// Register prompts
|
|
if err := s.registerPrompts(); err != nil {
|
|
return fmt.Errorf("failed to register prompts: %v", err)
|
|
}
|
|
|
|
if s.useHTTP {
|
|
// Start HTTP server
|
|
go func() {
|
|
http.HandleFunc("/mcp/analyze", s.handleAnalyzeHTTP)
|
|
http.HandleFunc("/mcp", s.handleSSE)
|
|
s.logger.Info("Starting MCP server on port", zap.String("port", s.port))
|
|
if err := http.ListenAndServe(fmt.Sprintf(":%s", s.port), nil); err != nil {
|
|
s.logger.Error("Error starting HTTP server", zap.Error(err))
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Start the server
|
|
return s.server.Serve()
|
|
}
|
|
|
|
// AnalyzeRequest represents the input parameters for the analyze tool
|
|
type AnalyzeRequest struct {
|
|
Namespace string `json:"namespace,omitempty"`
|
|
Backend string `json:"backend,omitempty"`
|
|
Language string `json:"language,omitempty"`
|
|
Filters []string `json:"filters,omitempty"`
|
|
LabelSelector string `json:"labelSelector,omitempty"`
|
|
NoCache bool `json:"noCache,omitempty"`
|
|
Explain bool `json:"explain,omitempty"`
|
|
MaxConcurrency int `json:"maxConcurrency,omitempty"`
|
|
WithDoc bool `json:"withDoc,omitempty"`
|
|
InteractiveMode bool `json:"interactiveMode,omitempty"`
|
|
CustomHeaders []string `json:"customHeaders,omitempty"`
|
|
WithStats bool `json:"withStats,omitempty"`
|
|
}
|
|
|
|
// AnalyzeResponse represents the output of the analyze tool
|
|
type AnalyzeResponse struct {
|
|
Results string `json:"results"`
|
|
}
|
|
|
|
// ClusterInfoRequest represents the input parameters for the cluster-info tool
|
|
type ClusterInfoRequest struct {
|
|
// Empty struct as we don't need any input parameters
|
|
}
|
|
|
|
// ClusterInfoResponse represents the output of the cluster-info tool
|
|
type ClusterInfoResponse struct {
|
|
Info string `json:"info"`
|
|
}
|
|
|
|
// ConfigRequest represents the input parameters for the config tool
|
|
type ConfigRequest struct {
|
|
CustomAnalyzers []struct {
|
|
Name string `json:"name"`
|
|
Connection struct {
|
|
Url string `json:"url"`
|
|
Port int `json:"port"`
|
|
} `json:"connection"`
|
|
} `json:"customAnalyzers,omitempty"`
|
|
Cache struct {
|
|
Type string `json:"type"`
|
|
// S3 specific fields
|
|
BucketName string `json:"bucketName,omitempty"`
|
|
Region string `json:"region,omitempty"`
|
|
Endpoint string `json:"endpoint,omitempty"`
|
|
Insecure bool `json:"insecure,omitempty"`
|
|
// Azure specific fields
|
|
StorageAccount string `json:"storageAccount,omitempty"`
|
|
ContainerName string `json:"containerName,omitempty"`
|
|
// GCS specific fields
|
|
ProjectId string `json:"projectId,omitempty"`
|
|
} `json:"cache,omitempty"`
|
|
}
|
|
|
|
// ConfigResponse represents the output of the config tool
|
|
type ConfigResponse struct {
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
// handleAnalyze handles the analyze tool
|
|
func (s *MCPServer) handleAnalyze(ctx context.Context, request *AnalyzeRequest) (*mcp_golang.ToolResponse, error) {
|
|
// Get stored configuration
|
|
var configAI ai.AIConfiguration
|
|
if err := viper.UnmarshalKey("ai", &configAI); err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("Failed to load AI configuration: %v", err))), nil
|
|
}
|
|
// Use stored configuration if not specified in request
|
|
if request.Backend == "" {
|
|
if configAI.DefaultProvider != "" {
|
|
request.Backend = configAI.DefaultProvider
|
|
} else if len(configAI.Providers) > 0 {
|
|
request.Backend = configAI.Providers[0].Name
|
|
} else {
|
|
request.Backend = "openai" // fallback default
|
|
}
|
|
}
|
|
|
|
request.Explain = true
|
|
// Get stored filters if not specified
|
|
if len(request.Filters) == 0 {
|
|
request.Filters = viper.GetStringSlice("active_filters")
|
|
}
|
|
|
|
// Validate MaxConcurrency to prevent excessive memory allocation
|
|
request.MaxConcurrency = validateMaxConcurrency(request.MaxConcurrency)
|
|
|
|
// Create a new analysis with the request parameters
|
|
analysis, err := analysis.NewAnalysis(
|
|
request.Backend,
|
|
request.Language,
|
|
request.Filters,
|
|
request.Namespace,
|
|
request.LabelSelector,
|
|
request.NoCache,
|
|
request.Explain,
|
|
request.MaxConcurrency,
|
|
request.WithDoc,
|
|
request.InteractiveMode,
|
|
request.CustomHeaders,
|
|
request.WithStats,
|
|
)
|
|
if err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("Failed to create analysis: %v", err))), nil
|
|
}
|
|
defer analysis.Close()
|
|
|
|
// Run the analysis
|
|
analysis.RunAnalysis()
|
|
|
|
// Get the output
|
|
output, err := analysis.PrintOutput("json")
|
|
if err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("Failed to print output: %v", err))), nil
|
|
}
|
|
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(string(output))), nil
|
|
}
|
|
|
|
// validateMaxConcurrency validates and bounds the MaxConcurrency parameter
|
|
func validateMaxConcurrency(maxConcurrency int) int {
|
|
const maxAllowedConcurrency = 100
|
|
if maxConcurrency <= 0 {
|
|
return 10 // Default value if not set
|
|
} else if maxConcurrency > maxAllowedConcurrency {
|
|
return maxAllowedConcurrency // Cap at a reasonable maximum
|
|
}
|
|
return maxConcurrency
|
|
}
|
|
|
|
// handleClusterInfo handles the cluster-info tool
|
|
func (s *MCPServer) handleClusterInfo(ctx context.Context, request *ClusterInfoRequest) (*mcp_golang.ToolResponse, error) {
|
|
// Create a new Kubernetes client
|
|
client, err := kubernetes.NewClient("", "")
|
|
if err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("failed to create Kubernetes client: %v", err))), nil
|
|
}
|
|
|
|
// Get cluster info from the client
|
|
version, err := client.Client.Discovery().ServerVersion()
|
|
if err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("failed to get cluster version: %v", err))), nil
|
|
}
|
|
|
|
info := fmt.Sprintf("Kubernetes %s", version.GitVersion)
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(info)), nil
|
|
}
|
|
|
|
// handleConfig handles the config tool
|
|
func (s *MCPServer) handleConfig(ctx context.Context, request *ConfigRequest) (*mcp_golang.ToolResponse, error) {
|
|
// Create a new config handler
|
|
handler := &config.Handler{}
|
|
|
|
// Convert request to AddConfigRequest
|
|
addConfigReq := &schemav1.AddConfigRequest{
|
|
CustomAnalyzers: make([]*schemav1.CustomAnalyzer, 0),
|
|
}
|
|
|
|
// Add custom analyzers if present
|
|
if len(request.CustomAnalyzers) > 0 {
|
|
for _, ca := range request.CustomAnalyzers {
|
|
addConfigReq.CustomAnalyzers = append(addConfigReq.CustomAnalyzers, &schemav1.CustomAnalyzer{
|
|
Name: ca.Name,
|
|
Connection: &schemav1.Connection{
|
|
Url: ca.Connection.Url,
|
|
Port: fmt.Sprintf("%d", ca.Connection.Port),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add cache configuration if present
|
|
if request.Cache.Type != "" {
|
|
cacheConfig := &schemav1.Cache{}
|
|
switch request.Cache.Type {
|
|
case "s3":
|
|
cacheConfig.CacheType = &schemav1.Cache_S3Cache{
|
|
S3Cache: &schemav1.S3Cache{
|
|
BucketName: request.Cache.BucketName,
|
|
Region: request.Cache.Region,
|
|
Endpoint: request.Cache.Endpoint,
|
|
Insecure: request.Cache.Insecure,
|
|
},
|
|
}
|
|
case "azure":
|
|
cacheConfig.CacheType = &schemav1.Cache_AzureCache{
|
|
AzureCache: &schemav1.AzureCache{
|
|
StorageAccount: request.Cache.StorageAccount,
|
|
ContainerName: request.Cache.ContainerName,
|
|
},
|
|
}
|
|
case "gcs":
|
|
cacheConfig.CacheType = &schemav1.Cache_GcsCache{
|
|
GcsCache: &schemav1.GCSCache{
|
|
BucketName: request.Cache.BucketName,
|
|
Region: request.Cache.Region,
|
|
ProjectId: request.Cache.ProjectId,
|
|
},
|
|
}
|
|
}
|
|
addConfigReq.Cache = cacheConfig
|
|
}
|
|
|
|
// Apply the configuration using the shared function
|
|
if err := handler.ApplyConfig(ctx, addConfigReq); err != nil {
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(fmt.Sprintf("Failed to add config: %v", err))), nil
|
|
}
|
|
|
|
return mcp_golang.NewToolResponse(mcp_golang.NewTextContent("Successfully added configuration")), nil
|
|
}
|
|
|
|
// registerPrompts registers the prompts for the MCP server
|
|
func (s *MCPServer) registerPrompts() error {
|
|
// Register any prompts needed for the MCP server
|
|
return nil
|
|
}
|
|
|
|
// registerResources registers the resources for the MCP server
|
|
func (s *MCPServer) registerResources() error {
|
|
if err := s.server.RegisterResource("cluster-info", "Get cluster information", "Get information about the Kubernetes cluster", "text", s.getClusterInfo); err != nil {
|
|
return fmt.Errorf("failed to register cluster-info resource: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *MCPServer) getClusterInfo(ctx context.Context) (interface{}, error) {
|
|
// Create a new Kubernetes client
|
|
client, err := kubernetes.NewClient("", "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Kubernetes client: %v", err)
|
|
}
|
|
|
|
// Get cluster info from the client
|
|
version, err := client.Client.Discovery().ServerVersion()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get cluster version: %v", err)
|
|
}
|
|
|
|
return map[string]string{
|
|
"version": version.String(),
|
|
"platform": version.Platform,
|
|
"gitVersion": version.GitVersion,
|
|
}, nil
|
|
}
|
|
|
|
// handleSSE handles Server-Sent Events for MCP
|
|
func (s *MCPServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|
// Set headers for SSE
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
|
// Create a channel to receive messages
|
|
msgChan := make(chan string)
|
|
defer close(msgChan)
|
|
|
|
// Start a goroutine to handle the stdio transport
|
|
go func() {
|
|
// TODO: Implement message handling between HTTP and stdio transport
|
|
// This would require implementing a custom transport that bridges HTTP and stdio
|
|
|
|
}()
|
|
|
|
// Send messages to the client
|
|
for msg := range msgChan {
|
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", msg); err != nil {
|
|
s.logger.Error("Failed to write SSE message", zap.Error(err))
|
|
return
|
|
}
|
|
w.(http.Flusher).Flush()
|
|
}
|
|
}
|
|
|
|
// handleAnalyzeHTTP handles HTTP requests for the analyze endpoint
|
|
func (s *MCPServer) handleAnalyzeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Parse the request body
|
|
var req AnalyzeRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to decode request: %v", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate MaxConcurrency to prevent excessive memory allocation
|
|
req.MaxConcurrency = validateMaxConcurrency(req.MaxConcurrency)
|
|
|
|
// Call the analyze handler
|
|
resp, err := s.handleAnalyze(r.Context(), &req)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to analyze: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Set response headers
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Write the response
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
s.logger.Error("Failed to encode response", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// Close closes the MCP server and releases resources
|
|
func (s *MCPServer) Close() error {
|
|
return nil
|
|
}
|