TRA-3437 switch fiber and ikisocket with gin-gonic and gorilla websocket (#136)

* WIP

* WIP

* WIP

* Update socket_server_handlers.go and socket_routes.go

* Fix stuck sockets

* Update go.mod, go.sum, and 5 more files...

* Update socket_routes.go

* Update Dockerfile, go.sum, and fiber_middleware.go

* fix analyze

Co-authored-by: RamiBerm <rami.berman@up9.com>
This commit is contained in:
RamiBerm
2021-07-25 13:08:29 +03:00
committed by GitHub
parent f64ee23c74
commit 6dd2bf705b
16 changed files with 284 additions and 314 deletions

View File

@@ -2,7 +2,7 @@ package api
import (
"encoding/json"
"github.com/antoniodipinto/ikisocket"
"fmt"
"github.com/romana/rlog"
"github.com/up9inc/mizu/shared"
"github.com/up9inc/mizu/tap"
@@ -10,9 +10,11 @@ import (
"mizuserver/pkg/models"
"mizuserver/pkg/routes"
"mizuserver/pkg/up9"
"sync"
)
var browserClientSocketUUIDs = make([]string, 0)
var browserClientSocketUUIDs = make([]int, 0)
var socketListLock = sync.Mutex{}
type RoutesEventHandlers struct {
routes.EventHandlers
@@ -23,51 +25,50 @@ func init() {
go up9.UpdateAnalyzeStatus(broadcastToBrowserClients)
}
func (h *RoutesEventHandlers) WebSocketConnect(ep *ikisocket.EventPayload) {
if ep.Kws.GetAttribute("is_tapper") == true {
rlog.Infof("Websocket Connection event - Tapper connected: %s", ep.SocketUUID)
func (h *RoutesEventHandlers) WebSocketConnect(socketId int, isTapper bool) {
if isTapper {
rlog.Infof("Websocket Connection event - Tapper connected: %s", socketId)
} else {
rlog.Infof("Websocket Connection event - Browser socket connected: %s", ep.SocketUUID)
browserClientSocketUUIDs = append(browserClientSocketUUIDs, ep.SocketUUID)
rlog.Infof("Websocket Connection event - Browser socket connected: %s", socketId)
socketListLock.Lock()
browserClientSocketUUIDs = append(browserClientSocketUUIDs, socketId)
socketListLock.Unlock()
}
}
func (h *RoutesEventHandlers) WebSocketDisconnect(ep *ikisocket.EventPayload) {
if ep.Kws.GetAttribute("is_tapper") == true {
rlog.Infof("Disconnection event - Tapper connected: %s", ep.SocketUUID)
func (h *RoutesEventHandlers) WebSocketDisconnect(socketId int, isTapper bool) {
if isTapper {
rlog.Infof("Disconnection event - Tapper connected: %s", socketId)
} else {
rlog.Infof("Disconnection event - Browser socket connected: %s", ep.SocketUUID)
removeSocketUUIDFromBrowserSlice(ep.SocketUUID)
rlog.Infof("Disconnection event - Browser socket connected: %s", socketId)
socketListLock.Lock()
removeSocketUUIDFromBrowserSlice(socketId)
socketListLock.Unlock()
}
}
func broadcastToBrowserClients(message []byte) {
ikisocket.EmitToList(browserClientSocketUUIDs, message)
}
for _, socketId := range browserClientSocketUUIDs {
go func(socketId int) {
err := routes.SendToSocket(socketId, message)
if err != nil {
fmt.Printf("error sending message to socket id %d: %v", socketId, err)
}
}(socketId)
func (h *RoutesEventHandlers) WebSocketClose(ep *ikisocket.EventPayload) {
if ep.Kws.GetAttribute("is_tapper") == true {
rlog.Infof("Websocket Close event - Tapper connected: %s", ep.SocketUUID)
} else {
rlog.Infof("Websocket Close event - Browser socket connected: %s", ep.SocketUUID)
removeSocketUUIDFromBrowserSlice(ep.SocketUUID)
}
}
func (h *RoutesEventHandlers) WebSocketError(ep *ikisocket.EventPayload) {
rlog.Infof("Socket error - Socket uuid : %s %v", ep.SocketUUID, ep.Error)
}
func (h *RoutesEventHandlers) WebSocketMessage(ep *ikisocket.EventPayload) {
func (h *RoutesEventHandlers) WebSocketMessage(_ int, message []byte) {
var socketMessageBase shared.WebSocketMessageMetadata
err := json.Unmarshal(ep.Data, &socketMessageBase)
err := json.Unmarshal(message, &socketMessageBase)
if err != nil {
rlog.Infof("Could not unmarshal websocket message %v\n", err)
} else {
switch socketMessageBase.MessageType {
case shared.WebSocketMessageTypeTappedEntry:
var tappedEntryMessage models.WebSocketTappedEntryMessage
err := json.Unmarshal(ep.Data, &tappedEntryMessage)
err := json.Unmarshal(message, &tappedEntryMessage)
if err != nil {
rlog.Infof("Could not unmarshal message of message type %s %v\n", socketMessageBase.MessageType, err)
} else {
@@ -75,12 +76,12 @@ func (h *RoutesEventHandlers) WebSocketMessage(ep *ikisocket.EventPayload) {
}
case shared.WebSocketMessageTypeUpdateStatus:
var statusMessage shared.WebSocketStatusMessage
err := json.Unmarshal(ep.Data, &statusMessage)
err := json.Unmarshal(message, &statusMessage)
if err != nil {
rlog.Infof("Could not unmarshal message of message type %s %v\n", socketMessageBase.MessageType, err)
} else {
controllers.TapStatus = statusMessage.TappingStatus
broadcastToBrowserClients(ep.Data)
broadcastToBrowserClients(message)
}
default:
rlog.Infof("Received socket message of type %s for which no handlers are defined", socketMessageBase.MessageType)
@@ -88,8 +89,8 @@ func (h *RoutesEventHandlers) WebSocketMessage(ep *ikisocket.EventPayload) {
}
}
func removeSocketUUIDFromBrowserSlice(uuidToRemove string) {
newUUIDSlice := make([]string, 0, len(browserClientSocketUUIDs))
func removeSocketUUIDFromBrowserSlice(uuidToRemove int) {
newUUIDSlice := make([]int, 0, len(browserClientSocketUUIDs))
for _, uuid := range browserClientSocketUUIDs {
if uuid != uuidToRemove {
newUUIDSlice = append(newUUIDSlice, uuid)

View File

@@ -3,7 +3,7 @@ package controllers
import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"github.com/google/martian/har"
"github.com/romana/rlog"
"mizuserver/pkg/database"
@@ -11,19 +11,20 @@ import (
"mizuserver/pkg/up9"
"mizuserver/pkg/utils"
"mizuserver/pkg/validation"
"net/http"
"strings"
"time"
)
func GetEntries(c *fiber.Ctx) error {
func GetEntries(c *gin.Context) {
entriesFilter := &models.EntriesFilter{}
if err := c.QueryParser(entriesFilter); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
if err := c.BindQuery(entriesFilter); err != nil {
c.JSON(http.StatusBadRequest, err)
}
err := validation.Validate(entriesFilter)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
c.JSON(http.StatusBadRequest, err)
}
order := database.OperatorToOrderMapping[entriesFilter.Operator]
@@ -50,18 +51,18 @@ func GetEntries(c *fiber.Ctx) error {
baseEntries = append(baseEntries, harEntry)
}
return c.Status(fiber.StatusOK).JSON(baseEntries)
c.JSON(http.StatusOK, baseEntries)
}
func GetHARs(c *fiber.Ctx) error {
func GetHARs(c *gin.Context) {
entriesFilter := &models.HarFetchRequestBody{}
order := database.OrderDesc
if err := c.QueryParser(entriesFilter); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
if err := c.BindQuery(entriesFilter); err != nil {
c.JSON(http.StatusBadRequest, err)
}
err := validation.Validate(entriesFilter)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
c.JSON(http.StatusBadRequest, err)
}
var timestampFrom, timestampTo int64
@@ -137,40 +138,45 @@ func GetHARs(c *fiber.Ctx) error {
retObj[k] = bytesData
}
buffer := utils.ZipData(retObj)
return c.Status(fiber.StatusOK).SendStream(buffer)
c.Data(http.StatusOK, "application/octet-stream", buffer.Bytes())
}
func UploadEntries(c *fiber.Ctx) error {
func UploadEntries(c *gin.Context) {
rlog.Infof("Upload entries - started\n")
uploadRequestBody := &models.UploadEntriesRequestBody{}
if err := c.QueryParser(uploadRequestBody); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
if err := c.BindQuery(uploadRequestBody); err != nil {
c.JSON(http.StatusBadRequest, err)
return
}
if err := validation.Validate(uploadRequestBody); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
c.JSON(http.StatusBadRequest, err)
return
}
if up9.GetAnalyzeInfo().IsAnalyzing {
return c.Status(fiber.StatusBadRequest).SendString("Cannot analyze, mizu is already analyzing")
c.String(http.StatusBadRequest, "Cannot analyze, mizu is already analyzing")
return
}
rlog.Infof("Upload entries - creating token. dest %s\n", uploadRequestBody.Dest)
token, err := up9.CreateAnonymousToken(uploadRequestBody.Dest)
if err != nil {
return c.Status(fiber.StatusServiceUnavailable).SendString("Can't get token")
c.String(http.StatusServiceUnavailable, "Cannot analyze, mizu is already analyzing")
return
}
rlog.Infof("Upload entries - uploading. token: %s model: %s\n", token.Token, token.Model)
go up9.UploadEntriesImpl(token.Token, token.Model, uploadRequestBody.Dest, uploadRequestBody.SleepIntervalSec)
return c.Status(fiber.StatusOK).SendString("OK")
c.String(http.StatusOK, "OK")
}
func GetFullEntries(c *fiber.Ctx) error {
func GetFullEntries(c *gin.Context) {
entriesFilter := &models.HarFetchRequestBody{}
if err := c.QueryParser(entriesFilter); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
if err := c.BindQuery(entriesFilter); err != nil {
c.JSON(http.StatusBadRequest, err)
}
err := validation.Validate(entriesFilter)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(err)
c.JSON(http.StatusBadRequest, err)
}
var timestampFrom, timestampTo int64
@@ -195,38 +201,37 @@ func GetFullEntries(c *fiber.Ctx) error {
}
result = append(result, harEntry)
}
return c.Status(fiber.StatusOK).JSON(result)
c.JSON(http.StatusOK, result)
}
func GetEntry(c *fiber.Ctx) error {
func GetEntry(c *gin.Context) {
var entryData models.MizuEntry
database.GetEntriesTable().
Where(map[string]string{"entryId": c.Params("entryId")}).
Where(map[string]string{"entryId": c.Param("entryId")}).
First(&entryData)
fullEntry := models.FullEntryDetails{}
if err := models.GetEntry(&entryData, &fullEntry); err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": true,
"msg": "Can't get entry details",
})
}
return c.Status(fiber.StatusOK).JSON(fullEntry)
c.JSON(http.StatusOK, fullEntry)
}
func DeleteAllEntries(c *fiber.Ctx) error {
func DeleteAllEntries(c *gin.Context) {
database.GetEntriesTable().
Where("1 = 1").
Delete(&models.MizuEntry{})
return c.Status(fiber.StatusOK).JSON(fiber.Map{
c.JSON(http.StatusOK, map[string]string{
"msg": "Success",
})
}
func GetGeneralStats(c *fiber.Ctx) error {
func GetGeneralStats(c *gin.Context) {
sqlQuery := "SELECT count(*) as count, min(timestamp) as min, max(timestamp) as max from mizu_entries"
var result struct {
Count int
@@ -234,5 +239,5 @@ func GetGeneralStats(c *fiber.Ctx) error {
Max int
}
database.GetEntriesTable().Raw(sqlQuery).Scan(&result)
return c.Status(fiber.StatusOK).JSON(&result)
c.JSON(http.StatusOK, result)
}

View File

@@ -1,12 +1,13 @@
package controllers
import (
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"github.com/up9inc/mizu/shared"
"mizuserver/pkg/version"
"net/http"
)
func GetVersion(c *fiber.Ctx) error {
func GetVersion(c *gin.Context) {
resp := shared.VersionResponse{SemVer: version.SemVer}
return c.Status(fiber.StatusOK).JSON(resp)
c.JSON(http.StatusOK, resp)
}

View File

@@ -1,11 +1,12 @@
package controllers
import (
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"mizuserver/pkg/holder"
"net/http"
)
func GetCurrentResolvingInformation(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(holder.GetResolver().GetMap())
func GetCurrentResolvingInformation(c *gin.Context) {
c.JSON(http.StatusOK, holder.GetResolver().GetMap())
}

View File

@@ -1,17 +1,18 @@
package controllers
import (
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"github.com/up9inc/mizu/shared"
"mizuserver/pkg/up9"
"net/http"
)
var TapStatus shared.TapStatus
func GetTappingStatus(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(TapStatus)
func GetTappingStatus(c *gin.Context) {
c.JSON(http.StatusOK, TapStatus)
}
func AnalyzeInformation(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).JSON(up9.GetAnalyzeInfo())
func AnalyzeInformation(c *gin.Context) {
c.JSON(http.StatusOK, up9.GetAnalyzeInfo())
}

View File

@@ -1,18 +0,0 @@
package middleware
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
)
// FiberMiddleware provide Fiber's built-in middlewares.
// See: https://docs.gofiber.io/api/middleware
func FiberMiddleware(a *fiber.App) {
a.Use(
// Add CORS to each route.
cors.New(),
// Add simple logger.
logger.New(),
)
}

View File

@@ -113,8 +113,8 @@ type EntriesFilter struct {
}
type UploadEntriesRequestBody struct {
Dest string `query:"dest"`
SleepIntervalSec int `query:"interval"`
Dest string `form:"dest"`
SleepIntervalSec int `form:"interval"`
}
type HarFetchRequestBody struct {

View File

@@ -1,25 +1,25 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"mizuserver/pkg/controllers"
)
// EntriesRoutes defines the group of har entries routes.
func EntriesRoutes(fiberApp *fiber.App) {
routeGroup := fiberApp.Group("/api")
func EntriesRoutes(ginApp *gin.Engine) {
routeGroup := ginApp.Group("/api")
routeGroup.Get("/entries", controllers.GetEntries) // get entries (base/thin entries)
routeGroup.Get("/entries/:entryId", controllers.GetEntry) // get single (full) entry
routeGroup.Get("/exportEntries", controllers.GetFullEntries)
routeGroup.Get("/uploadEntries", controllers.UploadEntries)
routeGroup.Get("/resolving", controllers.GetCurrentResolvingInformation)
routeGroup.GET("/entries", controllers.GetEntries) // get entries (base/thin entries)
routeGroup.GET("/entries/:entryId", controllers.GetEntry) // get single (full) entry
routeGroup.GET("/exportEntries", controllers.GetFullEntries)
routeGroup.GET("/uploadEntries", controllers.UploadEntries)
routeGroup.GET("/resolving", controllers.GetCurrentResolvingInformation)
routeGroup.Get("/har", controllers.GetHARs)
routeGroup.GET("/har", controllers.GetHARs)
routeGroup.Get("/resetDB", controllers.DeleteAllEntries) // get single (full) entry
routeGroup.Get("/generalStats", controllers.GetGeneralStats) // get general stats about entries in DB
routeGroup.GET("/resetDB", controllers.DeleteAllEntries) // get single (full) entry
routeGroup.GET("/generalStats", controllers.GetGeneralStats) // get general stats about entries in DB
routeGroup.Get("/tapStatus", controllers.GetTappingStatus) // get tapping status
routeGroup.Get("/analyzeStatus", controllers.AnalyzeInformation)
routeGroup.GET("/tapStatus", controllers.GetTappingStatus) // get tapping status
routeGroup.GET("/analyzeStatus", controllers.AnalyzeInformation)
}

View File

@@ -1,13 +1,13 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/gin-gonic/gin"
"mizuserver/pkg/controllers"
)
// MetadataRoutes defines the group of metadata routes.
func MetadataRoutes(fiberApp *fiber.App) {
routeGroup := fiberApp.Group("/metadata")
func MetadataRoutes(app *gin.Engine) {
routeGroup := app.Group("/metadata")
routeGroup.Get("/version", controllers.GetVersion)
routeGroup.GET("/version", controllers.GetVersion)
}

View File

@@ -1,12 +1,15 @@
package routes
import "github.com/gofiber/fiber/v2"
import (
"github.com/gin-gonic/gin"
"net/http"
)
// NotFoundRoute defines the 404 Error route.
func NotFoundRoute(fiberApp *fiber.App) {
fiberApp.Use(
func(c *fiber.Ctx) error {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
func NotFoundRoute(app *gin.Engine) {
app.Use(
func(c *gin.Context) {
c.JSON(http.StatusNotFound, map[string]interface{}{
"error": true,
"msg": "sorry, endpoint is not found",
})

View File

@@ -1,31 +1,118 @@
package routes
import (
"github.com/antoniodipinto/ikisocket"
"github.com/gofiber/fiber/v2"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/up9inc/mizu/shared/debounce"
"net/http"
"sync"
"time"
)
type EventHandlers interface {
WebSocketConnect(ep *ikisocket.EventPayload)
WebSocketDisconnect(ep *ikisocket.EventPayload)
WebSocketClose(ep *ikisocket.EventPayload)
WebSocketError(ep *ikisocket.EventPayload)
WebSocketMessage(ep *ikisocket.EventPayload)
WebSocketConnect(socketId int, isTapper bool)
WebSocketDisconnect(socketId int, isTapper bool)
WebSocketMessage(socketId int, message []byte)
}
func WebSocketRoutes(app *fiber.App, eventHandlers EventHandlers) {
app.Get("/ws", ikisocket.New(func(kws *ikisocket.Websocket) {
kws.SetAttribute("is_tapper", false)
}))
app.Get("/wsTapper", ikisocket.New(func(kws *ikisocket.Websocket) {
// Tapper clients are handled differently, they don't need to receive new message broadcasts.
kws.SetAttribute("is_tapper", true)
}))
ikisocket.On(ikisocket.EventMessage, eventHandlers.WebSocketMessage)
ikisocket.On(ikisocket.EventConnect, eventHandlers.WebSocketConnect)
ikisocket.On(ikisocket.EventDisconnect, eventHandlers.WebSocketDisconnect)
ikisocket.On(ikisocket.EventClose, eventHandlers.WebSocketClose) // This event is called when the server disconnects the user actively with .Close() method
ikisocket.On(ikisocket.EventError, eventHandlers.WebSocketError) // On error event
type SocketConnection struct {
connection *websocket.Conn
lock *sync.Mutex
eventHandlers EventHandlers
isTapper bool
}
var websocketUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
var websocketIdsLock = sync.Mutex{}
var connectedWebsockets map[int]*SocketConnection
var connectedWebsocketIdCounter = 0
func init() {
websocketUpgrader.CheckOrigin = func(r *http.Request) bool { return true } // like cors for web socket
connectedWebsockets = make(map[int]*SocketConnection, 0)
}
func WebSocketRoutes(app *gin.Engine, eventHandlers EventHandlers) {
app.GET("/ws", func(c *gin.Context) {
websocketHandler(c.Writer, c.Request, eventHandlers, false)
})
app.GET("/wsTapper", func(c *gin.Context) {
websocketHandler(c.Writer, c.Request, eventHandlers, true)
})
}
func websocketHandler(w http.ResponseWriter, r *http.Request, eventHandlers EventHandlers, isTapper bool) {
conn, err := websocketUpgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println("Failed to set websocket upgrade: %+v", err)
return
}
websocketIdsLock.Lock()
connectedWebsocketIdCounter++
socketId := connectedWebsocketIdCounter
connectedWebsockets[socketId] = &SocketConnection{connection: conn, lock: &sync.Mutex{}, eventHandlers: eventHandlers, isTapper: isTapper}
websocketIdsLock.Unlock()
defer func() {
socketCleanup(socketId, connectedWebsockets[socketId])
}()
eventHandlers.WebSocketConnect(socketId, isTapper)
for {
_, msg, err := conn.ReadMessage()
if err != nil {
fmt.Printf("Conn err: %v\n", err)
break
}
eventHandlers.WebSocketMessage(socketId, msg)
}
}
func socketCleanup(socketId int, socketConnection *SocketConnection) {
err := socketConnection.connection.Close()
if err != nil {
fmt.Printf("Error closing socket connection for socket id %d: %v\n", socketId, err)
}
websocketIdsLock.Lock()
connectedWebsockets[socketId] = nil
websocketIdsLock.Unlock()
socketConnection.eventHandlers.WebSocketDisconnect(socketId, socketConnection.isTapper)
}
var db = debounce.NewDebouncer(time.Second * 5, func() {
fmt.Println("Successfully sent to socket")
})
func SendToSocket(socketId int, message []byte) error {
socketObj := connectedWebsockets[socketId]
if socketObj == nil {
return errors.New("Socket is disconnected")
}
var sent = false
time.AfterFunc(time.Second * 5, func() {
if !sent {
fmt.Println("Socket timed out")
socketCleanup(socketId, socketObj)
}
})
socketObj.lock.Lock() // gorilla socket panics from concurrent writes to a single socket
err := socketObj.connection.WriteMessage(1, message)
socketObj.lock.Unlock()
sent = true
return err
}

View File

@@ -1,32 +1,42 @@
package utils
import (
"github.com/gofiber/fiber/v2"
"context"
"github.com/gin-gonic/gin"
"github.com/romana/rlog"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"reflect"
"syscall"
"time"
)
// StartServer starts the server with a graceful shutdown
func StartServer(app *fiber.App) {
func StartServer(app *gin.Engine) {
signals := make(chan os.Signal, 2)
signal.Notify(signals,
os.Interrupt, // this catch ctrl + c
syscall.SIGTSTP, // this catch ctrl + z
)
srv := &http.Server{
Addr: ":8080",
Handler: app,
}
go func() {
_ = <-signals
rlog.Infof("Shutting down...")
_ = app.Shutdown()
ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
_ = srv.Shutdown(ctx)
os.Exit(0)
}()
// Run server.
if err := app.Listen(":8899"); err != nil {
if err := app.Run(":8899"); err != nil {
log.Printf("Oops... Server is not running! Reason: %v", err)
}
}