mirror of
				https://github.com/kubeshark/kubeshark.git
				synced 2025-10-22 15:58:44 +00:00 
			
		
		
		
	* close gopacket conn immediately * increase last ack threshold to 3 seconds * remove empty condition * add periodic stats * protect connections from concurrent updates * implement basic throttling base on live streams count * remove assembler mutex * pr fixes * change max conns default to 500 * create connectionId type * fix linter
		
			
				
	
	
		
			255 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			255 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package tap
 | |
| 
 | |
| import (
 | |
| 	"encoding/hex"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"os/signal"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/google/gopacket"
 | |
| 	"github.com/google/gopacket/layers"
 | |
| 	"github.com/google/gopacket/reassembly"
 | |
| 	"github.com/hashicorp/golang-lru/simplelru"
 | |
| 	"github.com/up9inc/mizu/logger"
 | |
| 	"github.com/up9inc/mizu/tap/api"
 | |
| 	"github.com/up9inc/mizu/tap/dbgctl"
 | |
| 	"github.com/up9inc/mizu/tap/diagnose"
 | |
| 	"github.com/up9inc/mizu/tap/source"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	lastClosedConnectionsMaxItems = 1000
 | |
| 	packetsSeenLogThreshold       = 1000
 | |
| 	lastAckThreshold              = time.Duration(3) * time.Second
 | |
| )
 | |
| 
 | |
| type connectionId string
 | |
| 
 | |
| func NewConnectionId(c string) connectionId {
 | |
| 	return connectionId(c)
 | |
| }
 | |
| 
 | |
| type AssemblerStats struct {
 | |
| 	flushedConnections int
 | |
| 	closedConnections  int
 | |
| }
 | |
| 
 | |
| type tcpAssembler struct {
 | |
| 	*reassembly.Assembler
 | |
| 	streamPool             *reassembly.StreamPool
 | |
| 	streamFactory          *tcpStreamFactory
 | |
| 	ignoredPorts           []uint16
 | |
| 	lastClosedConnections  *simplelru.LRU // Actual type is map[string]int64 which is "connId -> lastSeen"
 | |
| 	liveConnections        map[connectionId]bool
 | |
| 	maxLiveStreams         int
 | |
| 	staleConnectionTimeout time.Duration
 | |
| 	stats                  AssemblerStats
 | |
| }
 | |
| 
 | |
| // Context
 | |
| // The assembler context
 | |
| type context struct {
 | |
| 	CaptureInfo gopacket.CaptureInfo
 | |
| 	Origin      api.Capture
 | |
| }
 | |
| 
 | |
| func (c *context) GetCaptureInfo() gopacket.CaptureInfo {
 | |
| 	return c.CaptureInfo
 | |
| }
 | |
| 
 | |
| func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap, opts *TapOpts) (*tcpAssembler, error) {
 | |
| 	var emitter api.Emitter = &api.Emitting{
 | |
| 		AppStats:      &diagnose.AppStats,
 | |
| 		OutputChannel: outputItems,
 | |
| 	}
 | |
| 
 | |
| 	lastClosedConnections, err := simplelru.NewLRU(lastClosedConnectionsMaxItems, func(key interface{}, value interface{}) {})
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	a := &tcpAssembler{
 | |
| 		ignoredPorts:           opts.IgnoredPorts,
 | |
| 		lastClosedConnections:  lastClosedConnections,
 | |
| 		liveConnections:        make(map[connectionId]bool),
 | |
| 		maxLiveStreams:         opts.maxLiveStreams,
 | |
| 		staleConnectionTimeout: opts.staleConnectionTimeout,
 | |
| 		stats:                  AssemblerStats{},
 | |
| 	}
 | |
| 
 | |
| 	a.streamFactory = NewTcpStreamFactory(emitter, streamsMap, opts, a)
 | |
| 	a.streamPool = reassembly.NewStreamPool(a.streamFactory)
 | |
| 	a.Assembler = reassembly.NewAssembler(a.streamPool)
 | |
| 
 | |
| 	maxBufferedPagesTotal := GetMaxBufferedPagesPerConnection()
 | |
| 	maxBufferedPagesPerConnection := GetMaxBufferedPagesTotal()
 | |
| 	logger.Log.Infof("Assembler options: maxBufferedPagesTotal=%d, maxBufferedPagesPerConnection=%d, opts=%v",
 | |
| 		maxBufferedPagesTotal, maxBufferedPagesPerConnection, opts)
 | |
| 	a.Assembler.AssemblerOptions.MaxBufferedPagesTotal = maxBufferedPagesTotal
 | |
| 	a.Assembler.AssemblerOptions.MaxBufferedPagesPerConnection = maxBufferedPagesPerConnection
 | |
| 
 | |
| 	return a, nil
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.TcpPacketInfo) {
 | |
| 	signalChan := make(chan os.Signal, 1)
 | |
| 	signal.Notify(signalChan, os.Interrupt)
 | |
| 	ticker := time.NewTicker(a.staleConnectionTimeout)
 | |
| 
 | |
| out:
 | |
| 	for {
 | |
| 		select {
 | |
| 		case packetInfo, ok := <-packets:
 | |
| 			if !ok {
 | |
| 				break out
 | |
| 			}
 | |
| 			if a.processPacket(packetInfo, dumpPacket) {
 | |
| 				break out
 | |
| 			}
 | |
| 		case <-signalChan:
 | |
| 			logger.Log.Infof("Caught SIGINT: aborting")
 | |
| 			break out
 | |
| 		case <-ticker.C:
 | |
| 			a.periodicClean()
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	closed := a.FlushAll()
 | |
| 	logger.Log.Debugf("Final flush: %d closed", closed)
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) processPacket(packetInfo source.TcpPacketInfo, dumpPacket bool) bool {
 | |
| 	packetsCount := diagnose.AppStats.IncPacketsCount()
 | |
| 
 | |
| 	if packetsCount%packetsSeenLogThreshold == 0 {
 | |
| 		logger.Log.Debugf("Packets seen: #%d", packetsCount)
 | |
| 	}
 | |
| 
 | |
| 	packet := packetInfo.Packet
 | |
| 	data := packet.Data()
 | |
| 	diagnose.AppStats.UpdateProcessedBytes(uint64(len(data)))
 | |
| 	if dumpPacket {
 | |
| 		logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data))
 | |
| 	}
 | |
| 
 | |
| 	tcp := packet.Layer(layers.LayerTypeTCP)
 | |
| 	if tcp != nil {
 | |
| 		a.processTcpPacket(packetInfo.Source.Origin, packet, tcp.(*layers.TCP))
 | |
| 	}
 | |
| 
 | |
| 	done := *maxcount > 0 && int64(diagnose.AppStats.PacketsCount) >= *maxcount
 | |
| 	if done {
 | |
| 		errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary()
 | |
| 		logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)",
 | |
| 			diagnose.AppStats.PacketsCount,
 | |
| 			diagnose.AppStats.ProcessedBytes,
 | |
| 			time.Since(diagnose.AppStats.StartTime),
 | |
| 			diagnose.TapErrors.ErrorsCount,
 | |
| 			errorMapLen)
 | |
