diff --git a/agent/pkg/api/main.go b/agent/pkg/api/main.go index ec92e8a00..2042909c0 100644 --- a/agent/pkg/api/main.go +++ b/agent/pkg/api/main.go @@ -110,8 +110,6 @@ func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem, extension panic("Channel of captured messages is nil") } - go Query("", "localhost", "8000") - for item := range outputItems { extension := extensionsMap[item.Protocol.Name] resolvedSource, resolvedDestionation := resolveIP(item.ConnectionInfo) diff --git a/agent/pkg/api/realtime_client.go b/agent/pkg/api/realtime_client.go index 0d50b643f..6101f9b4a 100644 --- a/agent/pkg/api/realtime_client.go +++ b/agent/pkg/api/realtime_client.go @@ -10,6 +10,8 @@ import ( "regexp" "sync" "time" + + "github.com/gorilla/websocket" ) func Connect(host string, port string) (conn net.Conn) { @@ -45,11 +47,9 @@ func Insert(entry interface{}, conn net.Conn) { conn.Write([]byte("\n")) } -func Query(query string, host string, port string) { - conn := Connect(host, port) - +func Query(query string, conn net.Conn, ws *websocket.Conn) { var wg sync.WaitGroup - go readConnection(&wg, conn) + go readConnection(&wg, conn, ws) wg.Add(1) conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) @@ -61,7 +61,7 @@ func Query(query string, host string, port string) { wg.Wait() } -func readConnection(wg *sync.WaitGroup, conn net.Conn) { +func readConnection(wg *sync.WaitGroup, conn net.Conn, ws *websocket.Conn) { defer wg.Done() for { scanner := bufio.NewScanner(conn) @@ -74,13 +74,17 @@ func readConnection(wg *sync.WaitGroup, conn net.Conn) { if !command { fmt.Printf("\b\b** %s\n> ", text) + if text == "" { + return + } + var data map[string]interface{} if err := json.Unmarshal([]byte(text), &data); err != nil { panic(err) } baseEntryBytes, _ := models.CreateBaseEntryWebSocketMessage(data["Summary"].(map[string]interface{})) - BroadcastToBrowserClients(baseEntryBytes) + ws.WriteMessage(1, baseEntryBytes) } if !ok { diff --git a/agent/pkg/api/socket_routes.go b/agent/pkg/api/socket_routes.go index c44c1e047..385d96252 100644 --- a/agent/pkg/api/socket_routes.go +++ b/agent/pkg/api/socket_routes.go @@ -2,13 +2,14 @@ package api import ( "errors" + "net/http" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/romana/rlog" "github.com/up9inc/mizu/shared/debounce" - "net/http" - "sync" - "time" ) type EventHandlers interface { @@ -40,13 +41,54 @@ func init() { func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { app.GET("/ws", func(c *gin.Context) { - websocketHandler(c.Writer, c.Request, eventHandlers, false) + queryMap := c.Request.URL.Query() + query := "" + if val, ok := queryMap["q"]; ok { + query = val[0] + } + websocketHandlerUI(c.Writer, c.Request, eventHandlers, false, query) }) app.GET("/wsTapper", func(c *gin.Context) { websocketHandler(c.Writer, c.Request, eventHandlers, true) }) } +func websocketHandlerUI(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool, query string) { + ws, err := websocketUpgrader.Upgrade(w, r, nil) + if err != nil { + rlog.Errorf("Failed to set websocket upgrade: %v", err) + return + } + + conn := Connect("localhost", "8000") + go Query(query, conn, ws) + + websocketIdsLock.Lock() + + connectedWebsocketIdCounter++ + socketId := connectedWebsocketIdCounter + connectedWebsockets[socketId] = &SocketConnection{connection: ws, lock: &sync.Mutex{}, eventHandlers: eventHandlers, isTapper: isTapper} + + websocketIdsLock.Unlock() + + defer func() { + socketCleanup(socketId, connectedWebsockets[socketId]) + }() + + eventHandlers.WebSocketConnect(socketId, isTapper) + + for { + _, msg, err := ws.ReadMessage() + if err != nil { + rlog.Errorf("Error reading message, socket id: %d, error: %v", socketId, err) + break + } + eventHandlers.WebSocketMessage(socketId, msg) + } + + conn.Close() +} + func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) { conn, err := websocketUpgrader.Upgrade(w, r, nil) if err != nil {