Allow to pass additional headers on Get()

This commit is contained in:
Ettore Di Giacinto
2022-10-07 21:27:48 +00:00
parent d273b298fc
commit fef86f595b
4 changed files with 25 additions and 8 deletions

View File

@@ -15,10 +15,15 @@ type config struct {
cacerts []byte cacerts []byte
header http.Header header http.Header
headers map[string]string
systemfallback bool systemfallback bool
} }
func newConfig() *config {
return &config{headers: make(map[string]string)}
}
// Option is a generic option for TPM configuration // Option is a generic option for TPM configuration
type Option func(c *config) error type Option func(c *config) error
@@ -53,6 +58,14 @@ func WithHeader(header http.Header) Option {
} }
} }
// WithAdditionalHeader adds a key to the request
func WithAdditionalHeader(k, v string) Option {
return func(c *config) error {
c.headers[k] = v
return nil
}
}
// WithSeed sets a permanent seed. Used with TPM emulated device. // WithSeed sets a permanent seed. Used with TPM emulated device.
func WithSeed(s int64) Option { func WithSeed(s int64) Option {
return func(c *config) error { return func(c *config) error {

10
get.go
View File

@@ -21,7 +21,7 @@ import (
// It will return the token as a string and the generated AK that should // It will return the token as a string and the generated AK that should
// be saved by the caller for later Authentication. // be saved by the caller for later Authentication.
func GetAuthToken(opts ...Option) (string, []byte, error) { func GetAuthToken(opts ...Option) (string, []byte, error) {
c := &config{} c := newConfig()
c.apply(opts...) c.apply(opts...)
attestationData, akBytes, err := getAttestationData(c) attestationData, akBytes, err := getAttestationData(c)
@@ -41,7 +41,7 @@ func GetAuthToken(opts ...Option) (string, []byte, error) {
// attestation server, will compute a challenge response via the TPM using the passed // attestation server, will compute a challenge response via the TPM using the passed
// Attestation Key (AK) and will send it back to the attestation server. // Attestation Key (AK) and will send it back to the attestation server.
func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error { func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error {
c := &config{} c := newConfig()
c.apply(opts...) c.apply(opts...)
var challenge Challenge var challenge Challenge
@@ -64,7 +64,7 @@ func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error {
// Get retrieves a message from a remote ws server after // Get retrieves a message from a remote ws server after
// a successfully process of the TPM challenge // a successfully process of the TPM challenge
func Get(url string, opts ...Option) ([]byte, error) { func Get(url string, opts ...Option) ([]byte, error) {
c := &config{} c := newConfig()
c.apply(opts...) c.apply(opts...)
header := c.header header := c.header
@@ -109,6 +109,10 @@ func Get(url string, opts ...Option) ([]byte, error) {
} }
header.Add("Authorization", token) header.Add("Authorization", token)
for k, v := range c.headers {
header.Add(k, v)
}
wsURL := strings.Replace(url, "http", "ws", 1) wsURL := strings.Replace(url, "http", "ws", 1)
logrus.Infof("Using TPMHash %s to dial %s", hash, wsURL) logrus.Infof("Using TPMHash %s to dial %s", hash, wsURL)
conn, resp, err := dialer.Dial(wsURL, header) conn, resp, err := dialer.Dial(wsURL, header)

View File

@@ -57,6 +57,7 @@ func WSServer(ctx context.Context) {
for { for {
token := r.Header.Get("Authorization") token := r.Header.Get("Authorization")
awesome := r.Header.Get("awesome")
ek, at, err := GetAttestationData(token) ek, at, err := GetAttestationData(token)
if err != nil { if err != nil {
fmt.Println("error", err.Error()) fmt.Println("error", err.Error())
@@ -77,7 +78,7 @@ func WSServer(ctx context.Context) {
} }
writer, _ := conn.NextWriter(websocket.BinaryMessage) writer, _ := conn.NextWriter(websocket.BinaryMessage)
json.NewEncoder(writer).Encode(map[string]string{"foo": "bar"}) json.NewEncoder(writer).Encode(map[string]string{"foo": "bar", "header": awesome})
} }
}) })
@@ -103,13 +104,12 @@ var _ = Describe("GET", func() {
WSServer(ctx) WSServer(ctx)
msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1)) msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1), WithAdditionalHeader("awesome", "content"))
result := map[string]interface{}{} result := map[string]interface{}{}
json.Unmarshal(msg, &result) json.Unmarshal(msg, &result)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal(map[string]interface{}{"foo": "bar"})) Expect(result).To(Equal(map[string]interface{}{"foo": "bar", "header": "content"}))
}) })
}) })
}) })

2
tpm.go
View File

@@ -52,7 +52,7 @@ func ResolveToken(token string, opts ...Option) (bool, string, error) {
// GetPubHash returns the EK's pub hash // GetPubHash returns the EK's pub hash
func GetPubHash(opts ...Option) (string, error) { func GetPubHash(opts ...Option) (string, error) {
c := &config{} c := newConfig()
c.apply(opts...) c.apply(opts...)
ek, err := getEK(c) ek, err := getEK(c)