Refactor the WebSocket implementaiton for /ws

This commit is contained in:
M. Mert Yildiran 2021-09-16 19:31:08 +03:00
parent 4f74be47d3
commit 252563fa44
No known key found for this signature in database
GPG Key ID: D42ADB236521BF7A
3 changed files with 56 additions and 12 deletions

View File

@ -110,8 +110,6 @@ func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem, extension
panic("Channel of captured messages is nil") panic("Channel of captured messages is nil")
} }
go Query("", "localhost", "8000")
for item := range outputItems { for item := range outputItems {
extension := extensionsMap[item.Protocol.Name] extension := extensionsMap[item.Protocol.Name]
resolvedSource, resolvedDestionation := resolveIP(item.ConnectionInfo) resolvedSource, resolvedDestionation := resolveIP(item.ConnectionInfo)

View File

@ -10,6 +10,8 @@ import (
"regexp" "regexp"
"sync" "sync"
"time" "time"
"github.com/gorilla/websocket"
) )
func Connect(host string, port string) (conn net.Conn) { func Connect(host string, port string) (conn net.Conn) {
@ -45,11 +47,9 @@ func Insert(entry interface{}, conn net.Conn) {
conn.Write([]byte("\n")) conn.Write([]byte("\n"))
} }
func Query(query string, host string, port string) { func Query(query string, conn net.Conn, ws *websocket.Conn) {
conn := Connect(host, port)
var wg sync.WaitGroup var wg sync.WaitGroup
go readConnection(&wg, conn) go readConnection(&wg, conn, ws)
wg.Add(1) wg.Add(1)
conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) conn.SetWriteDeadline(time.Now().Add(1 * time.Second))
@ -61,7 +61,7 @@ func Query(query string, host string, port string) {
wg.Wait() wg.Wait()
} }
func readConnection(wg *sync.WaitGroup, conn net.Conn) { func readConnection(wg *sync.WaitGroup, conn net.Conn, ws *websocket.Conn) {
defer wg.Done() defer wg.Done()
for { for {
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(conn)
@ -74,13 +74,17 @@ func readConnection(wg *sync.WaitGroup, conn net.Conn) {
if !command { if !command {
fmt.Printf("\b\b** %s\n> ", text) fmt.Printf("\b\b** %s\n> ", text)
if text == "" {
return
}
var data map[string]interface{} var data map[string]interface{}
if err := json.Unmarshal([]byte(text), &data); err != nil { if err := json.Unmarshal([]byte(text), &data); err != nil {
panic(err) panic(err)
} }
baseEntryBytes, _ := models.CreateBaseEntryWebSocketMessage(data["Summary"].(map[string]interface{})) baseEntryBytes, _ := models.CreateBaseEntryWebSocketMessage(data["Summary"].(map[string]interface{}))
BroadcastToBrowserClients(baseEntryBytes) ws.WriteMessage(1, baseEntryBytes)
} }
if !ok { if !ok {

View File

@ -2,13 +2,14 @@ package api
import ( import (
"errors" "errors"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/romana/rlog" "github.com/romana/rlog"
"github.com/up9inc/mizu/shared/debounce" "github.com/up9inc/mizu/shared/debounce"
"net/http"
"sync"
"time"
) )
type EventHandlers interface { type EventHandlers interface {
@ -40,13 +41,54 @@ func init() {
func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) { func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) {
app.GET("/ws", func(c *gin.Context) { app.GET("/ws", func(c *gin.Context) {
websocketHandler(c.Writer, c.Request, eventHandlers, false) queryMap := c.Request.URL.Query()
query := ""
if val, ok := queryMap["q"]; ok {
query = val[0]
}
websocketHandlerUI(c.Writer, c.Request, eventHandlers, false, query)
}) })
app.GET("/wsTapper", func(c *gin.Context) { app.GET("/wsTapper", func(c *gin.Context) {
websocketHandler(c.Writer, c.Request, eventHandlers, true) websocketHandler(c.Writer, c.Request, eventHandlers, true)
}) })
} }
func websocketHandlerUI(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool, query string) {
ws, err := websocketUpgrader.Upgrade(w, r, nil)
if err != nil {
rlog.Errorf("Failed to set websocket upgrade: %v", err)
return
}
conn := Connect("localhost", "8000")
go Query(query, conn, ws)
websocketIdsLock.Lock()
connectedWebsocketIdCounter++
socketId := connectedWebsocketIdCounter
connectedWebsockets[socketId] = &SocketConnection{connection: ws, lock: &sync.Mutex{}, eventHandlers: eventHandlers, isTapper: isTapper}
websocketIdsLock.Unlock()
defer func() {
socketCleanup(socketId, connectedWebsockets[socketId])
}()
eventHandlers.WebSocketConnect(socketId, isTapper)
for {
_, msg, err := ws.ReadMessage()
if err != nil {
rlog.Errorf("Error reading message, socket id: %d, error: %v", socketId, err)
break
}
eventHandlers.WebSocketMessage(socketId, msg)
}
conn.Close()
}
func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) { func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) {
conn, err := websocketUpgrader.Upgrade(w, r, nil) conn, err := websocketUpgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {