1
0
mirror of https://github.com/rancher/os.git synced 2025-09-08 02:01:27 +00:00
Files
os/pkg/tpm/auth_tpm.go
2021-10-29 23:08:26 -07:00

215 lines
5.2 KiB
Go

package tpm
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/google/certificate-transparency-go/x509"
"github.com/google/go-attestation/attest"
"github.com/gorilla/websocket"
v1 "github.com/rancher/os2/pkg/apis/rancheros.cattle.io/v1"
"github.com/rancher/wrangler/pkg/merr"
"github.com/sirupsen/logrus"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (a *AuthServer) verifyChain(ek *attest.EK, namespace string) error {
secret, err := a.secretCache.Get(namespace, tpmCACert)
if apierrors.IsNotFound(err) {
return nil
}
roots := x509.NewCertPool()
_ = roots.AppendCertsFromPEM(secret.Data[corev1.TLSCertKey])
opts := x509.VerifyOptions{
Roots: roots,
}
_, err = ek.Certificate.Verify(opts)
return err
}
func (a *AuthServer) generateChallenge(ek *attest.EK, attestationData *AttestationData) ([]byte, []byte, error) {
ap := attest.ActivationParameters{
TPMVersion: attest.TPMVersion20,
EK: ek.Public,
AK: *attestationData.AK,
}
secret, ec, err := ap.Generate()
if err != nil {
return nil, nil, fmt.Errorf("generating challenge: %w", err)
}
challengeBytes, err := json.Marshal(Challenge{EC: ec})
if err != nil {
return nil, nil, fmt.Errorf("marshalling challenge: %w", err)
}
return secret, challengeBytes, nil
}
func (a *AuthServer) validateChallenge(secret, resp []byte) error {
var response ChallengeResponse
if err := json.Unmarshal(resp, &response); err != nil {
return fmt.Errorf("unmarshalling challenge response: %w", err)
}
if !bytes.Equal(secret, response.Secret) {
return fmt.Errorf("invalid challenge response")
}
return nil
}
func (a *AuthServer) validHash(ek *attest.EK, registerNamespace string) (*v1.MachineInventory, error) {
hashEncoded, err := GetPubHash(ek)
if err != nil {
return nil, fmt.Errorf("tpm: could not get public key hash: %v", err)
}
if registerNamespace != "" {
if err := a.verifyChain(ek, registerNamespace); err != nil {
return nil, fmt.Errorf("verifying chain: %w", err)
}
return &v1.MachineInventory{
ObjectMeta: metav1.ObjectMeta{
Namespace: registerNamespace,
},
Spec: v1.MachineInventorySpec{
TPMHash: hashEncoded,
},
}, nil
}
machines, err := a.machineCache.GetByIndex(machineByHash, hashEncoded)
if apierrors.IsNotFound(err) || len(machines) != 1 {
if len(machines) > 1 {
logrus.Errorf("multiple machines for same hash %s found: %v", hashEncoded, machines)
}
return nil, fmt.Errorf("failed to find machine")
}
if err := a.verifyChain(ek, machines[0].Namespace); err != nil {
return nil, fmt.Errorf("verifying chain: %w", err)
}
return machines[0], nil
}
func writeRead(conn *websocket.Conn, input []byte) ([]byte, error) {
writer, err := conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return nil, err
}
if _, err := writer.Write(input); err != nil {
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
_, reader, err := conn.NextReader()
if err != nil {
return nil, err
}
return ioutil.ReadAll(reader)
}
func upgrade(resp http.ResponseWriter, req *http.Request) (*websocket.Conn, error) {
upgrader := websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(resp, req, nil)
if err != nil {
return nil, err
}
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
return conn, err
}
func (a *AuthServer) getAttestationData(header string) (*attest.EK, *AttestationData, error) {
tpmBytes, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(header, "Bearer TPM"))
if err != nil {
return nil, nil, err
}
var attestationData AttestationData
if err := json.Unmarshal(tpmBytes, &attestationData); err != nil {
return nil, nil, err
}
ek, err := DecodeEK(attestationData.EK)
if err != nil {
return nil, nil, err
}
return ek, &attestationData, nil
}
func (a *AuthServer) Authenticate(resp http.ResponseWriter, req *http.Request, registerNamespace string) (*v1.MachineInventory, bool, io.WriteCloser, error) {
header := req.Header.Get("Authorization")
if !strings.HasPrefix(header, "Bearer TPM") {
return nil, true, nil, nil
}
ek, attestationData, err := a.getAttestationData(header)
if err != nil {
return nil, false, nil, err
}
machine, err := a.validHash(ek, registerNamespace)
if err != nil {
return nil, false, nil, err
}
secret, challenge, err := a.generateChallenge(ek, attestationData)
if err != nil {
return nil, false, nil, err
}
conn, err := upgrade(resp, req)
if err != nil {
return nil, false, nil, err
}
challResp, err := writeRead(conn, challenge)
if err != nil {
return nil, false, nil, err
}
if err := a.validateChallenge(secret, challResp); err != nil {
return nil, false, nil, err
}
writer, err := conn.NextWriter(websocket.BinaryMessage)
return machine, false, &responseWriter{
WriteCloser: writer,
conn: conn,
}, err
}
type responseWriter struct {
io.WriteCloser
conn *websocket.Conn
}
func (r *responseWriter) Close() error {
err := r.WriteCloser.Close()
err2 := r.conn.Close()
return merr.NewErrors(err, err2)
}