Refactor ws (#961)

* Separate socket and basenine logic

* WIP

* Update socket_server_handlers.go

* Update socket_data_streamer.go and socket_server_handlers.go

* Update socket_server_handlers.go

* Merge branch 'develop' into refactor_ws
# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.

* empty commit for actions

* empty commit for actions

* commit for actions

* Revert "commit for actions"

This reverts commit 8ba2ecf7d3.

Co-authored-by: RoyUP9 <87927115+RoyUP9@users.noreply.github.com>
This commit is contained in:
RamiBerm
2022-04-04 17:33:53 +03:00
committed by GitHub
parent 2bfc523bbc
commit 76a6a77a14
6 changed files with 226 additions and 185 deletions

View File

@@ -1,12 +1,13 @@
package api
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/gin-gonic/gin"
"github.com/up9inc/mizu/agent/pkg/dependency"
"github.com/up9inc/mizu/agent/pkg/models"
"github.com/up9inc/mizu/agent/pkg/providers"
"github.com/up9inc/mizu/agent/pkg/providers/tappedPods"
"github.com/up9inc/mizu/agent/pkg/providers/tappers"
"github.com/up9inc/mizu/agent/pkg/up9"
@@ -17,7 +18,11 @@ import (
"github.com/up9inc/mizu/shared/logger"
)
var browserClientSocketUUIDs = make([]int, 0)
type BrowserClient struct {
dataStreamCancelFunc context.CancelFunc
}
var browserClients = make(map[int]*BrowserClient, 0)
var tapperClientSocketUUIDs = make([]int, 0)
var socketListLock = sync.Mutex{}
@@ -30,7 +35,7 @@ func init() {
go up9.UpdateAnalyzeStatus(BroadcastToBrowserClients)
}
func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) {
func (h *RoutesEventHandlers) WebSocketConnect(_ *gin.Context, socketId int, isTapper bool) {
if isTapper {
logger.Log.Infof("Websocket event - Tapper connected, socket ID: %d", socketId)
tappers.Connected()
@@ -45,7 +50,7 @@ func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) {
logger.Log.Infof("Websocket event - Browser socket connected, socket ID: %d", socketId)
socketListLock.Lock()
browserClientSocketUUIDs = append(browserClientSocketUUIDs, socketId)
browserClients[socketId] = &BrowserClient{}
socketListLock.Unlock()
BroadcastTappedPodsStatus()
@@ -63,13 +68,16 @@ func (h *RoutesEventHandlers) WebSocketDisconnect(socketId int, isTapper bool) {
} else {
logger.Log.Infof("Websocket event - Browser socket disconnected, socket ID: %d", socketId)
socketListLock.Lock()
removeSocketUUIDFromBrowserSlice(socketId)
if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc != nil {
browserClients[socketId].dataStreamCancelFunc()
}
delete(browserClients, socketId)
socketListLock.Unlock()
}
}
func BroadcastToBrowserClients(message []byte) {
for _, socketId := range browserClientSocketUUIDs {
for socketId := range browserClients {
go func(socketId int) {
if err := SendToSocket(socketId, message); err != nil {
logger.Log.Error(err)
@@ -88,7 +96,33 @@ func BroadcastToTapperClients(message []byte) {
}
}
func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) {
func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, message []byte) {
if isTapper {
HandleTapperIncomingMessage(message, h.SocketOutChannel, BroadcastToBrowserClients)
} else {
// we initiate the basenine stream after the first websocket message we receive (it contains the entry query), we then store a cancelfunc to later cancel this stream
if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc == nil {
var params WebSocketParams
if err := json.Unmarshal(message, &params); err != nil {
logger.Log.Errorf("Error: %v", socketId, err)
return
}
entriesStreamer := dependency.GetInstance(dependency.EntriesSocketStreamer).(EntryStreamer)
ctx, cancelFunc := context.WithCancel(context.Background())
err := entriesStreamer.Get(ctx, socketId, &params)
if err != nil {
logger.Log.Errorf("error initializing basenine stream for browser socket %d %+v", socketId, err)
cancelFunc()
} else {
browserClients[socketId].dataStreamCancelFunc = cancelFunc
}
}
}
}
func HandleTapperIncomingMessage(message []byte, socketOutChannel chan<- *tapApi.OutputChannelItem, broadcastMessageFunc func([]byte)) {
var socketMessageBase shared.WebSocketMessageMetadata
err := json.Unmarshal(message, &socketMessageBase)
if err != nil {
@@ -102,7 +136,7 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) {
logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err)
} else {
// NOTE: This is where the message comes back from the intermediate WebSocket to code.
h.SocketOutChannel <- tappedEntryMessage.Data
socketOutChannel <- tappedEntryMessage.Data
}
case shared.WebSocketMessageTypeUpdateStatus:
var statusMessage shared.WebSocketStatusMessage
@@ -110,15 +144,7 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) {
if err != nil {
logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err)
} else {
BroadcastToBrowserClients(message)
}
case shared.WebsocketMessageTypeOutboundLink:
var outboundLinkMessage models.WebsocketOutboundLinkMessage
err := json.Unmarshal(message, &outboundLinkMessage)
if err != nil {
logger.Log.Infof("Could not unmarshal message of message type %s %v", socketMessageBase.MessageType, err)
} else {
handleTLSLink(outboundLinkMessage)
broadcastMessageFunc(message)
}
default:
logger.Log.Infof("Received socket message of type %s for which no handlers are defined", socketMessageBase.MessageType)
@@ -126,39 +152,6 @@ func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) {
}
}
func handleTLSLink(outboundLinkMessage models.WebsocketOutboundLinkMessage) {
resolvedNameObject := k8sResolver.Resolve(outboundLinkMessage.Data.DstIP)
if resolvedNameObject != nil {
outboundLinkMessage.Data.DstIP = resolvedNameObject.FullAddress
} else if outboundLinkMessage.Data.SuggestedResolvedName != "" {
outboundLinkMessage.Data.DstIP = outboundLinkMessage.Data.SuggestedResolvedName
}
cacheKey := fmt.Sprintf("%s -> %s:%d", outboundLinkMessage.Data.Src, outboundLinkMessage.Data.DstIP, outboundLinkMessage.Data.DstPort)
_, isInCache := providers.RecentTLSLinks.Get(cacheKey)
if isInCache {
return
} else {
providers.RecentTLSLinks.SetDefault(cacheKey, outboundLinkMessage.Data)
}
marshaledMessage, err := json.Marshal(outboundLinkMessage)
if err != nil {
logger.Log.Errorf("Error marshaling outbound link message for broadcasting: %v", err)
} else {
logger.Log.Errorf("Broadcasting outboundlink message %s", string(marshaledMessage))
BroadcastToBrowserClients(marshaledMessage)
}
}
func removeSocketUUIDFromBrowserSlice(uuidToRemove int) {
newUUIDSlice := make([]int, 0, len(browserClientSocketUUIDs))
for _, uuid := range browserClientSocketUUIDs {
if uuid != uuidToRemove {
newUUIDSlice = append(newUUIDSlice, uuid)
}
}
browserClientSocketUUIDs = newUUIDSlice
}
func removeSocketUUIDFromTapperSlice(uuidToRemove int) {
newUUIDSlice := make([]int, 0, len(tapperClientSocketUUIDs))
for _, uuid := range tapperClientSocketUUIDs {