diff --git a/README.md b/README.md index 46ee7e6..db434d0 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,8 @@ Unused: > huggingface > noopai > googlevertexai +> watsonxai +> customrest > ibmwatsonxai ``` diff --git a/pkg/ai/customrest.go b/pkg/ai/customrest.go new file mode 100644 index 0000000..22acc5c --- /dev/null +++ b/pkg/ai/customrest.go @@ -0,0 +1,147 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const CustomRestClientName = "customrest" + +type CustomRestClient struct { + nopCloser + client *http.Client + base *url.URL + token string + model string + temperature float32 + topP float32 + topK int32 +} + +type CustomRestRequest struct { + Model string `json:"model"` + + // Prompt is the textual prompt to send to the model. + Prompt string `json:"prompt"` + + // Options lists model-specific options. For example, temperature can be + // set through this field, if the model supports it. + Options map[string]interface{} `json:"options"` +} + +type CustomRestResponse struct { + // Model is the model name that generated the response. + Model string `json:"model"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Response is the textual response itself. + Response string `json:"response"` +} + +func (c *CustomRestClient) Configure(config IAIConfig) error { + baseURL := config.GetBaseURL() + if baseURL == "" { + baseURL = defaultBaseURL + } + c.token = config.GetPassword() + baseClientURL, err := url.Parse(baseURL) + if err != nil { + return err + } + c.base = baseClientURL + + proxyEndpoint := config.GetProxyEndpoint() + c.client = http.DefaultClient + if proxyEndpoint != "" { + proxyUrl, err := url.Parse(proxyEndpoint) + if err != nil { + return err + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + c.client = &http.Client{ + Transport: transport, + } + } + + c.model = config.GetModel() + if c.model == "" { + c.model = defaultModel + } + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + c.topK = config.GetTopK() + return nil +} + +func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + var promptDetail struct { + Language string `json:"language,omitempty"` + Message string `json:"message"` + Prompt string `json:"prompt,omitempty"` + } + prompt = strings.NewReplacer("\n", "\\n", "\t", "\\t").Replace(prompt) + if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil { + return "", err + } + generateRequest := &CustomRestRequest{ + Model: c.model, + Prompt: promptDetail.Prompt, + Options: map[string]interface{}{ + "temperature": c.temperature, + "top_p": c.topP, + "top_k": c.topK, + "message": promptDetail.Message, + "language": promptDetail.Language, + }, + } + requestBody, err := json.Marshal(generateRequest) + if err != nil { + return "", err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(requestBody)) + if err != nil { + return "", err + } + if c.token != "" { + request.Header.Set("Authorization", "Bearer "+c.token) + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/x-ndjson") + + response, err := c.client.Do(request) + if err != nil { + return "", err + } + defer response.Body.Close() + + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("could not read response body: %w", err) + } + + if response.StatusCode >= http.StatusBadRequest { + return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, responseBody) + } + + var result CustomRestResponse + if err := json.Unmarshal(responseBody, &result); err != nil { + return "", err + } + return result.Response, nil +} + +func (c *CustomRestClient) GetName() string { + return CustomRestClientName +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index e9a5618..e1e7169 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -32,6 +32,7 @@ var ( &HuggingfaceClient{}, &GoogleVertexAIClient{}, &OCIGenAIClient{}, + &CustomRestClient{}, &IBMWatsonxAIClient{}, } Backends = []string{ @@ -47,6 +48,7 @@ var ( huggingfaceAIClientName, googleVertexAIClientName, ociClientName, + CustomRestClientName, ibmWatsonxAIClientName, } ) @@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header { return p.CustomHeaders } -var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} +var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/prompts.go b/pkg/ai/prompts.go index e41defc..f620741 100644 --- a/pkg/ai/prompts.go +++ b/pkg/ai/prompts.go @@ -56,9 +56,11 @@ const ( Solution: {kubectl command} ` + raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}` ) var PromptMap = map[string]string{ + "raw": raw_promt, "default": default_prompt, "PrometheusConfigValidate": prom_conf_prompt, "PrometheusConfigRelabelReport": prom_relabel_prompt, diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 79f46d5..227c343 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -405,6 +405,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st // Process template. prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey) + if a.AIClient.GetName() == ai.CustomRestClientName { + prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt) + } response, err := a.AIClient.GetCompletion(a.Context, prompt) if err != nil { return "", err