diff --git a/config.go b/config.go index caa1d5a..fa0b78e 100644 --- a/config.go +++ b/config.go @@ -15,10 +15,15 @@ type config struct { cacerts []byte header http.Header + headers map[string]string systemfallback bool } +func newConfig() *config { + return &config{headers: make(map[string]string)} +} + // Option is a generic option for TPM configuration type Option func(c *config) error @@ -53,6 +58,14 @@ func WithHeader(header http.Header) Option { } } +// WithAdditionalHeader adds a key to the request +func WithAdditionalHeader(k, v string) Option { + return func(c *config) error { + c.headers[k] = v + return nil + } +} + // WithSeed sets a permanent seed. Used with TPM emulated device. func WithSeed(s int64) Option { return func(c *config) error { diff --git a/get.go b/get.go index d3c0464..052c496 100644 --- a/get.go +++ b/get.go @@ -21,7 +21,7 @@ import ( // It will return the token as a string and the generated AK that should // be saved by the caller for later Authentication. func GetAuthToken(opts ...Option) (string, []byte, error) { - c := &config{} + c := newConfig() c.apply(opts...) attestationData, akBytes, err := getAttestationData(c) @@ -41,7 +41,7 @@ func GetAuthToken(opts ...Option) (string, []byte, error) { // attestation server, will compute a challenge response via the TPM using the passed // Attestation Key (AK) and will send it back to the attestation server. func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error { - c := &config{} + c := newConfig() c.apply(opts...) var challenge Challenge @@ -64,7 +64,7 @@ func Authenticate(akBytes []byte, channel io.ReadWriter, opts ...Option) error { // Get retrieves a message from a remote ws server after // a successfully process of the TPM challenge func Get(url string, opts ...Option) ([]byte, error) { - c := &config{} + c := newConfig() c.apply(opts...) header := c.header @@ -109,6 +109,10 @@ func Get(url string, opts ...Option) ([]byte, error) { } header.Add("Authorization", token) + for k, v := range c.headers { + header.Add(k, v) + } + wsURL := strings.Replace(url, "http", "ws", 1) logrus.Infof("Using TPMHash %s to dial %s", hash, wsURL) conn, resp, err := dialer.Dial(wsURL, header) diff --git a/get_test.go b/get_test.go index f66be24..8541c72 100644 --- a/get_test.go +++ b/get_test.go @@ -57,6 +57,7 @@ func WSServer(ctx context.Context) { for { token := r.Header.Get("Authorization") + awesome := r.Header.Get("awesome") ek, at, err := GetAttestationData(token) if err != nil { fmt.Println("error", err.Error()) @@ -77,7 +78,7 @@ func WSServer(ctx context.Context) { } writer, _ := conn.NextWriter(websocket.BinaryMessage) - json.NewEncoder(writer).Encode(map[string]string{"foo": "bar"}) + json.NewEncoder(writer).Encode(map[string]string{"foo": "bar", "header": awesome}) } }) @@ -103,13 +104,12 @@ var _ = Describe("GET", func() { WSServer(ctx) - msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1)) + msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1), WithAdditionalHeader("awesome", "content")) result := map[string]interface{}{} json.Unmarshal(msg, &result) Expect(err).ToNot(HaveOccurred()) - Expect(result).To(Equal(map[string]interface{}{"foo": "bar"})) + Expect(result).To(Equal(map[string]interface{}{"foo": "bar", "header": "content"})) }) - }) }) diff --git a/tpm.go b/tpm.go index 4833112..81a3527 100644 --- a/tpm.go +++ b/tpm.go @@ -52,7 +52,7 @@ func ResolveToken(token string, opts ...Option) (bool, string, error) { // GetPubHash returns the EK's pub hash func GetPubHash(opts ...Option) (string, error) { - c := &config{} + c := newConfig() c.apply(opts...) ek, err := getEK(c)