From cbd07a72e5e31ee1976a2af96274051a6dd7be78 Mon Sep 17 00:00:00 2001 From: Rami Berman Date: Thu, 31 Mar 2022 15:55:44 +0300 Subject: [PATCH] Separate socket and basenine logic --- agent/pkg/api/socket_routes.go | 163 ++++----------------- agent/pkg/api/socket_server_handlers.go | 181 +++++++++++++++++------- 2 files changed, 162 insertions(+), 182 deletions(-) diff --git a/agent/pkg/api/socket_routes.go b/agent/pkg/api/socket_routes.go index 964ac231e..eddf160cd 100644 --- a/agent/pkg/api/socket_routes.go +++ b/agent/pkg/api/socket_routes.go @@ -1,19 +1,15 @@ package api import ( - "encoding/json" "fmt" "net/http" "sync" "time" - "github.com/up9inc/mizu/agent/pkg/models" - "github.com/up9inc/mizu/agent/pkg/utils" - "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - basenine "github.com/up9inc/basenine/client/go" - "github.com/up9inc/mizu/shared" + "github.com/up9inc/mizu/agent/pkg/models" + "github.com/up9inc/mizu/agent/pkg/utils" "github.com/up9inc/mizu/shared/logger" tapApi "github.com/up9inc/mizu/tap/api" ) @@ -25,9 +21,9 @@ func InitExtensionsMap(ref map[string]*tapApi.Extension) { } type EventHandlers interface { - WebSocketConnect(socketId int, isTapper bool) + WebSocketConnect(c *gin.Context, socketId int, isTapper bool) WebSocketDisconnect(socketId int, isTapper bool) - WebSocketMessage(socketId int, message []byte) + WebSocketMessage(socketId int, isTapper bool, message []byte) } type SocketConnection struct { @@ -62,11 +58,11 @@ func init() { func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { SocketGetBrowserHandler = func(c *gin.Context) { - websocketHandler(c.Writer, c.Request, eventHandlers, false) + websocketHandler(c, eventHandlers, false) } SocketGetTapperHandler = func(c *gin.Context) { - websocketHandler(c.Writer, c.Request, eventHandlers, true) + websocketHandler(c, eventHandlers, true) } app.GET("/ws", func(c *gin.Context) { @@ -78,10 +74,10 @@ func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { }) } -func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) { - ws, err := websocketUpgrader.Upgrade(w, r, nil) +func websocketHandler(c *gin.Context, eventHandlers EventHandlers, isTapper bool) { + ws, err := websocketUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - logger.Log.Errorf("Failed to set websocket upgrade: %v", err) + logger.Log.Errorf("failed to set websocket upgrade: %v", err) return } @@ -93,30 +89,11 @@ func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers Even websocketIdsLock.Unlock() - var connection *basenine.Connection - var isQuerySet bool - - // `!isTapper` means it's a connection from the web UI - if !isTapper { - connection, err = basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) - if err != nil { - logger.Log.Errorf("Failed to establish a connection to Basenine: %v", err) - socketCleanup(socketId, connectedWebsockets[socketId]) - return - } - } - - data := make(chan []byte) - meta := make(chan []byte) - defer func() { socketCleanup(socketId, connectedWebsockets[socketId]) - data <- []byte(basenine.CloseChannel) - meta <- []byte(basenine.CloseChannel) - connection.Close() }() - eventHandlers.WebSocketConnect(socketId, isTapper) + eventHandlers.WebSocketConnect(c, socketId, isTapper) startTimeBytes, _ := models.CreateWebsocketStartTimeMessage(utils.StartTime) @@ -124,127 +101,32 @@ func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers Even logger.Log.Error(err) } - var params WebSocketParams - for { _, msg, err := ws.ReadMessage() if err != nil { if _, ok := err.(*websocket.CloseError); ok { - logger.Log.Debugf("Received websocket close message, socket id: %d", socketId) + logger.Log.Debugf("received websocket close message, socket id: %d", socketId) } else { - logger.Log.Errorf("Error reading message, socket id: %d, error: %v", socketId, err) + logger.Log.Errorf("error reading message, socket id: %d, error: %v", socketId, err) } break } - if !isTapper && !isQuerySet { - if err := json.Unmarshal(msg, ¶ms); err != nil { - logger.Log.Errorf("Error unmarshalling parameters: %v", socketId, err) - continue - } - - query := params.Query - err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) - if err != nil { - toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ - Type: "error", - AutoClose: 5000, - Text: fmt.Sprintf("Syntax error: %s", err.Error()), - }) - if err := SendToSocket(socketId, toastBytes); err != nil { - logger.Log.Error(err) - } - break - } - - isQuerySet = true - - handleDataChannel := func(c *basenine.Connection, data chan []byte) { - for { - bytes := <-data - - if string(bytes) == basenine.CloseChannel { - return - } - - var entry *tapApi.Entry - err = json.Unmarshal(bytes, &entry) - if err != nil { - logger.Log.Debugf("Error unmarshalling entry: %v", err.Error()) - continue - } - - var message []byte - if params.EnableFullEntries { - message, _ = models.CreateFullEntryWebSocketMessage(entry) - } else { - extension := extensionsMap[entry.Protocol.Name] - base := extension.Dissector.Summarize(entry) - message, _ = models.CreateBaseEntryWebSocketMessage(base) - } - - if err := SendToSocket(socketId, message); err != nil { - logger.Log.Error(err) - } - } - } - - handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { - for { - bytes := <-meta - - if string(bytes) == basenine.CloseChannel { - return - } - - var metadata *basenine.Metadata - err = json.Unmarshal(bytes, &metadata) - if err != nil { - logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) - continue - } - - metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) - if err := SendToSocket(socketId, metadataBytes); err != nil { - logger.Log.Error(err) - } - } - } - - go handleDataChannel(connection, data) - go handleMetaChannel(connection, meta) - - connection.Query(query, data, meta) - } else { - eventHandlers.WebSocketMessage(socketId, msg) - } + eventHandlers.WebSocketMessage(socketId, isTapper, msg) } } -func socketCleanup(socketId int, socketConnection *SocketConnection) { - err := socketConnection.connection.Close() - if err != nil { - logger.Log.Errorf("Error closing socket connection for socket id %d: %v", socketId, err) - } - - websocketIdsLock.Lock() - connectedWebsockets[socketId] = nil - websocketIdsLock.Unlock() - - socketConnection.eventHandlers.WebSocketDisconnect(socketId, socketConnection.isTapper) -} - func SendToSocket(socketId int, message []byte) error { socketObj := connectedWebsockets[socketId] if socketObj == nil { - return fmt.Errorf("Socket %v is disconnected", socketId) + return fmt.Errorf("socket %v is disconnected", socketId) } var sent = false time.AfterFunc(time.Second*5, func() { if !sent { - logger.Log.Error("Socket timed out") + logger.Log.Error("socket timed out") socketCleanup(socketId, socketObj) } }) @@ -255,7 +137,20 @@ func SendToSocket(socketId int, message []byte) error { sent = true if err != nil { - return fmt.Errorf("Failed to write message to socket %v, err: %w", socketId, err) + return fmt.Errorf("failed to write message to socket %v, err: %w", socketId, err) } return nil } + +func socketCleanup(socketId int, socketConnection *SocketConnection) { + err := socketConnection.connection.Close() + if err != nil { + logger.Log.Errorf("error closing socket connection for socket id %d: %v", socketId, err) + } + + websocketIdsLock.Lock() + connectedWebsockets[socketId] = nil + websocketIdsLock.Unlock() + + socketConnection.eventHandlers.WebSocketDisconnect(socketId, socketConnection.isTapper) +} diff --git a/agent/pkg/api/socket_server_handlers.go b/agent/pkg/api/socket_server_handlers.go index d3b604e32..7f655310e 100644 --- a/agent/pkg/api/socket_server_handlers.go +++ b/agent/pkg/api/socket_server_handlers.go @@ -1,12 +1,14 @@ package api import ( + "context" "encoding/json" "fmt" "sync" + "github.com/gin-gonic/gin" + basenine "github.com/up9inc/basenine/client/go" "github.com/up9inc/mizu/agent/pkg/models" - "github.com/up9inc/mizu/agent/pkg/providers" "github.com/up9inc/mizu/agent/pkg/providers/tappedPods" "github.com/up9inc/mizu/agent/pkg/providers/tappers" "github.com/up9inc/mizu/agent/pkg/up9" @@ -17,7 +19,11 @@ import ( "github.com/up9inc/mizu/shared/logger" ) -var browserClientSocketUUIDs = make([]int, 0) +type BrowserClient struct { + dataStreamCancelFunc context.CancelFunc +} + +var browserClients = make(map[int]*BrowserClient, 0) var tapperClientSocketUUIDs = make([]int, 0) var socketListLock = sync.Mutex{} @@ -30,7 +36,7 @@ func init() { go up9.UpdateAnalyzeStatus(BroadcastToBrowserClients) } -func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) { +func (h *RoutesEventHandlers) WebSocketConnect(_ *gin.Context, socketId int, isTapper bool) { if isTapper { logger.Log.Infof("Websocket event - Tapper connected, socket ID: %d", socketId) tappers.Connected() @@ -45,7 +51,7 @@ func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) { logger.Log.Infof("Websocket event - Browser socket connected, socket ID: %d", socketId) socketListLock.Lock() - browserClientSocketUUIDs = append(browserClientSocketUUIDs, socketId) + browserClients[socketId] = &BrowserClient{} socketListLock.Unlock() BroadcastTappedPodsStatus() @@ -63,13 +69,16 @@ func (h *RoutesEventHandlers) WebSocketDisconnect(socketId int, isTapper bool) { } else { logger.Log.Infof("Websocket event - Browser socket disconnected, socket ID: %d", socketId) socketListLock.Lock() - removeSocketUUIDFromBrowserSlice(socketId) + if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc != nil { + browserClients[socketId].dataStreamCancelFunc() + } + delete(browserClients, socketId) socketListLock.Unlock() } } func BroadcastToBrowserClients(message []byte) { - for _, socketId := range browserClientSocketUUIDs { + for socketId, _ := range browserClients { go func(socketId int) { if err := SendToSocket(socketId, message); err != nil { logger.Log.Error(err) @@ -88,7 +97,124 @@ func BroadcastToTapperClients(message []byte) { } } -func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { +func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, message []byte) { + if isTapper { + h.handleTapperIncomingMessage(message) + } else { + // we initiate the basenine stream after the first websocket message we receive (it contains the entry query), we then store a cancelfunc to later cancel this stream + if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc == nil { + cancelFunc, err := startStreamingBasenineEntriesToSocket(socketId, message) + if err != nil { + logger.Log.Errorf("error initializing basenine stream for browser socket %d %+v", socketId, err) + } else { + browserClients[socketId].dataStreamCancelFunc = cancelFunc + } + } + } +} + +func startStreamingBasenineEntriesToSocket(socketId int, message []byte) (context.CancelFunc, error) { + var params WebSocketParams + if err := json.Unmarshal(message, ¶ms); err != nil { + logger.Log.Errorf("error unmarshalling parameters: %v", socketId, err) + return nil, err + } + + var connection *basenine.Connection + + connection, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) + if err != nil { + logger.Log.Errorf("failed to establish a connection to Basenine: %v", err) + socketCleanup(socketId, connectedWebsockets[socketId]) + return nil, err + } + + data := make(chan []byte) + meta := make(chan []byte) + + ctx, cancel := context.WithCancel(context.Background()) + + query := params.Query + err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) + if err != nil { + toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ + Type: "error", + AutoClose: 5000, + Text: fmt.Sprintf("syntax error: %s", err.Error()), + }) + if err := SendToSocket(socketId, toastBytes); err != nil { + logger.Log.Error(err) + } + } + + handleDataChannel := func(c *basenine.Connection, data chan []byte) { + for { + bytes := <-data + + if string(bytes) == basenine.CloseChannel { + return + } + + var entry *tapApi.Entry + err = json.Unmarshal(bytes, &entry) + if err != nil { + logger.Log.Debugf("error unmarshalling entry: %v", err.Error()) + continue + } + + var message []byte + if params.EnableFullEntries { + message, _ = models.CreateFullEntryWebSocketMessage(entry) + } else { + extension := extensionsMap[entry.Protocol.Name] + base := extension.Dissector.Summarize(entry) + message, _ = models.CreateBaseEntryWebSocketMessage(base) + } + + if err := SendToSocket(socketId, message); err != nil { + logger.Log.Error(err) + } + } + } + + handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { + for { + bytes := <-meta + + if string(bytes) == basenine.CloseChannel { + return + } + + var metadata *basenine.Metadata + err = json.Unmarshal(bytes, &metadata) + if err != nil { + logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) + continue + } + + metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) + if err := SendToSocket(socketId, metadataBytes); err != nil { + logger.Log.Error(err) + } + } + } + + go handleDataChannel(connection, data) + go handleMetaChannel(connection, meta) + + connection.Query(query, data, meta) + + go func() { + <-ctx.Done() + data <- []byte(basenine.CloseChannel) + meta <- []byte(basenine.CloseChannel) + connection.Close() + }() + + return cancel, nil +} + +func (h *RoutesEventHandlers) handleTapperIncomingMessage(message []byte) { var socketMessageBase shared.WebSocketMessageMetadata err := json.Unmarshal(message, &socketMessageBase) if err != nil { @@ -112,53 +238,12 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) { } else { BroadcastToBrowserClients(message) } - case shared.WebsocketMessageTypeOutboundLink: - var outboundLinkMessage models.WebsocketOutboundLinkMessage - err := json.Unmarshal(message, &outboundLinkMessage) - if err != nil { - logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err) - } else { - handleTLSLink(outboundLinkMessage) - } default: logger.Log.Infof("Received socket message of type %s for which no handlers are defined", socketMessageBase.MessageType) } } } -func handleTLSLink(outboundLinkMessage models.WebsocketOutboundLinkMessage) { - resolvedNameObject := k8sResolver.Resolve(outboundLinkMessage.Data.DstIP) - if resolvedNameObject != nil { - outboundLinkMessage.Data.DstIP = resolvedNameObject.FullAddress - } else if outboundLinkMessage.Data.SuggestedResolvedName != "" { - outboundLinkMessage.Data.DstIP = outboundLinkMessage.Data.SuggestedResolvedName - } - cacheKey := fmt.Sprintf("%s -> %s:%d", outboundLinkMessage.Data.Src, outboundLinkMessage.Data.DstIP, outboundLinkMessage.Data.DstPort) - _, isInCache := providers.RecentTLSLinks.Get(cacheKey) - if isInCache { - return - } else { - providers.RecentTLSLinks.SetDefault(cacheKey, outboundLinkMessage.Data) - } - marshaledMessage, err := json.Marshal(outboundLinkMessage) - if err != nil { - logger.Log.Errorf("Error marshaling outbound link message for broadcasting: %v", err) - } else { - logger.Log.Errorf("Broadcasting outboundlink message %s", string(marshaledMessage)) - BroadcastToBrowserClients(marshaledMessage) - } -} - -func removeSocketUUIDFromBrowserSlice(uuidToRemove int) { - newUUIDSlice := make([]int, 0, len(browserClientSocketUUIDs)) - for _, uuid := range browserClientSocketUUIDs { - if uuid != uuidToRemove { - newUUIDSlice = append(newUUIDSlice, uuid) - } - } - browserClientSocketUUIDs = newUUIDSlice -} - func removeSocketUUIDFromTapperSlice(uuidToRemove int) { newUUIDSlice := make([]int, 0, len(tapperClientSocketUUIDs)) for _, uuid := range tapperClientSocketUUIDs {