Allow to pass additional headers on Get()

This commit is contained in:
Ettore Di Giacinto
2022-10-07 21:27:48 +00:00
parent d273b298fc
commit fef86f595b
4 changed files with 25 additions and 8 deletions

View File

@@ -15,10 +15,15 @@ type config struct {
cacerts []byte
header http.Header
headers map[string]string
systemfallback bool
}
func newConfig() *config {
return &config{headers: make(map[string]string)}
}
// Option is a generic option for TPM configuration
type Option func(c *config) error
@@ -53,6 +58,14 @@ func WithHeader(header http.Header) Option {
}
}
// WithAdditionalHeader adds a key to the request
func WithAdditionalHeader(k, v string) Option {
return func(c *config) error {
c.headers[k] = v
return nil
}
}
// WithSeed sets a permanent seed. Used with TPM emulated device.
func WithSeed(s int64) Option {
return func(c *config) error {

10
get.go
View File

@@ -21,7 +21,7 @@ import (
// It will return the token as a string and the generated AK that should
// be saved by the caller for later Authentication.
func GetAuthToken(opts ...Option) (string, []byte, error) {
c := &config{}
c := newConfig()
c.apply(opts...)
attestationData, akBytes, err := getAttestationData(c)
@@ -41,7 +41,7 @@ func GetAuthToken(opts ...Option) (string, []byte, error) {
// attestation server, will compute a challenge response via the TPM using the passed
// Attestation Key (AK) and will send it back to the attestation server.
func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error {
c := &config{}
c := newConfig()
c.apply(opts...)
var challenge Challenge
@@ -64,7 +64,7 @@ func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error {
// Get retrieves a message from a remote ws server after
// a successfully process of the TPM challenge
func Get(url string, opts ...Option) ([]byte, error) {
c := &config{}
c := newConfig()
c.apply(opts...)
header := c.header
@@ -109,6 +109,10 @@ func Get(url string, opts ...Option) ([]byte, error) {
}
header.Add("Authorization", token)
for k, v := range c.headers {
header.Add(k, v)
}
wsURL := strings.Replace(url, "http", "ws", 1)
logrus.Infof("Using TPMHash %s to dial %s", hash, wsURL)
conn, resp, err := dialer.Dial(wsURL, header)

View File

@@ -57,6 +57,7 @@ func WSServer(ctx context.Context) {
for {
token := r.Header.Get("Authorization")
awesome := r.Header.Get("awesome")
ek, at, err := GetAttestationData(token)
if err != nil {
fmt.Println("error", err.Error())
@@ -77,7 +78,7 @@ func WSServer(ctx context.Context) {
}
writer, _ := conn.NextWriter(websocket.BinaryMessage)
json.NewEncoder(writer).Encode(map[string]string{"foo": "bar"})
json.NewEncoder(writer).Encode(map[string]string{"foo": "bar", "header": awesome})
}
})
@@ -103,13 +104,12 @@ var _ = Describe("GET", func() {
WSServer(ctx)
msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1))
msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1), WithAdditionalHeader("awesome", "content"))
result := map[string]interface{}{}
json.Unmarshal(msg, &result)
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal(map[string]interface{}{"foo": "bar"}))
Expect(result).To(Equal(map[string]interface{}{"foo": "bar", "header": "content"}))
})
})
})

2
tpm.go
View File

@@ -52,7 +52,7 @@ func ResolveToken(token string, opts ...Option) (bool, string, error) {
// GetPubHash returns the EK's pub hash
func GetPubHash(opts ...Option) (string, error) {
c := &config{}
c := newConfig()
c.apply(opts...)
ek, err := getEK(c)