mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-10 03:29:18 +00:00
feat: add custom restful backend for complex scenarios (e.g, rag) (#1228)
* feat: add custom restful backend for complex scenarios (e.g, rag) Signed-off-by: popsiclexu <zhenxuexu@gmail.com> * chore: rebased chore: removed trivy Signed-off-by: AlexsJones <alexsimonjones@gmail.com> * chore: updated deps Signed-off-by: AlexsJones <alexsimonjones@gmail.com> * chore: resolved issues Signed-off-by: AlexsJones <alexsimonjones@gmail.com> --------- Signed-off-by: popsiclexu <zhenxuexu@gmail.com> Signed-off-by: popsiclexu <ZhenxueXu@gmail.com> Signed-off-by: AlexsJones <alexsimonjones@gmail.com> Co-authored-by: popsiclexu <zhenxue.xu@mthreads.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
@@ -366,6 +366,8 @@ Unused:
|
|||||||
> huggingface
|
> huggingface
|
||||||
> noopai
|
> noopai
|
||||||
> googlevertexai
|
> googlevertexai
|
||||||
|
> watsonxai
|
||||||
|
> customrest
|
||||||
> ibmwatsonxai
|
> ibmwatsonxai
|
||||||
```
|
```
|
||||||
|
|
||||||
|
147
pkg/ai/customrest.go
Normal file
147
pkg/ai/customrest.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const CustomRestClientName = "customrest"
|
||||||
|
|
||||||
|
type CustomRestClient struct {
|
||||||
|
nopCloser
|
||||||
|
client *http.Client
|
||||||
|
base *url.URL
|
||||||
|
token string
|
||||||
|
model string
|
||||||
|
temperature float32
|
||||||
|
topP float32
|
||||||
|
topK int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomRestRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// Prompt is the textual prompt to send to the model.
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// Options lists model-specific options. For example, temperature can be
|
||||||
|
// set through this field, if the model supports it.
|
||||||
|
Options map[string]interface{} `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomRestResponse struct {
|
||||||
|
// Model is the model name that generated the response.
|
||||||
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// CreatedAt is the timestamp of the response.
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// Response is the textual response itself.
|
||||||
|
Response string `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CustomRestClient) Configure(config IAIConfig) error {
|
||||||
|
baseURL := config.GetBaseURL()
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = defaultBaseURL
|
||||||
|
}
|
||||||
|
c.token = config.GetPassword()
|
||||||
|
baseClientURL, err := url.Parse(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.base = baseClientURL
|
||||||
|
|
||||||
|
proxyEndpoint := config.GetProxyEndpoint()
|
||||||
|
c.client = http.DefaultClient
|
||||||
|
if proxyEndpoint != "" {
|
||||||
|
proxyUrl, err := url.Parse(proxyEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyUrl),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.client = &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.model = config.GetModel()
|
||||||
|
if c.model == "" {
|
||||||
|
c.model = defaultModel
|
||||||
|
}
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
|
c.topP = config.GetTopP()
|
||||||
|
c.topK = config.GetTopK()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
|
var promptDetail struct {
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
}
|
||||||
|
prompt = strings.NewReplacer("\n", "\\n", "\t", "\\t").Replace(prompt)
|
||||||
|
if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
generateRequest := &CustomRestRequest{
|
||||||
|
Model: c.model,
|
||||||
|
Prompt: promptDetail.Prompt,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": c.temperature,
|
||||||
|
"top_p": c.topP,
|
||||||
|
"top_k": c.topK,
|
||||||
|
"message": promptDetail.Message,
|
||||||
|
"language": promptDetail.Language,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestBody, err := json.Marshal(generateRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(requestBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if c.token != "" {
|
||||||
|
request.Header.Set("Authorization", "Bearer "+c.token)
|
||||||
|
}
|
||||||
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
request.Header.Set("Accept", "application/x-ndjson")
|
||||||
|
|
||||||
|
response, err := c.client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, responseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CustomRestResponse
|
||||||
|
if err := json.Unmarshal(responseBody, &result); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result.Response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CustomRestClient) GetName() string {
|
||||||
|
return CustomRestClientName
|
||||||
|
}
|
@@ -32,6 +32,7 @@ var (
|
|||||||
&HuggingfaceClient{},
|
&HuggingfaceClient{},
|
||||||
&GoogleVertexAIClient{},
|
&GoogleVertexAIClient{},
|
||||||
&OCIGenAIClient{},
|
&OCIGenAIClient{},
|
||||||
|
&CustomRestClient{},
|
||||||
&IBMWatsonxAIClient{},
|
&IBMWatsonxAIClient{},
|
||||||
}
|
}
|
||||||
Backends = []string{
|
Backends = []string{
|
||||||
@@ -47,6 +48,7 @@ var (
|
|||||||
huggingfaceAIClientName,
|
huggingfaceAIClientName,
|
||||||
googleVertexAIClientName,
|
googleVertexAIClientName,
|
||||||
ociClientName,
|
ociClientName,
|
||||||
|
CustomRestClientName,
|
||||||
ibmWatsonxAIClientName,
|
ibmWatsonxAIClientName,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
|
|||||||
return p.CustomHeaders
|
return p.CustomHeaders
|
||||||
}
|
}
|
||||||
|
|
||||||
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
|
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"}
|
||||||
|
|
||||||
func NeedPassword(backend string) bool {
|
func NeedPassword(backend string) bool {
|
||||||
for _, b := range passwordlessProviders {
|
for _, b := range passwordlessProviders {
|
||||||
|
@@ -56,9 +56,11 @@ const (
|
|||||||
|
|
||||||
Solution: {kubectl command}
|
Solution: {kubectl command}
|
||||||
`
|
`
|
||||||
|
raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}`
|
||||||
)
|
)
|
||||||
|
|
||||||
var PromptMap = map[string]string{
|
var PromptMap = map[string]string{
|
||||||
|
"raw": raw_promt,
|
||||||
"default": default_prompt,
|
"default": default_prompt,
|
||||||
"PrometheusConfigValidate": prom_conf_prompt,
|
"PrometheusConfigValidate": prom_conf_prompt,
|
||||||
"PrometheusConfigRelabelReport": prom_relabel_prompt,
|
"PrometheusConfigRelabelReport": prom_relabel_prompt,
|
||||||
|
@@ -405,6 +405,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st
|
|||||||
|
|
||||||
// Process template.
|
// Process template.
|
||||||
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
|
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
|
||||||
|
if a.AIClient.GetName() == ai.CustomRestClientName {
|
||||||
|
prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt)
|
||||||
|
}
|
||||||
response, err := a.AIClient.GetCompletion(a.Context, prompt)
|
response, err := a.AIClient.GetCompletion(a.Context, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
Reference in New Issue
Block a user