feat: add custom restful backend for complex scenarios (e.g, rag) (#1228)

* feat: add custom restful backend for complex scenarios (e.g, rag)

Signed-off-by: popsiclexu <zhenxuexu@gmail.com>

* chore: rebased
chore: removed trivy

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* chore: updated deps

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

* chore: resolved issues

Signed-off-by: AlexsJones <alexsimonjones@gmail.com>

---------

Signed-off-by: popsiclexu <zhenxuexu@gmail.com>
Signed-off-by: popsiclexu <ZhenxueXu@gmail.com>
Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
Co-authored-by: popsiclexu <zhenxue.xu@mthreads.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
popsiclexu
2025-03-17 20:21:38 +08:00
committed by GitHub
parent eb7b36aa27
commit 7540e0084e
5 changed files with 157 additions and 1 deletions

View File

@@ -366,6 +366,8 @@ Unused:
> huggingface
> noopai
> googlevertexai
> watsonxai
> customrest
> ibmwatsonxai
```

147
pkg/ai/customrest.go Normal file
View File

@@ -0,0 +1,147 @@
package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
const CustomRestClientName = "customrest"
type CustomRestClient struct {
nopCloser
client *http.Client
base *url.URL
token string
model string
temperature float32
topP float32
topK int32
}
type CustomRestRequest struct {
Model string `json:"model"`
// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
}
type CustomRestResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
// Response is the textual response itself.
Response string `json:"response"`
}
func (c *CustomRestClient) Configure(config IAIConfig) error {
baseURL := config.GetBaseURL()
if baseURL == "" {
baseURL = defaultBaseURL
}
c.token = config.GetPassword()
baseClientURL, err := url.Parse(baseURL)
if err != nil {
return err
}
c.base = baseClientURL
proxyEndpoint := config.GetProxyEndpoint()
c.client = http.DefaultClient
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
c.client = &http.Client{
Transport: transport,
}
}
c.model = config.GetModel()
if c.model == "" {
c.model = defaultModel
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
return nil
}
func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
var promptDetail struct {
Language string `json:"language,omitempty"`
Message string `json:"message"`
Prompt string `json:"prompt,omitempty"`
}
prompt = strings.NewReplacer("\n", "\\n", "\t", "\\t").Replace(prompt)
if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil {
return "", err
}
generateRequest := &CustomRestRequest{
Model: c.model,
Prompt: promptDetail.Prompt,
Options: map[string]interface{}{
"temperature": c.temperature,
"top_p": c.topP,
"top_k": c.topK,
"message": promptDetail.Message,
"language": promptDetail.Language,
},
}
requestBody, err := json.Marshal(generateRequest)
if err != nil {
return "", err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(requestBody))
if err != nil {
return "", err
}
if c.token != "" {
request.Header.Set("Authorization", "Bearer "+c.token)
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/x-ndjson")
response, err := c.client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("could not read response body: %w", err)
}
if response.StatusCode >= http.StatusBadRequest {
return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, responseBody)
}
var result CustomRestResponse
if err := json.Unmarshal(responseBody, &result); err != nil {
return "", err
}
return result.Response, nil
}
func (c *CustomRestClient) GetName() string {
return CustomRestClientName
}

View File

@@ -32,6 +32,7 @@ var (
&HuggingfaceClient{},
&GoogleVertexAIClient{},
&OCIGenAIClient{},
&CustomRestClient{},
&IBMWatsonxAIClient{},
}
Backends = []string{
@@ -47,6 +48,7 @@ var (
huggingfaceAIClientName,
googleVertexAIClientName,
ociClientName,
CustomRestClientName,
ibmWatsonxAIClientName,
}
)
@@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"}
func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {

View File

@@ -56,9 +56,11 @@ const (
Solution: {kubectl command}
`
raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}`
)
var PromptMap = map[string]string{
"raw": raw_promt,
"default": default_prompt,
"PrometheusConfigValidate": prom_conf_prompt,
"PrometheusConfigRelabelReport": prom_relabel_prompt,

View File

@@ -405,6 +405,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st
// Process template.
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
if a.AIClient.GetName() == ai.CustomRestClientName {
prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt)
}
response, err := a.AIClient.GetCompletion(a.Context, prompt)
if err != nil {
return "", err