tpm-helpers/get_test.go

172 lines
4.2 KiB
Go

package tpm_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
. "github.com/kairos-io/tpm-helpers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// Mimics a WS server which accepts TPM Bearer token
func WSServer(ctx context.Context) {
s := http.Server{
Addr: ":8080",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
m := http.NewServeMux()
m.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
conn, _ := upgrader.Upgrade(w, r, nil) // error ignored for sake of simplicity
for {
if err := AuthRequest(r, conn); err != nil {
fmt.Println("error", err.Error())
return
}
awesome := r.Header.Get("awesome")
writer, _ := conn.NextWriter(websocket.BinaryMessage)
json.NewEncoder(writer).Encode(map[string]string{"foo": "bar", "header": awesome})
}
})
s.Handler = m
go s.ListenAndServe()
go func() {
<-ctx.Done()
s.Shutdown(ctx)
}()
}
// Mimics a WS server which accepts TPM Bearer token and receives data
func WSServerReceiver(ctx context.Context, c chan map[string]string) {
s := http.Server{
Addr: ":8080",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
m := http.NewServeMux()
m.HandleFunc("/post", func(w http.ResponseWriter, r *http.Request) {
conn, _ := upgrader.Upgrade(w, r, nil) // error ignored for sake of simplicity
for {
if err := AuthRequest(r, conn); err != nil {
fmt.Println("error", err.Error())
return
}
defer conn.Close()
v := map[string]string{}
err := conn.ReadJSON(&v)
if err != nil {
fmt.Println("error", err.Error())
return
}
c <- v
}
})
s.Handler = m
go s.ListenAndServe()
go func() {
<-ctx.Done()
s.Shutdown(ctx)
}()
}
var _ = Describe("POST", func() {
Context("challenges", func() {
It("posts pubhash", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rec := make(chan map[string]string, 10)
WSServerReceiver(ctx, rec)
conn, err := Connection("http://localhost:8080/post", Emulated, WithSeed(1))
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
err = conn.WriteJSON(map[string]string{"foo": "bar", "header": "foo"})
Expect(err).ToNot(HaveOccurred())
res := <-rec
Expect(res).To(Equal(map[string]string{"foo": "bar", "header": "foo"}))
})
})
})
var _ = Describe("GET", func() {
Context("challenges", func() {
It("fails for permissions", func() {
_, err := Get("http://localhost:8080/test")
Expect(err).To(HaveOccurred())
})
It("gets pubhash", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
WSServer(ctx)
msg, err := Get("http://localhost:8080/test", Emulated, WithSeed(1), WithAdditionalHeader("awesome", "content"))
result := map[string]interface{}{}
json.Unmarshal(msg, &result)
Expect(err).ToNot(HaveOccurred())
Expect(result).To(Equal(map[string]interface{}{"foo": "bar", "header": "content"}))
})
})
})
// This test is meant to be running manually against a
// reg. server with a valid cert.
var _ = Describe("GET", func() {
Context("challenges with a remote endpoint", func() {
regUrl := os.Getenv("REG_URL")
expectedMatches := ContainElement("ros-node-{{ trunc 4 .MachineID }}")
BeforeEach(func() {
if regUrl == "" {
Skip("No remote url passed, skipping suite")
}
})
It("gets pubhash from remote with a public signed CA", func() {
msg, err := Get(regUrl, Emulated, WithSeed(1))
result := map[string]interface{}{}
json.Unmarshal(msg, &result)
Expect(err).ToNot(HaveOccurred())
Expect(result).To(expectedMatches)
})
It("it fails if we specify a custom CA (invalid)", func() {
msg, err := Get(regUrl, Emulated, WithSeed(1), WithCAs([]byte(`dddd`)))
result := map[string]interface{}{}
json.Unmarshal(msg, &result)
Expect(err).To(HaveOccurred())
})
It("it pass if appends to system CA", func() {
msg, err := Get(regUrl, Emulated, WithSeed(1), AppendCustomCAToSystemCA, WithCAs([]byte(`dddd`)))
result := map[string]interface{}{}
json.Unmarshal(msg, &result)
Expect(err).ToNot(HaveOccurred())
Expect(result).To(expectedMatches)
})
})
})