From 76a6a77a140b96f89b2e656accbeeca70491d3b0 Mon Sep 17 00:00:00 2001 From: RamiBerm <54766858+RamiBerm@users.noreply.github.com> Date: Mon, 4 Apr 2022 17:33:53 +0300 Subject: [PATCH] Refactor ws (#961) * Separate socket and basenine logic * WIP * Update socket_server_handlers.go * Update socket_data_streamer.go and socket_server_handlers.go * Update socket_server_handlers.go * Merge branch 'develop' into refactor_ws # Please enter a commit message to explain why this merge is necessary, # especially if it merges an updated upstream into a topic branch. # # Lines starting with '#' will be ignored, and an empty message aborts # the commit. * empty commit for actions * empty commit for actions * commit for actions * Revert "commit for actions" This reverts commit 8ba2ecf7d3ed7af672ca8d6552683d00471483d9. Co-authored-by: RoyUP9 <87927115+RoyUP9@users.noreply.github.com> --- agent/main.go | 2 + .../api/entry_streamer_socket_connector.go | 57 ++++++ agent/pkg/api/socket_data_streamer.go | 92 ++++++++++ agent/pkg/api/socket_routes.go | 163 ++++-------------- agent/pkg/api/socket_server_handlers.go | 95 +++++----- agent/pkg/dependency/type_names.go | 2 + 6 files changed, 226 insertions(+), 185 deletions(-) create mode 100644 agent/pkg/api/entry_streamer_socket_connector.go create mode 100644 agent/pkg/api/socket_data_streamer.go diff --git a/agent/main.go b/agent/main.go index c453fc673..1040bf888 100644 --- a/agent/main.go +++ b/agent/main.go @@ -373,4 +373,6 @@ func initializeDependencies() { dependency.RegisterGenerator(dependency.ServiceMapGeneratorDependency, func() interface{} { return servicemap.GetDefaultServiceMapInstance() }) dependency.RegisterGenerator(dependency.OasGeneratorDependency, func() interface{} { return oas.GetDefaultOasGeneratorInstance(nil) }) dependency.RegisterGenerator(dependency.EntriesProvider, func() interface{} { return &entries.BasenineEntriesProvider{} }) + dependency.RegisterGenerator(dependency.EntriesSocketStreamer, func() interface{} { return &api.BasenineEntryStreamer{} }) + dependency.RegisterGenerator(dependency.EntryStreamerSocketConnector, func() interface{} { return &api.DefaultEntryStreamerSocketConnector{} }) } diff --git a/agent/pkg/api/entry_streamer_socket_connector.go b/agent/pkg/api/entry_streamer_socket_connector.go new file mode 100644 index 000000000..c5e2e7d0d --- /dev/null +++ b/agent/pkg/api/entry_streamer_socket_connector.go @@ -0,0 +1,57 @@ +package api + +import ( + "fmt" + + basenine "github.com/up9inc/basenine/client/go" + "github.com/up9inc/mizu/agent/pkg/models" + "github.com/up9inc/mizu/shared/logger" + tapApi "github.com/up9inc/mizu/tap/api" +) + +type EntryStreamerSocketConnector interface { + SendEntry(socketId int, entry *tapApi.Entry, params *WebSocketParams) + SendMetadata(socketId int, metadata *basenine.Metadata) + SendToastError(socketId int, err error) + CleanupSocket(socketId int) +} + +type DefaultEntryStreamerSocketConnector struct{} + +func (e *DefaultEntryStreamerSocketConnector) SendEntry(socketId int, entry *tapApi.Entry, params *WebSocketParams) { + var message []byte + if params.EnableFullEntries { + message, _ = models.CreateFullEntryWebSocketMessage(entry) + } else { + extension := extensionsMap[entry.Protocol.Name] + base := extension.Dissector.Summarize(entry) + message, _ = models.CreateBaseEntryWebSocketMessage(base) + } + + if err := SendToSocket(socketId, message); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) SendMetadata(socketId int, metadata *basenine.Metadata) { + metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) + if err := SendToSocket(socketId, metadataBytes); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) SendToastError(socketId int, err error) { + toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ + Type: "error", + AutoClose: 5000, + Text: fmt.Sprintf("Syntax error: %s", err.Error()), + }) + if err := SendToSocket(socketId, toastBytes); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) CleanupSocket(socketId int) { + socketObj := connectedWebsockets[socketId] + socketCleanup(socketId, socketObj) +} diff --git a/agent/pkg/api/socket_data_streamer.go b/agent/pkg/api/socket_data_streamer.go new file mode 100644 index 000000000..0fde336d5 --- /dev/null +++ b/agent/pkg/api/socket_data_streamer.go @@ -0,0 +1,92 @@ +package api + +import ( + "context" + "encoding/json" + + basenine "github.com/up9inc/basenine/client/go" + "github.com/up9inc/mizu/agent/pkg/dependency" + "github.com/up9inc/mizu/shared" + "github.com/up9inc/mizu/shared/logger" + tapApi "github.com/up9inc/mizu/tap/api" +) + +type EntryStreamer interface { + Get(ctx context.Context, socketId int, params *WebSocketParams) error +} + +type BasenineEntryStreamer struct{} + +func (e *BasenineEntryStreamer) Get(ctx context.Context, socketId int, params *WebSocketParams) error { + var connection *basenine.Connection + + entryStreamerSocketConnector := dependency.GetInstance(dependency.EntryStreamerSocketConnector).(EntryStreamerSocketConnector) + + connection, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) + if err != nil { + logger.Log.Errorf("failed to establish a connection to Basenine: %v", err) + entryStreamerSocketConnector.CleanupSocket(socketId) + return err + } + + data := make(chan []byte) + meta := make(chan []byte) + + query := params.Query + err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) + if err != nil { + entryStreamerSocketConnector.SendToastError(socketId, err) + } + + handleDataChannel := func(c *basenine.Connection, data chan []byte) { + for { + bytes := <-data + + if string(bytes) == basenine.CloseChannel { + return + } + + var entry *tapApi.Entry + err = json.Unmarshal(bytes, &entry) + if err != nil { + logger.Log.Debugf("error unmarshalling entry: %v", err.Error()) + continue + } + + entryStreamerSocketConnector.SendEntry(socketId, entry, params) + } + } + + handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { + for { + bytes := <-meta + + if string(bytes) == basenine.CloseChannel { + return + } + + var metadata *basenine.Metadata + err = json.Unmarshal(bytes, &metadata) + if err != nil { + logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) + continue + } + + entryStreamerSocketConnector.SendMetadata(socketId, metadata) + } + } + + go handleDataChannel(connection, data) + go handleMetaChannel(connection, meta) + + connection.Query(query, data, meta) + + go func() { + <-ctx.Done() + data <- []byte(basenine.CloseChannel) + meta <- []byte(basenine.CloseChannel) + connection.Close() + }() + + return nil +} diff --git a/agent/pkg/api/socket_routes.go b/agent/pkg/api/socket_routes.go index 964ac231e..eddf160cd 100644 --- a/agent/pkg/api/socket_routes.go +++ b/agent/pkg/api/socket_routes.go @@ -1,19 +1,15 @@ package api import ( - "encoding/json" "fmt" "net/http" "sync" "time" - "github.com/up9inc/mizu/agent/pkg/models" - "github.com/up9inc/mizu/agent/pkg/utils" - "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - basenine "github.com/up9inc/basenine/client/go" - "github.com/up9inc/mizu/shared" + "github.com/up9inc/mizu/agent/pkg/models" + "github.com/up9inc/mizu/agent/pkg/utils" "github.com/up9inc/mizu/shared/logger" tapApi "github.com/up9inc/mizu/tap/api" ) @@ -25,9 +21,9 @@ func InitExtensionsMap(ref map[string]*tapApi.Extension) { } type EventHandlers interface { - WebSocketConnect(socketId int, isTapper bool) + WebSocketConnect(c *gin.Context, socketId int, isTapper bool) WebSocketDisconnect(socketId int, isTapper bool) - WebSocketMessage(socketId int, message []byte) + WebSocketMessage(socketId int, isTapper bool, message []byte) } type SocketConnection struct { @@ -62,11 +58,11 @@ func init() { func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { SocketGetBrowserHandler = func(c *gin.Context) { - websocketHandler(c.Writer, c.Request, eventHandlers, false) + websocketHandler(c, eventHandlers, false) } SocketGetTapperHandler = func(c *gin.Context) { - websocketHandler(c.Writer, c.Request, eventHandlers, true) + websocketHandler(c, eventHandlers, true) } app.GET("/ws", func(c *gin.Context) { @@ -78,10 +74,10 @@ func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { }) } -func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) { - ws, err := websocketUpgrader.Upgrade(w, r, nil) +func websocketHandler(c *gin.Context, eventHandlers EventHandlers, isTapper bool) { + ws, err := websocketUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - logger.Log.Errorf("Failed to set websocket upgrade: %v", err) + logger.Log.Errorf("failed to set websocket upgrade: %v", err) return } @@ -93,30 +89,11 @@ func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers Even websocketIdsLock.Unlock() - var connection *basenine.Connection - var isQuerySet bool - - // `!isTapper` means it's a connection from the web UI - if !isTapper { - connection, err = basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) - if err != nil { - logger.Log.Errorf("Failed to establish a connection to Basenine: %v", err) - socketCleanup(socketId, connectedWebsockets[socketId]) - return - } - } - - data := make(chan []byte) - meta := make(chan []byte) - defer func() { socketCleanup(socketId, connectedWebsockets[socketId]) - data <- []byte(basenine.CloseChannel) - meta <- []byte(basenine.CloseChannel) - connection.Close() }() - eventHandlers.WebSocketConnect(socketId, isTapper) + eventHandlers.WebSocketConnect(c, socketId, isTapper) startTimeBytes, _ := models.CreateWebsocketStartTimeMessage(utils.StartTime) @@ -124,127 +101,32 @@ func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers Even logger.Log.Error(err) } - var params WebSocketParams - for { _, msg, err := ws.ReadMessage() if err != nil { if _, ok := err.(*websocket.CloseError); ok { - logger.Log.Debugf("Received websocket close message, socket id: %d", socketId) + logger.Log.Debugf("received websocket close message, socket id: %d", socketId) } else { - logger.Log.Errorf("Error reading message, socket id: %d, error: %v", socketId, err) + logger.Log.Errorf("error reading message, socket id: %d, error: %v", socketId, err) } break } - if !isTapper && !isQuerySet { - if err := json.Unmarshal(msg, ¶ms); err != nil { - logger.Log.Errorf("Error unmarshalling parameters: %v", socketId, err) - continue - } - - query := params.Query - err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) - if err != nil { - toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ - Type: "error", - AutoClose: 5000, - Text: fmt.Sprintf("Syntax error: %s", err.Error()), - }) - if err := SendToSocket(socketId, toastBytes); err != nil { - logger.Log.Error(err) - } - break - } - - isQuerySet = true - - handleDataChannel := func(c *basenine.Connection, data chan []byte) { - for { - bytes := <-data - - if string(bytes) == basenine.CloseChannel { - return - } - - var entry *tapApi.Entry - err = json.Unmarshal(bytes, &entry) - if err != nil { - logger.Log.Debugf("Error unmarshalling entry: %v", err.Error()) - continue - } - - var message []byte - if params.EnableFullEntries { - message, _ = models.CreateFullEntryWebSocketMessage(entry) - } else { - extension := extensionsMap[entry.Protocol.Name] - base := extension.Dissector.Summarize(entry) - message, _ = models.CreateBaseEntryWebSocketMessage(base) - } - - if err := SendToSocket(socketId, message); err != nil { - logger.Log.Error(err) - } - } - } - - handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { - for { - bytes := <-meta - - if string(bytes) == basenine.CloseChannel { - return - } - - var metadata *basenine.Metadata - err = json.Unmarshal(bytes, &metadata) - if err != nil { - logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) - continue - } - - metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) - if err := SendToSocket(socketId, metadataBytes); err != nil { - logger.Log.Error(err) - } - } - } - - go handleDataChannel(connection, data) - go handleMetaChannel(connection, meta) - - connection.Query(query, data, meta) - } else { - eventHandlers.WebSocketMessage(socketId, msg) - } + eventHandlers.WebSocketMessage(socketId, isTapper, msg) } } -func socketCleanup(socketId int, socketConnection *SocketConnection) { - err := socketConnection.connection.Close() - if err != nil { - logger.Log.Errorf("Error closing socket connection for socket id %d: %v", socketId, err) - } - - websocketIdsLock.Lock() - connectedWebsockets[socketId] = nil - websocketIdsLock.Unlock() - - socketConnection.eventHandlers.WebSocketDisconnect(socketId, socketConnection.isTapper) -} - func SendToSocket(socketId int, message []byte) error { socketObj := connectedWebsockets[socketId] if socketObj == nil { - return fmt.Errorf("Socket %v is disconnected", socketId) + return fmt.Errorf("socket %v is disconnected", socketId) } var sent = false time.AfterFunc(time.Second*5, func() { if !sent { - logger.Log.Error("Socket timed out") + logger.Log.Error("socket timed out") socketCleanup(socketId, socketObj) } }) @@ -255,7 +137,20 @@ func SendToSocket(socketId int, message []byte) error { sent = true if err != nil { - return fmt.Errorf("Failed to write message to socket %v, err: %w", socketId, err) + return fmt.Errorf("failed to write message to socket %v, err: %w", socketId, err) } return nil } + +func socketCleanup(socketId int, socketConnection *SocketConnection) { + err := socketConnection.connection.Close() + if err != nil { + logger.Log.Errorf("error closing socket connection for socket id %d: %v", socketId, err) + } + + websocketIdsLock.Lock() + connectedWebsockets[socketId] = nil + websocketIdsLock.Unlock() + + socketConnection.eventHandlers.WebSocketDisconnect(socketId, socketConnection.isTapper) +} diff --git a/agent/pkg/api/socket_server_handlers.go b/agent/pkg/api/socket_server_handlers.go index d3b604e32..a3d5808db 100644 --- a/agent/pkg/api/socket_server_handlers.go +++ b/agent/pkg/api/socket_server_handlers.go @@ -1,12 +1,13 @@ package api import ( + "context" "encoding/json" - "fmt" "sync" + "github.com/gin-gonic/gin" + "github.com/up9inc/mizu/agent/pkg/dependency" "github.com/up9inc/mizu/agent/pkg/models" - "github.com/up9inc/mizu/agent/pkg/providers" "github.com/up9inc/mizu/agent/pkg/providers/tappedPods" "github.com/up9inc/mizu/agent/pkg/providers/tappers" "github.com/up9inc/mizu/agent/pkg/up9" @@ -17,7 +18,11 @@ import ( "github.com/up9inc/mizu/shared/logger" ) -var browserClientSocketUUIDs = make([]int, 0) +type BrowserClient struct { + dataStreamCancelFunc context.CancelFunc +} + +var browserClients = make(map[int]*BrowserClient, 0) var tapperClientSocketUUIDs = make([]int, 0) var socketListLock = sync.Mutex{} @@ -30,7 +35,7 @@ func init() { go up9.UpdateAnalyzeStatus(BroadcastToBrowserClients) } -func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) { +func (h *RoutesEventHandlers) WebSocketConnect(_ *gin.Context, socketId int, isTapper bool) { if isTapper { logger.Log.Infof("Websocket event - Tapper connected, socket ID: %d", socketId) tappers.Connected() @@ -45,7 +50,7 @@ func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) { logger.Log.Infof("Websocket event - Browser socket connected, socket ID: %d", socketId) socketListLock.Lock() - browserClientSocketUUIDs = append(browserClientSocketUUIDs, socketId) + browserClients[socketId] = &BrowserClient{} socketListLock.Unlock() BroadcastTappedPodsStatus() @@ -63,13 +68,16 @@ func (h *RoutesEventHandlers) WebSocketDisconnect(socketId int, isTapper bool) { } else { logger.Log.Infof("Websocket event - Browser socket disconnected, socket ID: %d", socketId) socketListLock.Lock() - removeSocketUUIDFromBrowserSlice(socketId) + if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc != nil { + browserClients[socketId].dataStreamCancelFunc() + } + delete(browserClients, socketId) socketListLock.Unlock() } } func BroadcastToBrowserClients(message []byte) { - for _, socketId := range browserClientSocketUUIDs { + for socketId := range browserClients { go func(socketId int) { if err := SendToSocket(socketId, message); err != nil { logger.Log.Error(err) @@ -88,7 +96,33 @@ func BroadcastToTapperClients(message []byte) { } } -func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { +func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, message []byte) { + if isTapper { + HandleTapperIncomingMessage(message, h.SocketOutChannel, BroadcastToBrowserClients) + } else { + // we initiate the basenine stream after the first websocket message we receive (it contains the entry query), we then store a cancelfunc to later cancel this stream + if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc == nil { + var params WebSocketParams + if err := json.Unmarshal(message, ¶ms); err != nil { + logger.Log.Errorf("Error: %v", socketId, err) + return + } + + entriesStreamer := dependency.GetInstance(dependency.EntriesSocketStreamer).(EntryStreamer) + ctx, cancelFunc := context.WithCancel(context.Background()) + err := entriesStreamer.Get(ctx, socketId, ¶ms) + + if err != nil { + logger.Log.Errorf("error initializing basenine stream for browser socket %d %+v", socketId, err) + cancelFunc() + } else { + browserClients[socketId].dataStreamCancelFunc = cancelFunc + } + } + } +} + +func HandleTapperIncomingMessage(message []byte, socketOutChannel chan<- *tapApi.OutputChannelItem, broadcastMessageFunc func([]byte)) { var socketMessageBase shared.WebSocketMessageMetadata err := json.Unmarshal(message, &socketMessageBase) if err != nil { @@ -102,7 +136,7 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err) } else { // NOTE: This is where the message comes back from the intermediate WebSocket to code. - h.SocketOutChannel <- tappedEntryMessage.Data + socketOutChannel <- tappedEntryMessage.Data } case shared.WebSocketMessageTypeUpdateStatus: var statusMessage shared.WebSocketStatusMessage @@ -110,15 +144,7 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { if err != nil { logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err) } else { - BroadcastToBrowserClients(message) - } - case shared.WebsocketMessageTypeOutboundLink: - var outboundLinkMessage models.WebsocketOutboundLinkMessage - err := json.Unmarshal(message, &outboundLinkMessage) - if err != nil { - logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err) - } else { - handleTLSLink(outboundLinkMessage) + broadcastMessageFunc(message) } default: logger.Log.Infof("Received socket message of type %s for which no handlers are defined", socketMessageBase.MessageType) @@ -126,39 +152,6 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { } } -func handleTLSLink(outboundLinkMessage models.WebsocketOutboundLinkMessage) { - resolvedNameObject := k8sResolver.Resolve(outboundLinkMessage.Data.DstIP) - if resolvedNameObject != nil { - outboundLinkMessage.Data.DstIP = resolvedNameObject.FullAddress - } else if outboundLinkMessage.Data.SuggestedResolvedName != "" { - outboundLinkMessage.Data.DstIP = outboundLinkMessage.Data.SuggestedResolvedName - } - cacheKey := fmt.Sprintf("%s -> %s:%d", outboundLinkMessage.Data.Src, outboundLinkMessage.Data.DstIP, outboundLinkMessage.Data.DstPort) - _, isInCache := providers.RecentTLSLinks.Get(cacheKey) - if isInCache { - return - } else { - providers.RecentTLSLinks.SetDefault(cacheKey, outboundLinkMessage.Data) - } - marshaledMessage, err := json.Marshal(outboundLinkMessage) - if err != nil { - logger.Log.Errorf("Error marshaling outbound link message for broadcasting: %v", err) - } else { - logger.Log.Errorf("Broadcasting outboundlink message %s", string(marshaledMessage)) - BroadcastToBrowserClients(marshaledMessage) - } -} - -func removeSocketUUIDFromBrowserSlice(uuidToRemove int) { - newUUIDSlice := make([]int, 0, len(browserClientSocketUUIDs)) - for _, uuid := range browserClientSocketUUIDs { - if uuid != uuidToRemove { - newUUIDSlice = append(newUUIDSlice, uuid) - } - } - browserClientSocketUUIDs = newUUIDSlice -} - func removeSocketUUIDFromTapperSlice(uuidToRemove int) { newUUIDSlice := make([]int, 0, len(tapperClientSocketUUIDs)) for _, uuid := range tapperClientSocketUUIDs { diff --git a/agent/pkg/dependency/type_names.go b/agent/pkg/dependency/type_names.go index d886ce24e..c7e91decb 100644 --- a/agent/pkg/dependency/type_names.go +++ b/agent/pkg/dependency/type_names.go @@ -6,4 +6,6 @@ const ( ServiceMapGeneratorDependency = "ServiceMapGeneratorDependency" OasGeneratorDependency = "OasGeneratorDependency" EntriesProvider = "EntriesProvider" + EntriesSocketStreamer = "EntriesSocketStreamer" + EntryStreamerSocketConnector = "EntryStreamerSocketConnector" )