diff --git a/cmd/discovery/client/client.go b/cmd/discovery/client/client.go index a3b3f50..548733f 100644 --- a/cmd/discovery/client/client.go +++ b/cmd/discovery/client/client.go @@ -9,6 +9,7 @@ import ( "github.com/jaypipes/ghw/pkg/block" "github.com/kairos-io/kairos-challenger/pkg/constants" + "github.com/kairos-io/kairos-challenger/pkg/payload" "github.com/kairos-io/kcrypt/pkg/bus" "github.com/kairos-io/tpm-helpers" "github.com/mudler/go-pluggable" @@ -57,6 +58,28 @@ func (c *Client) Start() error { return factory.Run(pluggable.EventType(os.Args[1]), os.Stdin, os.Stdout) } +func (c *Client) generatePass(postEndpoint string, p *block.Partition) error { + + rand := utils.RandomString(32) + pass, err := tpm.EncryptBlob([]byte(rand)) + if err != nil { + return err + } + bpass := base64.RawURLEncoding.EncodeToString(pass) + + opts := []tpm.Option{ + tpm.WithAdditionalHeader("label", p.Label), + tpm.WithAdditionalHeader("name", p.Name), + tpm.WithAdditionalHeader("uuid", p.UUID), + } + conn, err := tpm.Connection(postEndpoint, opts...) + if err != nil { + return err + } + + return conn.WriteJSON(payload.Data{Passphrase: bpass, GeneratedBy: constants.TPMSecret}) +} + func (c *Client) waitPass(p *block.Partition, attempts int) (pass string, err error) { // IF we don't have any server configured, just do local if c.Config.Kcrypt.Challenger.Server == "" { @@ -66,33 +89,19 @@ func (c *Client) waitPass(p *block.Partition, attempts int) (pass string, err er challengeEndpoint := fmt.Sprintf("%s/getPass", c.Config.Kcrypt.Challenger.Server) postEndpoint := fmt.Sprintf("%s/postPass", c.Config.Kcrypt.Challenger.Server) - // IF server doesn't have a pass for us, then we generate one and we set it - if _, _, err := getPass(challengeEndpoint, p); err == errPartNotFound { - rand := utils.RandomString(32) - pass, err := tpm.EncryptBlob([]byte(rand)) - if err != nil { - return "", err - } - bpass := base64.RawURLEncoding.EncodeToString(pass) - - opts := []tpm.Option{ - tpm.WithAdditionalHeader("label", p.Label), - tpm.WithAdditionalHeader("name", p.Name), - tpm.WithAdditionalHeader("uuid", p.UUID), - } - conn, err := tpm.Connection(postEndpoint, opts...) - if err != nil { - return "", err - } - err = conn.WriteJSON(map[string]string{"passphrase": bpass, constants.GeneratedByKey: constants.TPMSecret}) - if err != nil { - return rand, err - } - } - for tries := 0; tries < attempts; tries++ { var generated bool pass, generated, err = getPass(challengeEndpoint, p) + if err == errPartNotFound { + // IF server doesn't have a pass for us, then we generate one and we set it + err = c.generatePass(postEndpoint, p) + if err != nil { + return + } + // Attempt to fetch again - validate that the server has it now + tries = 0 + continue + } if generated { // passphrase is encrypted return c.decryptPassphrase(pass) } diff --git a/cmd/discovery/client/enc.go b/cmd/discovery/client/enc.go index 3313267..d74c077 100644 --- a/cmd/discovery/client/enc.go +++ b/cmd/discovery/client/enc.go @@ -3,8 +3,10 @@ package client import ( "encoding/json" "fmt" + "strings" "github.com/kairos-io/kairos-challenger/pkg/constants" + "github.com/kairos-io/kairos-challenger/pkg/payload" "github.com/jaypipes/ghw/pkg/block" "github.com/kairos-io/tpm-helpers" @@ -22,16 +24,21 @@ func getPass(server string, partition *block.Partition) (string, bool, error) { if err != nil { return "", false, err } - result := map[string]interface{}{} + result := payload.Data{} err = json.Unmarshal(msg, &result) if err != nil { return "", false, errors.Wrap(err, string(msg)) } - generatedBy, generated := result[constants.GeneratedByKey] - p, ok := result["passphrase"] - if ok { - return fmt.Sprint(p), generated && generatedBy == constants.TPMSecret, nil + + if result.HasPassphrase() { + return fmt.Sprint(result.Passphrase), result.HasBeenGenerated() && result.GeneratedBy == constants.TPMSecret, nil + } else if result.HasError() { + if strings.Contains(result.Error, "No secret found for") { + return "", false, errPartNotFound + } + return "", false, fmt.Errorf(result.Error) } + return "", false, errPartNotFound } diff --git a/pkg/challenger/challenger.go b/pkg/challenger/challenger.go index 869c795..a1b5be2 100644 --- a/pkg/challenger/challenger.go +++ b/pkg/challenger/challenger.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "time" keyserverv1alpha1 "github.com/kairos-io/kairos-challenger/api/v1alpha1" "github.com/kairos-io/kairos-challenger/pkg/constants" + "github.com/kairos-io/kairos-challenger/pkg/payload" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kairos-io/kairos-challenger/controllers" @@ -86,14 +88,26 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr m := http.NewServeMux() + errorMessage := func(writer io.WriteCloser, errMsg string) { + err := json.NewEncoder(writer).Encode(payload.Data{Error: errMsg}) + if err != nil { + fmt.Println("error encoding the response to json", err.Error()) + } + fmt.Println(errMsg) + } + m.HandleFunc("/postPass", func(w http.ResponseWriter, r *http.Request) { conn, _ := upgrader.Upgrade(w, r, nil) // error ignored for sake of simplicity for { + + fmt.Println("Receiving passphrase") if err := tpm.AuthRequest(r, conn); err != nil { fmt.Println("error", err.Error()) return } defer conn.Close() + fmt.Println("[Receiving passphrase] auth succeeded") + token := r.Header.Get("Authorization") hashEncoded, err := getPubHash(token) @@ -101,11 +115,12 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr fmt.Println("error decoding pubhash", err.Error()) return } + fmt.Println("[Receiving passphrase] pubhash", hashEncoded) label := r.Header.Get("label") name := r.Header.Get("name") uuid := r.Header.Get("uuid") - v := map[string]string{} + v := &payload.Data{} volumeList := &keyserverv1alpha1.SealedVolumeList{} if err := reconciler.List(ctx, volumeList, &client.ListOptions{Namespace: namespace}); err != nil { @@ -127,13 +142,12 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr return } - if err := conn.ReadJSON(&v); err != nil { + if err := conn.ReadJSON(v); err != nil { fmt.Println("error", err.Error()) return } - pass, ok := v["passphrase"] - if ok { + if v.HasPassphrase() && !v.HasError() { secretName := fmt.Sprintf("%s-%s", sealedVolumeData.VolumeName, sealedVolumeData.PartitionLabel) secretPath := "passphrase" if sealedVolumeData.SecretName != "" { @@ -158,9 +172,9 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr Name: secretName, Namespace: namespace, }, - Data: map[string][]byte{ - secretPath: []byte(pass), - constants.GeneratedByKey: []byte(v[constants.GeneratedByKey]), + StringData: map[string]string{ + secretPath: v.Passphrase, + constants.GeneratedByKey: v.GeneratedBy, }, Type: "Opaque", } @@ -213,10 +227,12 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr }, volumeList) if sealedVolumeData == nil { - fmt.Println("No TPM Hash found for", hashEncoded) + writer, _ := conn.NextWriter(websocket.BinaryMessage) + errorMessage(writer, fmt.Sprintf("Invalid hash: %s", hashEncoded)) conn.Close() return } + writer, _ := conn.NextWriter(websocket.BinaryMessage) if !sealedVolumeData.Quarantined { secretName := fmt.Sprintf("%s-%s", sealedVolumeData.VolumeName, sealedVolumeData.PartitionLabel) @@ -236,12 +252,10 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr secret, err := kclient.CoreV1().Secrets(namespace).Get(ctx, secretName, v1.GetOptions{}) if err == nil { passphrase := secret.Data[secretPath] - generatedBy, generated := secret.Data[constants.GeneratedByKey] - result := map[string]string{"passphrase": string(passphrase)} - if generated { - result[constants.GeneratedByKey] = string(generatedBy) - } - err = json.NewEncoder(writer).Encode(result) + generatedBy := secret.Data[constants.GeneratedByKey] + + p := payload.Data{Passphrase: string(passphrase), GeneratedBy: string(generatedBy)} + err = json.NewEncoder(writer).Encode(p) if err != nil { fmt.Println("error encoding the passphrase to json", err.Error(), string(passphrase)) } @@ -255,9 +269,11 @@ func Start(ctx context.Context, kclient *kubernetes.Clientset, reconciler *contr } return + } else { + errorMessage(writer, fmt.Sprintf("No secret found for %s and %s", hashEncoded, sealedVolumeData.PartitionLabel)) } } else { - fmt.Println("error getting the secret", err.Error()) + errorMessage(writer, fmt.Sprintf("quarantined: %s", sealedVolumeData.PartitionLabel)) if err = conn.Close(); err != nil { fmt.Println("error closing the connection", err.Error()) return diff --git a/pkg/payload/payload.go b/pkg/payload/payload.go new file mode 100644 index 0000000..ced91e4 --- /dev/null +++ b/pkg/payload/payload.go @@ -0,0 +1,19 @@ +package payload + +type Data struct { + Passphrase string `json:"passphrase"` + Error string `json:"error"` + GeneratedBy string `json:"generated_by"` +} + +func (d Data) HasError() bool { + return d.Error != "" +} + +func (d Data) HasPassphrase() bool { + return d.Passphrase != "" +} + +func (d Data) HasBeenGenerated() bool { + return d.GeneratedBy != "" +}