| 	}
 | |
| 	return done
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) processTcpPacket(origin api.Capture, packet gopacket.Packet, tcp *layers.TCP) {
 | |
| 	diagnose.AppStats.IncTcpPacketsCount()
 | |
| 	if a.shouldIgnorePort(uint16(tcp.DstPort)) || a.shouldIgnorePort(uint16(tcp.SrcPort)) {
 | |
| 		diagnose.AppStats.IncIgnoredPacketsCount()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	id := getConnectionId(packet.NetworkLayer().NetworkFlow().Src().String(),
 | |
| 		packet.TransportLayer().TransportFlow().Src().String(),
 | |
| 		packet.NetworkLayer().NetworkFlow().Dst().String(),
 | |
| 		packet.TransportLayer().TransportFlow().Dst().String())
 | |
| 
 | |
| 	if a.isRecentlyClosed(id) {
 | |
| 		diagnose.AppStats.IncIgnoredLastAckCount()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if a.shouldThrottle(id) {
 | |
| 		diagnose.AppStats.IncThrottledPackets()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	c := context{
 | |
| 		CaptureInfo: packet.Metadata().CaptureInfo,
 | |
| 		Origin:      origin,
 | |
| 	}
 | |
| 	diagnose.InternalStats.Totalsz += len(tcp.Payload)
 | |
| 	if !dbgctl.MizuTapperDisableTcpReassembly {
 | |
| 		a.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) tcpStreamCreated(stream *tcpStream) {
 | |
| 	a.liveConnections[stream.connectionId] = true
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) tcpStreamClosed(stream *tcpStream) {
 | |
| 	a.lastClosedConnections.Add(stream.connectionId, time.Now().UnixMilli())
 | |
| 	delete(a.liveConnections, stream.connectionId)
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) isRecentlyClosed(c connectionId) bool {
 | |
| 	if closedTimeMillis, ok := a.lastClosedConnections.Get(c); ok {
 | |
| 		timeSinceClosed := time.Since(time.UnixMilli(closedTimeMillis.(int64)))
 | |
| 		if timeSinceClosed < lastAckThreshold {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) shouldThrottle(c connectionId) bool {
 | |
| 	if _, ok := a.liveConnections[c]; ok {
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	return len(a.liveConnections) > a.maxLiveStreams
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) dumpStreamPool() {
 | |
| 	a.streamPool.Dump()
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) waitAndDump() {
 | |
| 	a.streamFactory.WaitGoRoutines()
 | |
| 	logger.Log.Debugf("%s", a.Dump())
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) shouldIgnorePort(port uint16) bool {
 | |
| 	for _, p := range a.ignoredPorts {
 | |
| 		if port == p {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) periodicClean() {
 | |
| 	flushed, closed := a.FlushCloseOlderThan(time.Now().Add(-a.staleConnectionTimeout))
 | |
| 	stats := a.stats
 | |
| 	stats.closedConnections += closed
 | |
| 	stats.flushedConnections += flushed
 | |
| }
 | |
| 
 | |
| func (a *tcpAssembler) DumpStats() AssemblerStats {
 | |
| 	result := a.stats
 | |
| 	a.stats = AssemblerStats{}
 | |
| 	return result
 | |
| }
 | |
| 
 | |
| func getConnectionId(saddr string, sport string, daddr string, dport string) connectionId {
 | |
| 	s := fmt.Sprintf("%s:%s", saddr, sport)
 | |
| 	d := fmt.Sprintf("%s:%s", daddr, dport)
 | |
| 	if s > d {
 | |
| 		return NewConnectionId(fmt.Sprintf("%s#%s", s, d))
 | |
| 	} else {
 | |
| 		return NewConnectionId(fmt.Sprintf("%s#%s", d, s))
 | |
| 	}
 | |
| }
 |