diff --git a/agent/main.go b/agent/main.go index c453fc673..1040bf888 100644 --- a/agent/main.go +++ b/agent/main.go @@ -373,4 +373,6 @@ func initializeDependencies() { dependency.RegisterGenerator(dependency.ServiceMapGeneratorDependency, func() interface{} { return servicemap.GetDefaultServiceMapInstance() }) dependency.RegisterGenerator(dependency.OasGeneratorDependency, func() interface{} { return oas.GetDefaultOasGeneratorInstance(nil) }) dependency.RegisterGenerator(dependency.EntriesProvider, func() interface{} { return &entries.BasenineEntriesProvider{} }) + dependency.RegisterGenerator(dependency.EntriesSocketStreamer, func() interface{} { return &api.BasenineEntryStreamer{} }) + dependency.RegisterGenerator(dependency.EntryStreamerSocketConnector, func() interface{} { return &api.DefaultEntryStreamerSocketConnector{} }) } diff --git a/agent/pkg/api/socket_data_streamer.go b/agent/pkg/api/socket_data_streamer.go new file mode 100644 index 000000000..48a9a8fea --- /dev/null +++ b/agent/pkg/api/socket_data_streamer.go @@ -0,0 +1,143 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + + basenine "github.com/up9inc/basenine/client/go" + "github.com/up9inc/mizu/agent/pkg/dependency" + "github.com/up9inc/mizu/agent/pkg/models" + "github.com/up9inc/mizu/shared" + "github.com/up9inc/mizu/shared/logger" + tapApi "github.com/up9inc/mizu/tap/api" +) + +type EntryStreamer interface { + Get(socketId int, params *WebSocketParams) (context.CancelFunc, error) +} + +type EntryStreamerSocketConnector interface { + SendEntry(socketId int, entry *tapApi.Entry, params *WebSocketParams) + SendMetadata(socketId int, metadata *basenine.Metadata) + SendToastError(socketId int, err error) + CleanupSocket(socketId int) +} + +type DefaultEntryStreamerSocketConnector struct{} + +func (e *DefaultEntryStreamerSocketConnector) SendEntry(socketId int, entry *tapApi.Entry, params *WebSocketParams) { + var message []byte + if params.EnableFullEntries { + message, _ = models.CreateFullEntryWebSocketMessage(entry) + } else { + extension := extensionsMap[entry.Protocol.Name] + base := extension.Dissector.Summarize(entry) + message, _ = models.CreateBaseEntryWebSocketMessage(base) + } + + if err := SendToSocket(socketId, message); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) SendMetadata(socketId int, metadata *basenine.Metadata) { + metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) + if err := SendToSocket(socketId, metadataBytes); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) SendToastError(socketId int, err error) { + toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ + Type: "error", + AutoClose: 5000, + Text: fmt.Sprintf("Syntax error: %s", err.Error()), + }) + if err := SendToSocket(socketId, toastBytes); err != nil { + logger.Log.Error(err) + } +} + +func (e *DefaultEntryStreamerSocketConnector) CleanupSocket(socketId int) { + socketObj := connectedWebsockets[socketId] + socketCleanup(socketId, socketObj) +} + +type BasenineEntryStreamer struct{} + +func (e *BasenineEntryStreamer) Get(socketId int, params *WebSocketParams) (context.CancelFunc, error) { + var connection *basenine.Connection + + entryStreamerSocketConnector := dependency.GetInstance(dependency.EntryStreamerSocketConnector).(EntryStreamerSocketConnector) + + connection, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) + if err != nil { + logger.Log.Errorf("failed to establish a connection to Basenine: %v", err) + entryStreamerSocketConnector.CleanupSocket(socketId) + return nil, err + } + + data := make(chan []byte) + meta := make(chan []byte) + + ctx, cancel := context.WithCancel(context.Background()) + + query := params.Query + err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) + if err != nil { + entryStreamerSocketConnector.SendToastError(socketId, err) + } + + handleDataChannel := func(c *basenine.Connection, data chan []byte) { + for { + bytes := <-data + + if string(bytes) == basenine.CloseChannel { + return + } + + var entry *tapApi.Entry + err = json.Unmarshal(bytes, &entry) + if err != nil { + logger.Log.Debugf("error unmarshalling entry: %v", err.Error()) + continue + } + + entryStreamerSocketConnector.SendEntry(socketId, entry, params) + } + } + + handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { + for { + bytes := <-meta + + if string(bytes) == basenine.CloseChannel { + return + } + + var metadata *basenine.Metadata + err = json.Unmarshal(bytes, &metadata) + if err != nil { + logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) + continue + } + + entryStreamerSocketConnector.SendMetadata(socketId, metadata) + } + } + + go handleDataChannel(connection, data) + go handleMetaChannel(connection, meta) + + connection.Query(query, data, meta) + + go func() { + <-ctx.Done() + data <- []byte(basenine.CloseChannel) + meta <- []byte(basenine.CloseChannel) + connection.Close() + }() + + return cancel, nil +} diff --git a/agent/pkg/api/socket_server_handlers.go b/agent/pkg/api/socket_server_handlers.go index 7f655310e..96805c7bc 100644 --- a/agent/pkg/api/socket_server_handlers.go +++ b/agent/pkg/api/socket_server_handlers.go @@ -3,11 +3,10 @@ package api import ( "context" "encoding/json" - "fmt" "sync" "github.com/gin-gonic/gin" - basenine "github.com/up9inc/basenine/client/go" + "github.com/up9inc/mizu/agent/pkg/dependency" "github.com/up9inc/mizu/agent/pkg/models" "github.com/up9inc/mizu/agent/pkg/providers/tappedPods" "github.com/up9inc/mizu/agent/pkg/providers/tappers" @@ -103,7 +102,16 @@ func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, mess } else { // we initiate the basenine stream after the first websocket message we receive (it contains the entry query), we then store a cancelfunc to later cancel this stream if browserClients[socketId] != nil && browserClients[socketId].dataStreamCancelFunc == nil { - cancelFunc, err := startStreamingBasenineEntriesToSocket(socketId, message) + + var params WebSocketParams + if err := json.Unmarshal(message, ¶ms); err != nil { + logger.Log.Errorf("Error: %v", socketId, err) + return + } + + entriesStreamer := dependency.GetInstance(dependency.EntriesSocketStreamer).(EntryStreamer) + cancelFunc, err := entriesStreamer.Get(socketId, ¶ms) + if err != nil { logger.Log.Errorf("error initializing basenine stream for browser socket %d %+v", socketId, err) } else { @@ -113,107 +121,6 @@ func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, mess } } -func startStreamingBasenineEntriesToSocket(socketId int, message []byte) (context.CancelFunc, error) { - var params WebSocketParams - if err := json.Unmarshal(message, ¶ms); err != nil { - logger.Log.Errorf("error unmarshalling parameters: %v", socketId, err) - return nil, err - } - - var connection *basenine.Connection - - connection, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) - if err != nil { - logger.Log.Errorf("failed to establish a connection to Basenine: %v", err) - socketCleanup(socketId, connectedWebsockets[socketId]) - return nil, err - } - - data := make(chan []byte) - meta := make(chan []byte) - - ctx, cancel := context.WithCancel(context.Background()) - - query := params.Query - err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) - if err != nil { - toastBytes, _ := models.CreateWebsocketToastMessage(&models.ToastMessage{ - Type: "error", - AutoClose: 5000, - Text: fmt.Sprintf("syntax error: %s", err.Error()), - }) - if err := SendToSocket(socketId, toastBytes); err != nil { - logger.Log.Error(err) - } - } - - handleDataChannel := func(c *basenine.Connection, data chan []byte) { - for { - bytes := <-data - - if string(bytes) == basenine.CloseChannel { - return - } - - var entry *tapApi.Entry - err = json.Unmarshal(bytes, &entry) - if err != nil { - logger.Log.Debugf("error unmarshalling entry: %v", err.Error()) - continue - } - - var message []byte - if params.EnableFullEntries { - message, _ = models.CreateFullEntryWebSocketMessage(entry) - } else { - extension := extensionsMap[entry.Protocol.Name] - base := extension.Dissector.Summarize(entry) - message, _ = models.CreateBaseEntryWebSocketMessage(base) - } - - if err := SendToSocket(socketId, message); err != nil { - logger.Log.Error(err) - } - } - } - - handleMetaChannel := func(c *basenine.Connection, meta chan []byte) { - for { - bytes := <-meta - - if string(bytes) == basenine.CloseChannel { - return - } - - var metadata *basenine.Metadata - err = json.Unmarshal(bytes, &metadata) - if err != nil { - logger.Log.Debugf("Error unmarshalling metadata: %v", err.Error()) - continue - } - - metadataBytes, _ := models.CreateWebsocketQueryMetadataMessage(metadata) - if err := SendToSocket(socketId, metadataBytes); err != nil { - logger.Log.Error(err) - } - } - } - - go handleDataChannel(connection, data) - go handleMetaChannel(connection, meta) - - connection.Query(query, data, meta) - - go func() { - <-ctx.Done() - data <- []byte(basenine.CloseChannel) - meta <- []byte(basenine.CloseChannel) - connection.Close() - }() - - return cancel, nil -} - func (h *RoutesEventHandlers) handleTapperIncomingMessage(message []byte) { var socketMessageBase shared.WebSocketMessageMetadata err := json.Unmarshal(message, &socketMessageBase) diff --git a/agent/pkg/dependency/type_names.go b/agent/pkg/dependency/type_names.go index d886ce24e..c7e91decb 100644 --- a/agent/pkg/dependency/type_names.go +++ b/agent/pkg/dependency/type_names.go @@ -6,4 +6,6 @@ const ( ServiceMapGeneratorDependency = "ServiceMapGeneratorDependency" OasGeneratorDependency = "OasGeneratorDependency" EntriesProvider = "EntriesProvider" + EntriesSocketStreamer = "EntriesSocketStreamer" + EntryStreamerSocketConnector = "EntryStreamerSocketConnector" )