chore: customized prompt template for integration plugins (#403)

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
This commit is contained in:
Peter Pan
2023-06-13 04:14:12 +08:00
committed by GitHub
parent ad2a5fd5fc
commit c85203bccd
6 changed files with 29 additions and 13 deletions

View File

@@ -36,7 +36,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
return nil return nil
} }
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request // Create a completion request
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model, Model: c.model,
@@ -53,7 +53,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (strin
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ") inputKey := strings.Join(prompt, " ")
// Check for cached data // Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
@@ -74,7 +74,7 @@ func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.
} }
} }
response, err := a.GetCompletion(ctx, inputKey) response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -36,8 +36,8 @@ var (
type IAI interface { type IAI interface {
Configure(config IAIConfig, language string) error Configure(config IAIConfig, language string) error
GetCompletion(ctx context.Context, prompt string) (string, error) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error)
Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error)
GetName() string GetName() string
} }

View File

@@ -38,20 +38,20 @@ func (c *NoOpAIClient) Configure(config IAIConfig, language string) error {
return nil return nil
} }
func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request // Create a completion request
response := "I am a noop response to the prompt " + prompt response := "I am a noop response to the prompt " + prompt
return response, nil return response, nil
} }
func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
// parse the text with the AI backend // parse the text with the AI backend
inputKey := strings.Join(prompt, " ") inputKey := strings.Join(prompt, " ")
// Check for cached data // Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey)) sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc) cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
response, err := a.GetCompletion(ctx, inputKey) response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil { if err != nil {
color.Red("error getting completion: %v", err) color.Red("error getting completion: %v", err)
return "", err return "", err

View File

@@ -53,14 +53,17 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
return nil return nil
} }
func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
// Create a completion request // Create a completion request
if len(promptTmpl) == 0 {
promptTmpl = PromptMap["default"]
}
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model, Model: c.model,
Messages: []openai.ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: "user", Role: "user",
Content: fmt.Sprintf(default_prompt, c.language, prompt), Content: fmt.Sprintf(promptTmpl, c.language, prompt),
}, },
}, },
}) })
@@ -70,7 +73,7 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
inputKey := strings.Join(prompt, " ") inputKey := strings.Join(prompt, " ")
// Check for cached data // Check for cached data
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey)
@@ -91,7 +94,7 @@ func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.I
} }
} }
response, err := a.GetCompletion(ctx, inputKey) response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -6,4 +6,10 @@ const (
Error: {Explain error here} Error: {Explain error here}
Solution: {Step by step solution here} Solution: {Step by step solution here}
` `
trivy_prompt = "Explain the following trivy scan result and the detail risk or root cause of the CVE ID, then provide a solution. Response in %s: %s"
) )
var PromptMap = map[string]string{
"default": default_prompt,
"VulnerabilityReport": trivy_prompt, // for Trivy intergration, the key should match `Result.Kind` in pkg/common/types.go
}

View File

@@ -261,7 +261,14 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error {
} }
texts = append(texts, failure.Text) texts = append(texts, failure.Text)
} }
parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache) // If the resource `Kind` comes from a "integration plugin", maybe a customized prompt template will be involved.
var promptTemplate string
if prompt, ok := ai.PromptMap[analysis.Kind]; ok {
promptTemplate = prompt
} else {
promptTemplate = ai.PromptMap["default"]
}
parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache, promptTemplate)
if err != nil { if err != nil {
// FIXME: can we avoid checking if output is json multiple times? // FIXME: can we avoid checking if output is json multiple times?
// maybe implement the progress bar better? // maybe implement the progress bar better?