mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2025-09-17 15:52:50 +00:00
feat: amazonsagemaker AI provider (#731)
* feat(amazonsagemaker): Add AmazonSageMaker AI provider Co-authored-by: NAME 18630245+zaremb@users.noreply.github.com Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> * feat(amazonsagemaker): Add AmazonSageMaker AI provider Co-authored-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com> Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> * feat(auth): add top p and max tokens to auth and use them in sagemaker backend Signed-off-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com> * feat: Updates SageMaker docs, validate topP, ident Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> * feat: list of passwordlessProviders Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> * feat: returns err Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> * fix: remove log.Fatal(err) Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> --------- Signed-off-by: Damian Kuroczko <7778327+dkuroczk@users.noreply.github.com> Signed-off-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com> Co-authored-by: Mateusz Zaremba <18630245+zaremb@users.noreply.github.com>
This commit is contained in:
62
README.md
62
README.md
@@ -386,8 +386,6 @@ In addition to this you will need to set the follow local environmental variable
|
|||||||
k8sgpt auth add --backend amazonbedrock --model anthropic.claude-v2
|
k8sgpt auth add --backend amazonbedrock --model anthropic.claude-v2
|
||||||
```
|
```
|
||||||
|
|
||||||
TODO: Currently access key will be requested in the CLI, you can enter anything into this.
|
|
||||||
|
|
||||||
#### Usage
|
#### Usage
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -398,6 +396,66 @@ k8sgpt analyze -e -b amazonbedrock
|
|||||||
|
|
||||||
You're right, I don't have enough context to determine if a StatefulSet is correctly configured to use a non-existent service. A StatefulSet manages Pods with persistent storage, and the Pods are created from the same spec. The service name referenced in the StatefulSet configuration would need to match an existing Kubernetes service for the Pods to connect to. Without more details on the specific StatefulSet and environment, I can't confirm whether the configuration is valid or not.
|
You're right, I don't have enough context to determine if a StatefulSet is correctly configured to use a non-existent service. A StatefulSet manages Pods with persistent storage, and the Pods are created from the same spec. The service name referenced in the StatefulSet configuration would need to match an existing Kubernetes service for the Pods to connect to. Without more details on the specific StatefulSet and environment, I can't confirm whether the configuration is valid or not.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Amazon SageMaker Provider</summary>
|
||||||
|
|
||||||
|
#### Prerequisites
|
||||||
|
|
||||||
|
1. **AWS CLI Configuration**: Make sure you have the AWS Command Line Interface (CLI) configured on your machine. If you haven't already configured the AWS CLI, you can follow the official AWS documentation for instructions on how to do it: [AWS CLI Configuration Guide](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html).
|
||||||
|
|
||||||
|
2. **SageMaker Instance**: You need to have an Amazon SageMaker instance set up. If you don't have one already, you can follow the step-by-step instructions provided in this repository for creating a SageMaker instance: [llm-sagemaker-jumpstart-cdk](https://github.com/zaremb/llm-sagemaker-jumpstart-cdk).
|
||||||
|
|
||||||
|
#### Backend Configuration
|
||||||
|
|
||||||
|
To add amazonsagemaker backend two parameters are required:
|
||||||
|
|
||||||
|
* `--endpointname` Amazon SageMaker endpoint name.
|
||||||
|
* `--providerRegion` AWS region where SageMaker instance is created. `k8sgpt` uses this region to connect to SageMaker (not the one defined with AWS CLI or environment variables )
|
||||||
|
|
||||||
|
To add amazonsagemaker as a backend run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
k8sgpt auth add --backend amazonsagemaker --providerRegion eu-west-1 --endpointname endpoint-xxxxxxxxxx
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Optional params
|
||||||
|
|
||||||
|
Optionally, when adding the backend and later by changing the configuration file, you can set the following parameters:
|
||||||
|
|
||||||
|
`-l, --maxtokens int` Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length (default 2048)
|
||||||
|
|
||||||
|
`-t, --temperature float32` The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random) (default 0.7)
|
||||||
|
|
||||||
|
`-c, --topp float32` Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability. (default 0.5)
|
||||||
|
|
||||||
|
To make amazonsagemaker as a default backend run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
k8sgpt auth default -p amazonsagemaker
|
||||||
|
```
|
||||||
|
|
||||||
|
#### AmazonSageMaker Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./k8sgpt analyze -e -b amazonsagemaker
|
||||||
|
100% |███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| (1/1, 14 it/min)
|
||||||
|
AI Provider: amazonsagemaker
|
||||||
|
|
||||||
|
0 default/nginx(nginx)
|
||||||
|
- Error: Back-off pulling image "nginxx"
|
||||||
|
Error: Back-off pulling image "nginxx"
|
||||||
|
|
||||||
|
Solution:
|
||||||
|
|
||||||
|
1. Check if the image exists in the registry by running `docker image ls nginxx`.
|
||||||
|
2. If the image is not found, try pulling it by running `docker pull nginxx`.
|
||||||
|
3. If the image is still not available, check if there are any network issues by running `docker network inspect` and `docker network list`.
|
||||||
|
4. If the issue persists, try restarting the Docker daemon by running `sudo service docker restart`.
|
||||||
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
@@ -41,7 +41,10 @@ var addCmd = &cobra.Command{
|
|||||||
_ = cmd.MarkFlagRequired("engine")
|
_ = cmd.MarkFlagRequired("engine")
|
||||||
_ = cmd.MarkFlagRequired("baseurl")
|
_ = cmd.MarkFlagRequired("baseurl")
|
||||||
}
|
}
|
||||||
|
if strings.ToLower(backend) == "amazonsagemaker" {
|
||||||
|
_ = cmd.MarkFlagRequired("endpointname")
|
||||||
|
_ = cmd.MarkFlagRequired("providerRegion")
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
@@ -90,6 +93,10 @@ var addCmd = &cobra.Command{
|
|||||||
color.Red("Error: temperature ranges from 0 to 1.")
|
color.Red("Error: temperature ranges from 0 to 1.")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
if topP > 1.0 || topP < 0.0 {
|
||||||
|
color.Red("Error: topP ranges from 0 to 1.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
if ai.NeedPassword(backend) && password == "" {
|
if ai.NeedPassword(backend) && password == "" {
|
||||||
fmt.Printf("Enter %s Key: ", backend)
|
fmt.Printf("Enter %s Key: ", backend)
|
||||||
@@ -108,9 +115,12 @@ var addCmd = &cobra.Command{
|
|||||||
Model: model,
|
Model: model,
|
||||||
Password: password,
|
Password: password,
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
|
EndpointName: endpointName,
|
||||||
Engine: engine,
|
Engine: engine,
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
ProviderRegion: providerRegion,
|
ProviderRegion: providerRegion,
|
||||||
|
TopP: topP,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if providerIndex == -1 {
|
if providerIndex == -1 {
|
||||||
@@ -138,6 +148,12 @@ func init() {
|
|||||||
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
|
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
|
||||||
// add flag for url
|
// add flag for url
|
||||||
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
|
||||||
|
// add flag for endpointName
|
||||||
|
addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, (e.g `endpoint-xxxxxxxxxxxx`)")
|
||||||
|
// add flag for topP
|
||||||
|
addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.")
|
||||||
|
// max tokens
|
||||||
|
addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length")
|
||||||
// add flag for temperature
|
// add flag for temperature
|
||||||
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||||
// add flag for azure open ai engine/deployment name
|
// add flag for azure open ai engine/deployment name
|
||||||
|
@@ -22,10 +22,13 @@ var (
|
|||||||
backend string
|
backend string
|
||||||
password string
|
password string
|
||||||
baseURL string
|
baseURL string
|
||||||
|
endpointName string
|
||||||
model string
|
model string
|
||||||
engine string
|
engine string
|
||||||
temperature float32
|
temperature float32
|
||||||
providerRegion string
|
providerRegion string
|
||||||
|
topP float32
|
||||||
|
maxTokens int
|
||||||
)
|
)
|
||||||
|
|
||||||
var configAI ai.AIConfiguration
|
var configAI ai.AIConfiguration
|
||||||
|
170
pkg/ai/amazonsagemaker.go
Normal file
170
pkg/ai/amazonsagemaker.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2023 The K8sGPT Authors.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/fatih/color"
|
||||||
|
"github.com/k8sgpt-ai/k8sgpt/pkg/cache"
|
||||||
|
"github.com/k8sgpt-ai/k8sgpt/pkg/util"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/sagemakerruntime"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SageMakerAIClient struct {
|
||||||
|
client *sagemakerruntime.SageMakerRuntime
|
||||||
|
language string
|
||||||
|
model string
|
||||||
|
temperature float32
|
||||||
|
endpoint string
|
||||||
|
topP float32
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Generations []struct {
|
||||||
|
Generation struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"generation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Inputs [][]Message `json:"inputs"`
|
||||||
|
Parameters Parameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parameters struct {
|
||||||
|
MaxNewTokens int `json:"max_new_tokens"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SageMakerAIClient) Configure(config IAIConfig, language string) error {
|
||||||
|
|
||||||
|
// Create a new AWS session
|
||||||
|
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||||
|
Config: aws.Config{Region: aws.String(config.GetProviderRegion())},
|
||||||
|
SharedConfigState: session.SharedConfigEnable,
|
||||||
|
}))
|
||||||
|
|
||||||
|
c.language = language
|
||||||
|
// Create a new SageMaker runtime client
|
||||||
|
c.client = sagemakerruntime.New(sess)
|
||||||
|
c.model = config.GetModel()
|
||||||
|
c.endpoint = config.GetEndpointName()
|
||||||
|
c.temperature = config.GetTemperature()
|
||||||
|
c.maxTokens = config.GetMaxTokens()
|
||||||
|
c.topP = config.GetTopP()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SageMakerAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) {
|
||||||
|
// Create a completion request
|
||||||
|
|
||||||
|
if len(promptTmpl) == 0 {
|
||||||
|
promptTmpl = PromptMap["default"]
|
||||||
|
}
|
||||||
|
|
||||||
|
request := Request{
|
||||||
|
Inputs: [][]Message{
|
||||||
|
{
|
||||||
|
{Role: "system", Content: "DEFAULT_PROMPT"},
|
||||||
|
{Role: "user", Content: fmt.Sprintf(promptTmpl, c.language, prompt)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
Parameters: Parameters{
|
||||||
|
MaxNewTokens: int(c.maxTokens),
|
||||||
|
TopP: float64(c.topP),
|
||||||
|
Temperature: float64(c.temperature),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert request to []byte
|
||||||
|
bytesData, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an input object
|
||||||
|
input := &sagemakerruntime.InvokeEndpointInput{
|
||||||
|
Body: bytesData,
|
||||||
|
EndpointName: aws.String(c.endpoint),
|
||||||
|
ContentType: aws.String("application/json"), // Set the content type as per your model's requirements
|
||||||
|
Accept: aws.String("application/json"), // Set the accept type as per your model's requirements
|
||||||
|
CustomAttributes: aws.String("accept_eula=true"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the InvokeEndpoint function
|
||||||
|
result, err := c.client.InvokeEndpoint(input)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// // Define a slice of Generations
|
||||||
|
var generations Generations
|
||||||
|
|
||||||
|
err = json.Unmarshal([]byte(string(result.Body)), &generations)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Check for length of generations
|
||||||
|
if len(generations) != 1 {
|
||||||
|
return "", fmt.Errorf("Expected exactly one generation, but got %d", len(generations))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access the content
|
||||||
|
content := generations[0].Generation.Content
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *SageMakerAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) {
|
||||||
|
// parse the text with the AI backend
|
||||||
|
inputKey := strings.Join(prompt, " ")
|
||||||
|
// Check for cached data
|
||||||
|
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
|
||||||
|
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
|
||||||
|
|
||||||
|
response, err := a.GetCompletion(ctx, inputKey, promptTmpl)
|
||||||
|
if err != nil {
|
||||||
|
color.Red("error getting completion: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
color.Red("error storing value to cache: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *SageMakerAIClient) GetName() string {
|
||||||
|
return "amazonsagemaker"
|
||||||
|
}
|
@@ -27,6 +27,7 @@ var (
|
|||||||
&NoOpAIClient{},
|
&NoOpAIClient{},
|
||||||
&CohereClient{},
|
&CohereClient{},
|
||||||
&AmazonBedRockClient{},
|
&AmazonBedRockClient{},
|
||||||
|
&SageMakerAIClient{},
|
||||||
}
|
}
|
||||||
Backends = []string{
|
Backends = []string{
|
||||||
"openai",
|
"openai",
|
||||||
@@ -35,6 +36,7 @@ var (
|
|||||||
"noopai",
|
"noopai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"amazonbedrock",
|
"amazonbedrock",
|
||||||
|
"amazonsagemaker",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,9 +51,12 @@ type IAIConfig interface {
|
|||||||
GetPassword() string
|
GetPassword() string
|
||||||
GetModel() string
|
GetModel() string
|
||||||
GetBaseURL() string
|
GetBaseURL() string
|
||||||
|
GetEndpointName() string
|
||||||
GetEngine() string
|
GetEngine() string
|
||||||
GetTemperature() float32
|
GetTemperature() float32
|
||||||
GetProviderRegion() string
|
GetProviderRegion() string
|
||||||
|
GetTopP() float32
|
||||||
|
GetMaxTokens() int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(provider string) IAI {
|
func NewClient(provider string) IAI {
|
||||||
@@ -74,15 +79,30 @@ type AIProvider struct {
|
|||||||
Model string `mapstructure:"model"`
|
Model string `mapstructure:"model"`
|
||||||
Password string `mapstructure:"password" yaml:"password,omitempty"`
|
Password string `mapstructure:"password" yaml:"password,omitempty"`
|
||||||
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
|
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
|
||||||
|
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
|
||||||
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
|
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
|
||||||
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
|
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
|
||||||
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
|
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
|
||||||
|
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
|
||||||
|
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AIProvider) GetBaseURL() string {
|
func (p *AIProvider) GetBaseURL() string {
|
||||||
return p.BaseURL
|
return p.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *AIProvider) GetEndpointName() string {
|
||||||
|
return p.EndpointName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AIProvider) GetTopP() float32 {
|
||||||
|
return p.TopP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AIProvider) GetMaxTokens() int {
|
||||||
|
return p.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
func (p *AIProvider) GetPassword() string {
|
func (p *AIProvider) GetPassword() string {
|
||||||
return p.Password
|
return p.Password
|
||||||
}
|
}
|
||||||
@@ -102,6 +122,13 @@ func (p *AIProvider) GetProviderRegion() string {
|
|||||||
return p.ProviderRegion
|
return p.ProviderRegion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock"}
|
||||||
|
|
||||||
func NeedPassword(backend string) bool {
|
func NeedPassword(backend string) bool {
|
||||||
return backend != "localai"
|
for _, b := range passwordlessProviders {
|
||||||
|
if b == backend {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user