diff --git a/agent/pkg/api/socket_data_streamer.go b/agent/pkg/api/socket_data_streamer.go index 48a9a8fea..c93254464 100644 --- a/agent/pkg/api/socket_data_streamer.go +++ b/agent/pkg/api/socket_data_streamer.go @@ -14,7 +14,7 @@ import ( ) type EntryStreamer interface { - Get(socketId int, params *WebSocketParams) (context.CancelFunc, error) + Get(ctx context.Context, socketId int, params *WebSocketParams) error } type EntryStreamerSocketConnector interface { @@ -66,7 +66,7 @@ func (e *DefaultEntryStreamerSocketConnector) CleanupSocket(socketId int) { type BasenineEntryStreamer struct{} -func (e *BasenineEntryStreamer) Get(socketId int, params *WebSocketParams) (context.CancelFunc, error) { +func (e *BasenineEntryStreamer) Get(ctx context.Context, socketId int, params *WebSocketParams) error { var connection *basenine.Connection entryStreamerSocketConnector := dependency.GetInstance(dependency.EntryStreamerSocketConnector).(EntryStreamerSocketConnector) @@ -75,14 +75,12 @@ func (e *BasenineEntryStreamer) Get(socketId int, params *WebSocketParams) (cont if err != nil { logger.Log.Errorf("failed to establish a connection to Basenine: %v", err) entryStreamerSocketConnector.CleanupSocket(socketId) - return nil, err + return err } data := make(chan []byte) meta := make(chan []byte) - ctx, cancel := context.WithCancel(context.Background()) - query := params.Query err = basenine.Validate(shared.BasenineHost, shared.BaseninePort, query) if err != nil { @@ -139,5 +137,5 @@ func (e *BasenineEntryStreamer) Get(socketId int, params *WebSocketParams) (cont connection.Close() }() - return cancel, nil + return nil } diff --git a/agent/pkg/api/socket_server_handlers.go b/agent/pkg/api/socket_server_handlers.go index 7d58c59ea..18e1dded8 100644 --- a/agent/pkg/api/socket_server_handlers.go +++ b/agent/pkg/api/socket_server_handlers.go @@ -109,10 +109,12 @@ func (h *RoutesEventHandlers) WebSocketMessage(socketId int, isTapper bool, mess } entriesStreamer := dependency.GetInstance(dependency.EntriesSocketStreamer).(EntryStreamer) - cancelFunc, err := entriesStreamer.Get(socketId, ¶ms) + ctx, cancelFunc := context.WithCancel(context.Background()) + err := entriesStreamer.Get(ctx, socketId, ¶ms) if err != nil { logger.Log.Errorf("error initializing basenine stream for browser socket %d %+v", socketId, err) + cancelFunc() } else { browserClients[socketId].dataStreamCancelFunc = cancelFunc }