diff --git a/acceptanceTests/Makefile b/acceptanceTests/Makefile index c6d411544..9e7c7f0e0 100644 --- a/acceptanceTests/Makefile +++ b/acceptanceTests/Makefile @@ -1,2 +1,2 @@ test: ## Run acceptance tests. - @go test ./... -timeout 1h + @go test ./... -timeout 1h -v diff --git a/acceptanceTests/config_test.go b/acceptanceTests/config_test.go index 248e56e9e..b0abed868 100644 --- a/acceptanceTests/config_test.go +++ b/acceptanceTests/config_test.go @@ -2,11 +2,12 @@ package acceptanceTests import ( "fmt" - "gopkg.in/yaml.v3" "io/ioutil" "os" "os/exec" "testing" + + "gopkg.in/yaml.v3" ) type tapConfig struct { diff --git a/acceptanceTests/go.mod b/acceptanceTests/go.mod index 0ced4f361..3bba434f2 100644 --- a/acceptanceTests/go.mod +++ b/acceptanceTests/go.mod @@ -3,6 +3,7 @@ module github.com/up9inc/mizu/tests go 1.16 require ( + github.com/gorilla/websocket v1.4.2 github.com/up9inc/mizu/shared v0.0.0 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) diff --git a/acceptanceTests/go.sum b/acceptanceTests/go.sum index 54a1e835d..429c993df 100644 --- a/acceptanceTests/go.sum +++ b/acceptanceTests/go.sum @@ -211,6 +211,7 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= diff --git a/acceptanceTests/tap_test.go b/acceptanceTests/tap_test.go index ebee856de..ee805aa4f 100644 --- a/acceptanceTests/tap_test.go +++ b/acceptanceTests/tap_test.go @@ -66,21 +66,18 @@ func TestTap(t *testing.T) { entriesCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, entriesCount, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, entriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } - - entry := entries[0].(map[string]interface{}) + entry := entries[0] entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, entry["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } @@ -441,38 +438,26 @@ func TestTapRedact(t *testing.T) { redactCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, defaultEntriesCount, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, defaultEntriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } - - firstEntry := entries[0].(map[string]interface{}) + firstEntry := entries[0] entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, firstEntry["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } - data := requestResult.(map[string]interface{})["data"].(map[string]interface{}) - entryJson := data["entry"].(string) + entry := requestResult.(map[string]interface{})["data"].(map[string]interface{}) + request := entry["request"].(map[string]interface{}) - var entry map[string]interface{} - if parseErr := json.Unmarshal([]byte(entryJson), &entry); parseErr != nil { - return fmt.Errorf("failed to parse entry, err: %v", parseErr) - } - - entryRequest := entry["request"].(map[string]interface{}) - entryPayload := entryRequest["payload"].(map[string]interface{}) - entryDetails := entryPayload["details"].(map[string]interface{}) - - headers := entryDetails["_headers"].([]interface{}) + headers := request["_headers"].([]interface{}) for _, headerInterface := range headers { header := headerInterface.(map[string]interface{}) if header["name"].(string) != "User-Agent" { @@ -485,7 +470,7 @@ func TestTapRedact(t *testing.T) { } } - postData := entryDetails["postData"].(map[string]interface{}) + postData := request["postData"].(map[string]interface{}) textDataStr := postData["text"].(string) var textData map[string]string @@ -556,38 +541,26 @@ func TestTapNoRedact(t *testing.T) { redactCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, defaultEntriesCount, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, defaultEntriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } - - firstEntry := entries[0].(map[string]interface{}) + firstEntry := entries[0] entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, firstEntry["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } - data := requestResult.(map[string]interface{})["data"].(map[string]interface{}) - entryJson := data["entry"].(string) + entry := requestResult.(map[string]interface{})["data"].(map[string]interface{}) + request := entry["request"].(map[string]interface{}) - var entry map[string]interface{} - if parseErr := json.Unmarshal([]byte(entryJson), &entry); parseErr != nil { - return fmt.Errorf("failed to parse entry, err: %v", parseErr) - } - - entryRequest := entry["request"].(map[string]interface{}) - entryPayload := entryRequest["payload"].(map[string]interface{}) - entryDetails := entryPayload["details"].(map[string]interface{}) - - headers := entryDetails["_headers"].([]interface{}) + headers := request["_headers"].([]interface{}) for _, headerInterface := range headers { header := headerInterface.(map[string]interface{}) if header["name"].(string) != "User-Agent" { @@ -600,7 +573,7 @@ func TestTapNoRedact(t *testing.T) { } } - postData := entryDetails["postData"].(map[string]interface{}) + postData := request["postData"].(map[string]interface{}) textDataStr := postData["text"].(string) var textData map[string]string @@ -671,38 +644,26 @@ func TestTapRegexMasking(t *testing.T) { redactCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, defaultEntriesCount, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, defaultEntriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } - - firstEntry := entries[0].(map[string]interface{}) + firstEntry := entries[0] entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, firstEntry["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } - data := requestResult.(map[string]interface{})["data"].(map[string]interface{}) - entryJson := data["entry"].(string) + entry := requestResult.(map[string]interface{})["data"].(map[string]interface{}) + request := entry["request"].(map[string]interface{}) - var entry map[string]interface{} - if parseErr := json.Unmarshal([]byte(entryJson), &entry); parseErr != nil { - return fmt.Errorf("failed to parse entry, err: %v", parseErr) - } - - entryRequest := entry["request"].(map[string]interface{}) - entryPayload := entryRequest["payload"].(map[string]interface{}) - entryDetails := entryPayload["details"].(map[string]interface{}) - - postData := entryDetails["postData"].(map[string]interface{}) + postData := request["postData"].(map[string]interface{}) textData := postData["text"].(string) if textData != "[REDACTED]" { @@ -778,38 +739,27 @@ func TestTapIgnoredUserAgents(t *testing.T) { ignoredUserAgentsCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, defaultEntriesCount*2, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, defaultEntriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } for _, entryInterface := range entries { - entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, entryInterface.(map[string]interface{})["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, entryInterface["id"]) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } - data := requestResult.(map[string]interface{})["data"].(map[string]interface{}) - entryJson := data["entry"].(string) + entry := requestResult.(map[string]interface{})["data"].(map[string]interface{}) + request := entry["request"].(map[string]interface{}) - var entry map[string]interface{} - if parseErr := json.Unmarshal([]byte(entryJson), &entry); parseErr != nil { - return fmt.Errorf("failed to parse entry, err: %v", parseErr) - } - - entryRequest := entry["request"].(map[string]interface{}) - entryPayload := entryRequest["payload"].(map[string]interface{}) - entryDetails := entryPayload["details"].(map[string]interface{}) - - entryHeaders := entryDetails["_headers"].([]interface{}) - for _, headerInterface := range entryHeaders { + headers := request["_headers"].([]interface{}) + for _, headerInterface := range headers { header := headerInterface.(map[string]interface{}) if header["name"].(string) != ignoredUserAgentCustomHeader { continue @@ -986,21 +936,18 @@ func TestDaemonSeeTraffic(t *testing.T) { entriesCheckFunc := func() error { timestamp := time.Now().UnixNano() / int64(time.Millisecond) - entriesUrl := fmt.Sprintf("%v/entries?limit=%v&operator=lt×tamp=%v", apiServerUrl, entriesCount, timestamp) - requestResult, requestErr := executeHttpGetRequest(entriesUrl) - if requestErr != nil { - return fmt.Errorf("failed to get entries, err: %v", requestErr) + entries, err := getDBEntries(timestamp, entriesCount, 1*time.Second) + if err != nil { + return err } - - entries := requestResult.([]interface{}) - if len(entries) == 0 { - return fmt.Errorf("unexpected entries result - Expected more than 0 entries") + err = checkEntriesAtLeast(entries, 1) + if err != nil { + return err } - - entry := entries[0].(map[string]interface{}) + entry := entries[0] entryUrl := fmt.Sprintf("%v/entries/%v", apiServerUrl, entry["id"]) - requestResult, requestErr = executeHttpGetRequest(entryUrl) + requestResult, requestErr := executeHttpGetRequest(entryUrl) if requestErr != nil { return fmt.Errorf("failed to get entry, err: %v", requestErr) } diff --git a/acceptanceTests/testsUtils.go b/acceptanceTests/testsUtils.go index 6b6681603..95cbf8077 100644 --- a/acceptanceTests/testsUtils.go +++ b/acceptanceTests/testsUtils.go @@ -11,22 +11,24 @@ import ( "os/exec" "path" "strings" + "sync" "syscall" "testing" "time" + "github.com/gorilla/websocket" "github.com/up9inc/mizu/shared" ) const ( - longRetriesCount = 100 - shortRetriesCount = 10 - defaultApiServerPort = shared.DefaultApiServerPort - defaultNamespaceName = "mizu-tests" - defaultServiceName = "httpbin" - defaultEntriesCount = 50 + longRetriesCount = 100 + shortRetriesCount = 10 + defaultApiServerPort = shared.DefaultApiServerPort + defaultNamespaceName = "mizu-tests" + defaultServiceName = "httpbin" + defaultEntriesCount = 50 waitAfterTapPodsReady = 3 * time.Second - cleanCommandTimeout = 1 * time.Minute + cleanCommandTimeout = 1 * time.Minute ) type PodDescriptor struct { @@ -36,7 +38,7 @@ type PodDescriptor struct { func isPodDescriptorInPodArray(pods []map[string]interface{}, podDescriptor PodDescriptor) bool { for _, pod := range pods { - podNamespace := pod["namespace"].(string) + podNamespace := pod["namespace"].(string) podName := pod["name"].(string) if podDescriptor.Namespace == podNamespace && strings.Contains(podName, podDescriptor.Name) { @@ -82,6 +84,10 @@ func getApiServerUrl(port uint16) string { return fmt.Sprintf("http://localhost:%v", port) } +func getWebSocketUrl(port uint16) string { + return fmt.Sprintf("ws://localhost:%v/ws", port) +} + func getDefaultCommandArgs() []string { setFlag := "--set" telemetry := "telemetry=false" @@ -92,10 +98,11 @@ func getDefaultCommandArgs() []string { } func getDefaultTapCommandArgs() []string { + headless := "--headless" tapCommand := "tap" defaultCmdArgs := getDefaultCommandArgs() - return append([]string{tapCommand}, defaultCmdArgs...) + return append([]string{tapCommand, headless}, defaultCmdArgs...) } func getDefaultTapCommandArgsWithDaemonMode() []string { @@ -256,11 +263,11 @@ func runMizuClean() error { }() select { - case err = <- commandDone: + case err = <-commandDone: if err != nil { return err } - case <- time.After(cleanCommandTimeout): + case <-time.After(cleanCommandTimeout): return errors.New("clean command timed out") } @@ -311,6 +318,77 @@ func daemonCleanup(t *testing.T, viewCmd *exec.Cmd) { } } +// waitTimeout waits for the waitgroup for the specified max timeout. +// Returns true if waiting timed out. +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + channel := make(chan struct{}) + go func() { + defer close(channel) + wg.Wait() + }() + select { + case <-channel: + return false // completed normally + case <-time.After(timeout): + return true // timed out + } +} + +// checkEntriesAtLeast checks whether the number of entries greater than or equal to n +func checkEntriesAtLeast(entries []map[string]interface{}, n int) error { + if len(entries) < n { + return fmt.Errorf("Unexpected entries result - Expected more than %d entries", n-1) + } + return nil +} + +// getDBEntries retrieves the entries from the database before the given timestamp. +// Also limits the results according to the limit parameter. +// Timeout for the WebSocket connection is defined by the timeout parameter. +func getDBEntries(timestamp int64, limit int, timeout time.Duration) (entries []map[string]interface{}, err error) { + query := fmt.Sprintf("timestamp < %d and limit(%d)", timestamp, limit) + webSocketUrl := getWebSocketUrl(defaultApiServerPort) + + var connection *websocket.Conn + connection, _, err = websocket.DefaultDialer.Dial(webSocketUrl, nil) + if err != nil { + return + } + defer connection.Close() + + handleWSConnection := func(wg *sync.WaitGroup) { + defer wg.Done() + for { + _, message, err := connection.ReadMessage() + if err != nil { + return + } + + var data map[string]interface{} + if err = json.Unmarshal([]byte(message), &data); err != nil { + return + } + + if data["messageType"] == "entry" { + entries = append(entries, data) + } + } + } + + err = connection.WriteMessage(websocket.TextMessage, []byte(query)) + if err != nil { + return + } + + var wg sync.WaitGroup + go handleWSConnection(&wg) + wg.Add(1) + + waitTimeout(&wg, timeout) + + return +} + func Contains(slice []string, containsValue string) bool { for _, sliceValue := range slice { if sliceValue == containsValue { diff --git a/cli/cmd/tap.go b/cli/cmd/tap.go index f4ed44605..6ec7105ca 100644 --- a/cli/cmd/tap.go +++ b/cli/cmd/tap.go @@ -113,4 +113,5 @@ func init() { tapCmd.Flags().String(configStructs.EnforcePolicyFile, defaultTapConfig.EnforcePolicyFile, "Yaml file path with policy rules") tapCmd.Flags().String(configStructs.ContractFile, defaultTapConfig.ContractFile, "OAS/Swagger file to validate to monitor the contracts") tapCmd.Flags().Bool(configStructs.DaemonModeTapName, defaultTapConfig.DaemonMode, "Run mizu in daemon mode, detached from the cli") + tapCmd.Flags().Bool(configStructs.HeadlessMode, defaultTapConfig.HeadlessMode, "Enable headless mode.") } diff --git a/cli/cmd/tapRunner.go b/cli/cmd/tapRunner.go index f0238d733..2dbd5f92e 100644 --- a/cli/cmd/tapRunner.go +++ b/cli/cmd/tapRunner.go @@ -4,16 +4,17 @@ import ( "context" "errors" "fmt" - "github.com/up9inc/mizu/cli/cmd/goUtils" "io/ioutil" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/wait" "path" "regexp" "strings" "time" + "github.com/up9inc/mizu/cli/cmd/goUtils" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "github.com/getkin/kin-openapi/openapi3" "github.com/up9inc/mizu/cli/apiserver" "github.com/up9inc/mizu/cli/config" @@ -625,7 +626,9 @@ func watchApiServerPod(ctx context.Context, kubernetesProvider *kubernetes.Provi } logger.Log.Infof("Mizu is available at %s\n", url) - uiUtils.OpenBrowser(url) + if !config.Config.Tap.HeadlessMode { + uiUtils.OpenBrowser(url) + } if err := apiProvider.ReportTappedPods(state.tapperSyncer.CurrentlyTappedPods); err != nil { logger.Log.Debugf("[Error] failed update tapped pods %v", err) } diff --git a/cli/config/configStructs/tapConfig.go b/cli/config/configStructs/tapConfig.go index 38fb187db..c79040e0c 100644 --- a/cli/config/configStructs/tapConfig.go +++ b/cli/config/configStructs/tapConfig.go @@ -3,9 +3,10 @@ package configStructs import ( "errors" "fmt" - "github.com/up9inc/mizu/shared" "regexp" + "github.com/up9inc/mizu/shared" + "github.com/up9inc/mizu/shared/units" ) @@ -22,6 +23,7 @@ const ( EnforcePolicyFile = "traffic-validation-file" ContractFile = "contract" DaemonModeTapName = "daemon" + HeadlessMode = "headless" ) type TapConfig struct { @@ -44,6 +46,7 @@ type TapConfig struct { ApiServerResources shared.Resources `yaml:"api-server-resources"` TapperResources shared.Resources `yaml:"tapper-resources"` DaemonMode bool `yaml:"daemon" default:"false"` + HeadlessMode bool `yaml:"headless" default:"false"` } func (config *TapConfig) PodRegex() *regexp.Regexp {