mirror of
https://github.com/kubeshark/kubeshark.git
synced 2025-06-26 16:24:54 +00:00
Stop the hanging Goroutines by dropping the old, unidentified TCP streams (#260)
* Close the hanging TCP message channels after a dynamically aligned timeout (base `10000` milliseconds) * Bring back `source.Lazy` * Add a one more `sync.Map.Delete` call * Improve the formula by taking base Goroutine count into account * Reduce duplication * Include the dropped TCP streams count into the stats tracker and print a debug log whenever it happens * Add `superIdentifier` field to `tcpStream` to check if it has identified Also stop the other protocol dissectors if a TCP stream identified by a protocol. * Take one step forward in fixing the channel closing issue (WIP) Add `sync.Mutex` to `tcpReader` and make the loops reference based. * Fix the channel closing issue * Improve the accuracy of the formula, log better and multiply `baseStreamChannelTimeoutMs` by 100 * Remove `fmt.Printf` * Replace `runtime.Gosched()` with `time.Sleep(1 * time.Millisecond)` * Close the channels of other protocols in case of an identification * Simplify the logic * Replace the formula with hard timeout 5000 milliseconds and 4000 maximum number of Goroutines
This commit is contained in:
parent
819ccf54cd
commit
858a64687d
@ -21,7 +21,7 @@ type Protocol struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Extension struct {
|
type Extension struct {
|
||||||
Protocol Protocol
|
Protocol *Protocol
|
||||||
Path string
|
Path string
|
||||||
Plug *plugin.Plugin
|
Plug *plugin.Plugin
|
||||||
Dissector Dissector
|
Dissector Dissector
|
||||||
@ -72,10 +72,15 @@ type SuperTimer struct {
|
|||||||
CaptureTime time.Time
|
CaptureTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SuperIdentifier struct {
|
||||||
|
Protocol *Protocol
|
||||||
|
IsClosedOthers bool
|
||||||
|
}
|
||||||
|
|
||||||
type Dissector interface {
|
type Dissector interface {
|
||||||
Register(*Extension)
|
Register(*Extension)
|
||||||
Ping()
|
Ping()
|
||||||
Dissect(b *bufio.Reader, isClient bool, tcpID *TcpID, counterPair *CounterPair, superTimer *SuperTimer, emitter Emitter) error
|
Dissect(b *bufio.Reader, isClient bool, tcpID *TcpID, counterPair *CounterPair, superTimer *SuperTimer, superIdentifier *SuperIdentifier, emitter Emitter) error
|
||||||
Analyze(item *OutputChannelItem, entryId string, resolvedSource string, resolvedDestination string) *MizuEntry
|
Analyze(item *OutputChannelItem, entryId string, resolvedSource string, resolvedDestination string) *MizuEntry
|
||||||
Summarize(entry *MizuEntry) *BaseEntryDetails
|
Summarize(entry *MizuEntry) *BaseEntryDetails
|
||||||
Represent(entry *MizuEntry) (protocol Protocol, object []byte, bodySize int64, err error)
|
Represent(entry *MizuEntry) (protocol Protocol, object []byte, bodySize int64, err error)
|
||||||
|
@ -32,7 +32,7 @@ func init() {
|
|||||||
type dissecting string
|
type dissecting string
|
||||||
|
|
||||||
func (d dissecting) Register(extension *api.Extension) {
|
func (d dissecting) Register(extension *api.Extension) {
|
||||||
extension.Protocol = protocol
|
extension.Protocol = &protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dissecting) Ping() {
|
func (d dissecting) Ping() {
|
||||||
@ -41,7 +41,7 @@ func (d dissecting) Ping() {
|
|||||||
|
|
||||||
const amqpRequest string = "amqp_request"
|
const amqpRequest string = "amqp_request"
|
||||||
|
|
||||||
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter) error {
|
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter) error {
|
||||||
r := AmqpReader{b}
|
r := AmqpReader{b}
|
||||||
|
|
||||||
var remaining int
|
var remaining int
|
||||||
@ -78,6 +78,10 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
var lastMethodFrameMessage Message
|
var lastMethodFrameMessage Message
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &protocol {
|
||||||
|
return errors.New("Identified by another protocol")
|
||||||
|
}
|
||||||
|
|
||||||
frame, err := r.ReadFrame()
|
frame, err := r.ReadFrame()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// We must read until we see an EOF... very important!
|
// We must read until we see an EOF... very important!
|
||||||
@ -108,9 +112,11 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
switch lastMethodFrameMessage.(type) {
|
switch lastMethodFrameMessage.(type) {
|
||||||
case *BasicPublish:
|
case *BasicPublish:
|
||||||
eventBasicPublish.Body = f.Body
|
eventBasicPublish.Body = f.Body
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
case *BasicDeliver:
|
case *BasicDeliver:
|
||||||
eventBasicDeliver.Body = f.Body
|
eventBasicDeliver.Body = f.Body
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
default:
|
default:
|
||||||
body = nil
|
body = nil
|
||||||
@ -134,6 +140,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
NoWait: m.NoWait,
|
NoWait: m.NoWait,
|
||||||
Arguments: m.Arguments,
|
Arguments: m.Arguments,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
case *BasicConsume:
|
case *BasicConsume:
|
||||||
@ -146,6 +153,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
NoWait: m.NoWait,
|
NoWait: m.NoWait,
|
||||||
Arguments: m.Arguments,
|
Arguments: m.Arguments,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
case *BasicDeliver:
|
case *BasicDeliver:
|
||||||
@ -165,6 +173,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
NoWait: m.NoWait,
|
NoWait: m.NoWait,
|
||||||
Arguments: m.Arguments,
|
Arguments: m.Arguments,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
case *ExchangeDeclare:
|
case *ExchangeDeclare:
|
||||||
@ -178,6 +187,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
NoWait: m.NoWait,
|
NoWait: m.NoWait,
|
||||||
Arguments: m.Arguments,
|
Arguments: m.Arguments,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
case *ConnectionStart:
|
case *ConnectionStart:
|
||||||
@ -188,6 +198,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
Mechanisms: m.Mechanisms,
|
Mechanisms: m.Mechanisms,
|
||||||
Locales: m.Locales,
|
Locales: m.Locales,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
case *ConnectionClose:
|
case *ConnectionClose:
|
||||||
@ -197,6 +208,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
ClassId: m.ClassId,
|
ClassId: m.ClassId,
|
||||||
MethodId: m.MethodId,
|
MethodId: m.MethodId,
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, superTimer.CaptureTime, emitter)
|
emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, superTimer.CaptureTime, emitter)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -52,7 +53,7 @@ func init() {
|
|||||||
type dissecting string
|
type dissecting string
|
||||||
|
|
||||||
func (d dissecting) Register(extension *api.Extension) {
|
func (d dissecting) Register(extension *api.Extension) {
|
||||||
extension.Protocol = protocol
|
extension.Protocol = &protocol
|
||||||
extension.MatcherMap = reqResMatcher.openMessagesMap
|
extension.MatcherMap = reqResMatcher.openMessagesMap
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ func (d dissecting) Ping() {
|
|||||||
log.Printf("pong %s\n", protocol.Name)
|
log.Printf("pong %s\n", protocol.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter) error {
|
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter) error {
|
||||||
ident := fmt.Sprintf("%s->%s:%s->%s", tcpID.SrcIP, tcpID.DstIP, tcpID.SrcPort, tcpID.DstPort)
|
ident := fmt.Sprintf("%s->%s:%s->%s", tcpID.SrcIP, tcpID.DstIP, tcpID.SrcPort, tcpID.DstPort)
|
||||||
isHTTP2, err := checkIsHTTP2Connection(b, isClient)
|
isHTTP2, err := checkIsHTTP2Connection(b, isClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -77,8 +78,12 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
grpcAssembler = createGrpcAssembler(b)
|
grpcAssembler = createGrpcAssembler(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
success := false
|
dissected := false
|
||||||
for {
|
for {
|
||||||
|
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &protocol {
|
||||||
|
return errors.New("Identified by another protocol")
|
||||||
|
}
|
||||||
|
|
||||||
if isHTTP2 {
|
if isHTTP2 {
|
||||||
err = handleHTTP2Stream(grpcAssembler, tcpID, superTimer, emitter)
|
err = handleHTTP2Stream(grpcAssembler, tcpID, superTimer, emitter)
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
@ -87,7 +92,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
rlog.Debugf("[HTTP/2] stream %s error: %s (%v,%+v)", ident, err, err, err)
|
rlog.Debugf("[HTTP/2] stream %s error: %s (%v,%+v)", ident, err, err, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
success = true
|
dissected = true
|
||||||
} else if isClient {
|
} else if isClient {
|
||||||
err = handleHTTP1ClientStream(b, tcpID, counterPair, superTimer, emitter)
|
err = handleHTTP1ClientStream(b, tcpID, counterPair, superTimer, emitter)
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
@ -96,7 +101,7 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
rlog.Debugf("[HTTP-request] stream %s Request error: %s (%v,%+v)", ident, err, err, err)
|
rlog.Debugf("[HTTP-request] stream %s Request error: %s (%v,%+v)", ident, err, err, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
success = true
|
dissected = true
|
||||||
} else {
|
} else {
|
||||||
err = handleHTTP1ServerStream(b, tcpID, counterPair, superTimer, emitter)
|
err = handleHTTP1ServerStream(b, tcpID, counterPair, superTimer, emitter)
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
@ -105,13 +110,14 @@ func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, co
|
|||||||
rlog.Debugf("[HTTP-response], stream %s Response error: %s (%v,%+v)", ident, err, err, err)
|
rlog.Debugf("[HTTP-response], stream %s Response error: %s (%v,%+v)", ident, err, err, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
success = true
|
dissected = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !success {
|
if !dissected {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &protocol
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
@ -30,7 +31,7 @@ func init() {
|
|||||||
type dissecting string
|
type dissecting string
|
||||||
|
|
||||||
func (d dissecting) Register(extension *api.Extension) {
|
func (d dissecting) Register(extension *api.Extension) {
|
||||||
extension.Protocol = _protocol
|
extension.Protocol = &_protocol
|
||||||
extension.MatcherMap = reqResMatcher.openMessagesMap
|
extension.MatcherMap = reqResMatcher.openMessagesMap
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,18 +39,24 @@ func (d dissecting) Ping() {
|
|||||||
log.Printf("pong %s\n", _protocol.Name)
|
log.Printf("pong %s\n", _protocol.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter) error {
|
func (d dissecting) Dissect(b *bufio.Reader, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter) error {
|
||||||
for {
|
for {
|
||||||
|
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &_protocol {
|
||||||
|
return errors.New("Identified by another protocol")
|
||||||
|
}
|
||||||
|
|
||||||
if isClient {
|
if isClient {
|
||||||
_, _, err := ReadRequest(b, tcpID, superTimer)
|
_, _, err := ReadRequest(b, tcpID, superTimer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &_protocol
|
||||||
} else {
|
} else {
|
||||||
err := ReadResponse(b, tcpID, superTimer, emitter)
|
err := ReadResponse(b, tcpID, superTimer, emitter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
superIdentifier.Protocol = &_protocol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@ -95,6 +96,8 @@ var ownIps []string // global
|
|||||||
var hostMode bool // global
|
var hostMode bool // global
|
||||||
var extensions []*api.Extension // global
|
var extensions []*api.Extension // global
|
||||||
|
|
||||||
|
const baseStreamChannelTimeoutMs int = 5000 * 100
|
||||||
|
|
||||||
/* minOutputLevel: Error will be printed only if outputLevel is above this value
|
/* minOutputLevel: Error will be printed only if outputLevel is above this value
|
||||||
* t: key for errorsMap (counting errors)
|
* t: key for errorsMap (counting errors)
|
||||||
* s, a: arguments log.Printf
|
* s, a: arguments log.Printf
|
||||||
@ -211,8 +214,45 @@ func startMemoryProfiler() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeTimedoutTcpStreamChannels() {
|
||||||
|
maxNumberOfGoroutines = GetMaxNumberOfGoroutines()
|
||||||
|
TcpStreamChannelTimeoutMs := GetTcpChannelTimeoutMs()
|
||||||
|
for {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
streams.Range(func(key interface{}, value interface{}) bool {
|
||||||
|
streamWrapper := value.(*tcpStreamWrapper)
|
||||||
|
stream := streamWrapper.stream
|
||||||
|
if stream.superIdentifier.Protocol == nil {
|
||||||
|
if !stream.isClosed && time.Now().After(streamWrapper.createdAt.Add(TcpStreamChannelTimeoutMs)) {
|
||||||
|
stream.Close()
|
||||||
|
statsTracker.incDroppedTcpStreams()
|
||||||
|
rlog.Debugf("Dropped an unidentified TCP stream because of timeout. Total dropped: %d Total Goroutines: %d Timeout (ms): %d\n", statsTracker.appStats.DroppedTcpStreams, runtime.NumGoroutine(), TcpStreamChannelTimeoutMs/1000000)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !stream.superIdentifier.IsClosedOthers {
|
||||||
|
for i := range stream.clients {
|
||||||
|
reader := &stream.clients[i]
|
||||||
|
if reader.extension.Protocol != stream.superIdentifier.Protocol {
|
||||||
|
reader.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := range stream.servers {
|
||||||
|
reader := &stream.servers[i]
|
||||||
|
if reader.extension.Protocol != stream.superIdentifier.Protocol {
|
||||||
|
reader.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream.superIdentifier.IsClosedOthers = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func startPassiveTapper(outputItems chan *api.OutputChannelItem) {
|
func startPassiveTapper(outputItems chan *api.OutputChannelItem) {
|
||||||
log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile)
|
log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile)
|
||||||
|
go closeTimedoutTcpStreamChannels()
|
||||||
|
|
||||||
defer util.Run()()
|
defer util.Run()()
|
||||||
if *debug {
|
if *debug {
|
||||||
@ -367,7 +407,14 @@ func startPassiveTapper(outputItems chan *api.OutputChannelItem) {
|
|||||||
startMemoryProfiler()
|
startMemoryProfiler()
|
||||||
}
|
}
|
||||||
|
|
||||||
for packet := range source.Packets() {
|
for {
|
||||||
|
packet, err := source.NextPacket()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
rlog.Debugf("Error:", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
packetsCount := statsTracker.incPacketsCount()
|
packetsCount := statsTracker.incPacketsCount()
|
||||||
rlog.Debugf("PACKET #%d", packetsCount)
|
rlog.Debugf("PACKET #%d", packetsCount)
|
||||||
data := packet.Data()
|
data := packet.Data()
|
||||||
|
@ -3,6 +3,7 @@ package tap
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -11,8 +12,12 @@ const (
|
|||||||
MemoryProfilingTimeIntervalSeconds = "MEMORY_PROFILING_TIME_INTERVAL"
|
MemoryProfilingTimeIntervalSeconds = "MEMORY_PROFILING_TIME_INTERVAL"
|
||||||
MaxBufferedPagesTotalEnvVarName = "MAX_BUFFERED_PAGES_TOTAL"
|
MaxBufferedPagesTotalEnvVarName = "MAX_BUFFERED_PAGES_TOTAL"
|
||||||
MaxBufferedPagesPerConnectionEnvVarName = "MAX_BUFFERED_PAGES_PER_CONNECTION"
|
MaxBufferedPagesPerConnectionEnvVarName = "MAX_BUFFERED_PAGES_PER_CONNECTION"
|
||||||
|
TcpStreamChannelTimeoutMsEnvVarName = "TCP_STREAM_CHANNEL_TIMEOUT_MS"
|
||||||
|
MaxNumberOfGoroutinesEnvVarName = "MAX_NUMBER_OF_GOROUTINES"
|
||||||
MaxBufferedPagesTotalDefaultValue = 5000
|
MaxBufferedPagesTotalDefaultValue = 5000
|
||||||
MaxBufferedPagesPerConnectionDefaultValue = 5000
|
MaxBufferedPagesPerConnectionDefaultValue = 5000
|
||||||
|
TcpStreamChannelTimeoutMsDefaultValue = 5000
|
||||||
|
MaxNumberOfGoroutinesDefaultValue = 4000
|
||||||
)
|
)
|
||||||
|
|
||||||
type globalSettings struct {
|
type globalSettings struct {
|
||||||
@ -49,6 +54,22 @@ func GetMaxBufferedPagesPerConnection() int {
|
|||||||
return valueFromEnv
|
return valueFromEnv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTcpChannelTimeoutMs() time.Duration {
|
||||||
|
valueFromEnv, err := strconv.Atoi(os.Getenv(TcpStreamChannelTimeoutMsEnvVarName))
|
||||||
|
if err != nil {
|
||||||
|
return TcpStreamChannelTimeoutMsDefaultValue * time.Millisecond
|
||||||
|
}
|
||||||
|
return time.Duration(valueFromEnv) * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMaxNumberOfGoroutines() int {
|
||||||
|
valueFromEnv, err := strconv.Atoi(os.Getenv(MaxNumberOfGoroutinesEnvVarName))
|
||||||
|
if err != nil {
|
||||||
|
return MaxNumberOfGoroutinesDefaultValue
|
||||||
|
}
|
||||||
|
return valueFromEnv
|
||||||
|
}
|
||||||
|
|
||||||
func GetMemoryProfilingEnabled() bool {
|
func GetMemoryProfilingEnabled() bool {
|
||||||
return os.Getenv(MemoryProfilingEnabledEnvVarName) == "1"
|
return os.Getenv(MemoryProfilingEnabledEnvVarName) == "1"
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ type AppStats struct {
|
|||||||
ReassembledTcpPayloadsCount int64 `json:"reassembledTcpPayloadsCount"`
|
ReassembledTcpPayloadsCount int64 `json:"reassembledTcpPayloadsCount"`
|
||||||
TlsConnectionsCount int64 `json:"tlsConnectionsCount"`
|
TlsConnectionsCount int64 `json:"tlsConnectionsCount"`
|
||||||
MatchedPairs int64 `json:"matchedPairs"`
|
MatchedPairs int64 `json:"matchedPairs"`
|
||||||
|
DroppedTcpStreams int64 `json:"droppedTcpStreams"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StatsTracker struct {
|
type StatsTracker struct {
|
||||||
@ -23,6 +24,7 @@ type StatsTracker struct {
|
|||||||
reassembledTcpPayloadsCountMutex sync.Mutex
|
reassembledTcpPayloadsCountMutex sync.Mutex
|
||||||
tlsConnectionsCountMutex sync.Mutex
|
tlsConnectionsCountMutex sync.Mutex
|
||||||
matchedPairsMutex sync.Mutex
|
matchedPairsMutex sync.Mutex
|
||||||
|
droppedTcpStreamsMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *StatsTracker) incMatchedPairs() {
|
func (st *StatsTracker) incMatchedPairs() {
|
||||||
@ -31,6 +33,12 @@ func (st *StatsTracker) incMatchedPairs() {
|
|||||||
st.matchedPairsMutex.Unlock()
|
st.matchedPairsMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (st *StatsTracker) incDroppedTcpStreams() {
|
||||||
|
st.droppedTcpStreamsMutex.Lock()
|
||||||
|
st.appStats.DroppedTcpStreams++
|
||||||
|
st.droppedTcpStreamsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (st *StatsTracker) incPacketsCount() int64 {
|
func (st *StatsTracker) incPacketsCount() int64 {
|
||||||
st.packetsCountMutex.Lock()
|
st.packetsCountMutex.Lock()
|
||||||
st.appStats.PacketsCount++
|
st.appStats.PacketsCount++
|
||||||
@ -100,5 +108,10 @@ func (st *StatsTracker) dumpStats() *AppStats {
|
|||||||
st.appStats.MatchedPairs = 0
|
st.appStats.MatchedPairs = 0
|
||||||
st.matchedPairsMutex.Unlock()
|
st.matchedPairsMutex.Unlock()
|
||||||
|
|
||||||
|
st.droppedTcpStreamsMutex.Lock()
|
||||||
|
currentAppStats.DroppedTcpStreams = st.appStats.DroppedTcpStreams
|
||||||
|
st.appStats.DroppedTcpStreams = 0
|
||||||
|
st.droppedTcpStreamsMutex.Unlock()
|
||||||
|
|
||||||
return currentAppStats
|
return currentAppStats
|
||||||
}
|
}
|
||||||
|
@ -47,6 +47,7 @@ func (tid *tcpID) String() string {
|
|||||||
type tcpReader struct {
|
type tcpReader struct {
|
||||||
ident string
|
ident string
|
||||||
tcpID *api.TcpID
|
tcpID *api.TcpID
|
||||||
|
isClosed bool
|
||||||
isClient bool
|
isClient bool
|
||||||
isOutgoing bool
|
isOutgoing bool
|
||||||
msgQueue chan tcpReaderDataMsg // Channel of captured reassembled tcp payload
|
msgQueue chan tcpReaderDataMsg // Channel of captured reassembled tcp payload
|
||||||
@ -59,6 +60,7 @@ type tcpReader struct {
|
|||||||
extension *api.Extension
|
extension *api.Extension
|
||||||
emitter api.Emitter
|
emitter api.Emitter
|
||||||
counterPair *api.CounterPair
|
counterPair *api.CounterPair
|
||||||
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *tcpReader) Read(p []byte) (int, error) {
|
func (h *tcpReader) Read(p []byte) (int, error) {
|
||||||
@ -93,10 +95,19 @@ func (h *tcpReader) Read(p []byte) (int, error) {
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *tcpReader) Close() {
|
||||||
|
h.Lock()
|
||||||
|
if !h.isClosed {
|
||||||
|
h.isClosed = true
|
||||||
|
close(h.msgQueue)
|
||||||
|
}
|
||||||
|
h.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (h *tcpReader) run(wg *sync.WaitGroup) {
|
func (h *tcpReader) run(wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
b := bufio.NewReader(h)
|
b := bufio.NewReader(h)
|
||||||
err := h.extension.Dissector.Dissect(b, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.emitter)
|
err := h.extension.Dissector.Dissect(b, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
io.Copy(ioutil.Discard, b)
|
io.Copy(ioutil.Discard, b)
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers" // pulls in all layers decoders
|
"github.com/google/gopacket/layers" // pulls in all layers decoders
|
||||||
"github.com/google/gopacket/reassembly"
|
"github.com/google/gopacket/reassembly"
|
||||||
|
"github.com/up9inc/mizu/tap/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* It's a connection (bidirectional)
|
/* It's a connection (bidirectional)
|
||||||
@ -16,16 +17,19 @@ import (
|
|||||||
* In our implementation, we pass information from ReassembledSG to the tcpReader through a shared channel.
|
* In our implementation, we pass information from ReassembledSG to the tcpReader through a shared channel.
|
||||||
*/
|
*/
|
||||||
type tcpStream struct {
|
type tcpStream struct {
|
||||||
tcpstate *reassembly.TCPSimpleFSM
|
id int64
|
||||||
fsmerr bool
|
isClosed bool
|
||||||
optchecker reassembly.TCPOptionCheck
|
superIdentifier *api.SuperIdentifier
|
||||||
net, transport gopacket.Flow
|
tcpstate *reassembly.TCPSimpleFSM
|
||||||
isDNS bool
|
fsmerr bool
|
||||||
isTapTarget bool
|
optchecker reassembly.TCPOptionCheck
|
||||||
clients []tcpReader
|
net, transport gopacket.Flow
|
||||||
servers []tcpReader
|
isDNS bool
|
||||||
urls []string
|
isTapTarget bool
|
||||||
ident string
|
clients []tcpReader
|
||||||
|
servers []tcpReader
|
||||||
|
urls []string
|
||||||
|
ident string
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,12 +150,22 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
|
|||||||
statsTracker.incReassembledTcpPayloadsCount()
|
statsTracker.incReassembledTcpPayloadsCount()
|
||||||
timestamp := ac.GetCaptureInfo().Timestamp
|
timestamp := ac.GetCaptureInfo().Timestamp
|
||||||
if dir == reassembly.TCPDirClientToServer {
|
if dir == reassembly.TCPDirClientToServer {
|
||||||
for _, reader := range t.clients {
|
for i := range t.clients {
|
||||||
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
|
reader := &t.clients[i]
|
||||||
|
reader.Lock()
|
||||||
|
if !reader.isClosed {
|
||||||
|
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
|
||||||
|
}
|
||||||
|
reader.Unlock()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, reader := range t.servers {
|
for i := range t.servers {
|
||||||
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
|
reader := &t.servers[i]
|
||||||
|
reader.Lock()
|
||||||
|
if !reader.isClosed {
|
||||||
|
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
|
||||||
|
}
|
||||||
|
reader.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,14 +174,33 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
|
|||||||
|
|
||||||
func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
|
func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
|
||||||
Trace("%s: Connection closed", t.ident)
|
Trace("%s: Connection closed", t.ident)
|
||||||
if t.isTapTarget {
|
if t.isTapTarget && !t.isClosed {
|
||||||
for _, reader := range t.clients {
|
t.Close()
|
||||||
close(reader.msgQueue)
|
|
||||||
}
|
|
||||||
for _, reader := range t.servers {
|
|
||||||
close(reader.msgQueue)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// do not remove the connection to allow last ACK
|
// do not remove the connection to allow last ACK
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tcpStream) Close() {
|
||||||
|
shouldReturn := false
|
||||||
|
t.Lock()
|
||||||
|
if t.isClosed {
|
||||||
|
shouldReturn = true
|
||||||
|
} else {
|
||||||
|
t.isClosed = true
|
||||||
|
}
|
||||||
|
t.Unlock()
|
||||||
|
if shouldReturn {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
streams.Delete(t.id)
|
||||||
|
|
||||||
|
for i := range t.clients {
|
||||||
|
reader := &t.clients[i]
|
||||||
|
reader.Close()
|
||||||
|
}
|
||||||
|
for i := range t.servers {
|
||||||
|
reader := &t.servers[i]
|
||||||
|
reader.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,7 +2,9 @@ package tap
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/romana/rlog"
|
"github.com/romana/rlog"
|
||||||
"github.com/up9inc/mizu/tap/api"
|
"github.com/up9inc/mizu/tap/api"
|
||||||
@ -23,6 +25,16 @@ type tcpStreamFactory struct {
|
|||||||
Emitter api.Emitter
|
Emitter api.Emitter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tcpStreamWrapper struct {
|
||||||
|
stream *tcpStream
|
||||||
|
createdAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var streams *sync.Map = &sync.Map{} // global
|
||||||
|
var streamId int64 = 0
|
||||||
|
|
||||||
|
var maxNumberOfGoroutines int
|
||||||
|
|
||||||
func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
|
func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
|
||||||
rlog.Debugf("* NEW: %s %s", net, transport)
|
rlog.Debugf("* NEW: %s %s", net, transport)
|
||||||
fsmOptions := reassembly.TCPSimpleFSMOptions{
|
fsmOptions := reassembly.TCPSimpleFSMOptions{
|
||||||
@ -39,15 +51,23 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T
|
|||||||
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
|
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
|
||||||
isTapTarget := props.isTapTarget
|
isTapTarget := props.isTapTarget
|
||||||
stream := &tcpStream{
|
stream := &tcpStream{
|
||||||
net: net,
|
net: net,
|
||||||
transport: transport,
|
transport: transport,
|
||||||
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
|
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
|
||||||
isTapTarget: isTapTarget,
|
isTapTarget: isTapTarget,
|
||||||
tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions),
|
tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions),
|
||||||
ident: fmt.Sprintf("%s:%s", net, transport),
|
ident: fmt.Sprintf("%s:%s", net, transport),
|
||||||
optchecker: reassembly.NewTCPOptionCheck(),
|
optchecker: reassembly.NewTCPOptionCheck(),
|
||||||
|
superIdentifier: &api.SuperIdentifier{},
|
||||||
}
|
}
|
||||||
if stream.isTapTarget {
|
if stream.isTapTarget {
|
||||||
|
if runtime.NumGoroutine() > maxNumberOfGoroutines {
|
||||||
|
statsTracker.incDroppedTcpStreams()
|
||||||
|
rlog.Debugf("Dropped a TCP stream because of load. Total dropped: %d Total Goroutines: %d\n", statsTracker.appStats.DroppedTcpStreams, runtime.NumGoroutine())
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
streamId++
|
||||||
|
stream.id = streamId
|
||||||
for i, extension := range extensions {
|
for i, extension := range extensions {
|
||||||
counterPair := &api.CounterPair{
|
counterPair := &api.CounterPair{
|
||||||
Request: 0,
|
Request: 0,
|
||||||
@ -89,6 +109,12 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T
|
|||||||
emitter: factory.Emitter,
|
emitter: factory.Emitter,
|
||||||
counterPair: counterPair,
|
counterPair: counterPair,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
streams.Store(stream.id, &tcpStreamWrapper{
|
||||||
|
stream: stream,
|
||||||
|
createdAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
factory.wg.Add(2)
|
factory.wg.Add(2)
|
||||||
// Start reading from channel stream.reader.bytes
|
// Start reading from channel stream.reader.bytes
|
||||||
go stream.clients[i].run(&factory.wg)
|
go stream.clients[i].run(&factory.wg)
|
||||||
@ -119,7 +145,7 @@ func (factory *tcpStreamFactory) getStreamProps(srcIP string, srcPort string, ds
|
|||||||
}
|
}
|
||||||
return &streamProps{isTapTarget: false, isOutgoing: false}
|
return &streamProps{isTapTarget: false, isOutgoing: false}
|
||||||
} else {
|
} else {
|
||||||
rlog.Debugf("getStreamProps %s", fmt.Sprintf("+ notHost3 %s -> %s:%s", srcIP, dstIP, dstPort))
|
rlog.Debugf("getStreamProps %s", fmt.Sprintf("+ notHost3 %s:%s -> %s:%s", srcIP, srcPort, dstIP, dstPort))
|
||||||
return &streamProps{isTapTarget: true}
|
return &streamProps{isTapTarget: true}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user