Close gopacket connection immediately, basic throttling and assembler mutex removal

* close gopacket conn immediately

* increase last ack threshold to 3 seconds

* remove empty condition

* add periodic stats

* protect connections from concurrent updates

* implement basic throttling base on live streams count

* remove assembler mutex

* pr fixes

* change max conns default to 500

* create connectionId type

* fix linter
This commit is contained in:
David Levanon
2022-06-19 16:47:03 +03:00
committed by GitHub
parent 167bbe3741
commit 99cb0b4f44
11 changed files with 273 additions and 124 deletions

View File

@@ -57,7 +57,7 @@ def extract_samples(f: typing.IO) -> typing.Tuple[pd.Series, pd.Series, pd.Serie
append_sample('"matchedPairs"', line, matched_samples) append_sample('"matchedPairs"', line, matched_samples)
append_sample('"liveTcpStreams"', line, live_samples) append_sample('"liveTcpStreams"', line, live_samples)
append_sample('"processedBytes"', line, processed_samples) append_sample('"processedBytes"', line, processed_samples)
append_sample('mem', line, heap_samples) append_sample('heap-alloc', line, heap_samples)
append_sample('goroutines', line, goroutines_samples) append_sample('goroutines', line, goroutines_samples)
cpu_samples = pd.Series(cpu_samples) cpu_samples = pd.Series(cpu_samples)

View File

@@ -16,6 +16,8 @@ type AppStats struct {
MatchedPairs uint64 `json:"matchedPairs"` MatchedPairs uint64 `json:"matchedPairs"`
DroppedTcpStreams uint64 `json:"droppedTcpStreams"` DroppedTcpStreams uint64 `json:"droppedTcpStreams"`
LiveTcpStreams uint64 `json:"liveTcpStreams"` LiveTcpStreams uint64 `json:"liveTcpStreams"`
IgnoredLastAckCount uint64 `json:"ignoredLastAckCount"`
ThrottledPackets uint64 `json:"throttledPackets"`
} }
func (as *AppStats) IncMatchedPairs() { func (as *AppStats) IncMatchedPairs() {
@@ -39,6 +41,14 @@ func (as *AppStats) IncIgnoredPacketsCount() {
atomic.AddUint64(&as.IgnoredPacketsCount, 1) atomic.AddUint64(&as.IgnoredPacketsCount, 1)
} }
func (as *AppStats) IncIgnoredLastAckCount() {
atomic.AddUint64(&as.IgnoredLastAckCount, 1)
}
func (as *AppStats) IncThrottledPackets() {
atomic.AddUint64(&as.ThrottledPackets, 1)
}
func (as *AppStats) IncReassembledTcpPayloadsCount() { func (as *AppStats) IncReassembledTcpPayloadsCount() {
atomic.AddUint64(&as.ReassembledTcpPayloadsCount, 1) atomic.AddUint64(&as.ReassembledTcpPayloadsCount, 1)
} }
@@ -74,6 +84,8 @@ func (as *AppStats) DumpStats() *AppStats {
currentAppStats.TlsConnectionsCount = resetUint64(&as.TlsConnectionsCount) currentAppStats.TlsConnectionsCount = resetUint64(&as.TlsConnectionsCount)
currentAppStats.MatchedPairs = resetUint64(&as.MatchedPairs) currentAppStats.MatchedPairs = resetUint64(&as.MatchedPairs)
currentAppStats.DroppedTcpStreams = resetUint64(&as.DroppedTcpStreams) currentAppStats.DroppedTcpStreams = resetUint64(&as.DroppedTcpStreams)
currentAppStats.IgnoredLastAckCount = resetUint64(&as.IgnoredLastAckCount)
currentAppStats.ThrottledPackets = resetUint64(&as.ThrottledPackets)
currentAppStats.LiveTcpStreams = as.LiveTcpStreams currentAppStats.LiveTcpStreams = as.LiveTcpStreams
return currentAppStats return currentAppStats

View File

@@ -10,14 +10,11 @@ import (
) )
type CleanerStats struct { type CleanerStats struct {
flushed int
closed int
deleted int deleted int
} }
type Cleaner struct { type Cleaner struct {
assembler *reassembly.Assembler assembler *reassembly.Assembler
assemblerMutex *sync.Mutex
cleanPeriod time.Duration cleanPeriod time.Duration
connectionTimeout time.Duration connectionTimeout time.Duration
stats CleanerStats stats CleanerStats
@@ -28,11 +25,6 @@ type Cleaner struct {
func (cl *Cleaner) clean() { func (cl *Cleaner) clean() {
startCleanTime := time.Now() startCleanTime := time.Now()
cl.assemblerMutex.Lock()
logger.Log.Debugf("Assembler Stats before cleaning %s", cl.assembler.Dump())
flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout))
cl.assemblerMutex.Unlock()
cl.streamsMap.Range(func(k, v interface{}) bool { cl.streamsMap.Range(func(k, v interface{}) bool {
reqResMatchers := v.(api.TcpStream).GetReqResMatchers() reqResMatchers := v.(api.TcpStream).GetReqResMatchers()
for _, reqResMatcher := range reqResMatchers { for _, reqResMatcher := range reqResMatchers {
@@ -47,8 +39,6 @@ func (cl *Cleaner) clean() {
cl.statsMutex.Lock() cl.statsMutex.Lock()
logger.Log.Debugf("Assembler Stats after cleaning %s", cl.assembler.Dump()) logger.Log.Debugf("Assembler Stats after cleaning %s", cl.assembler.Dump())
cl.stats.flushed += flushed
cl.stats.closed += closed
cl.statsMutex.Unlock() cl.statsMutex.Unlock()
} }
@@ -67,17 +57,12 @@ func (cl *Cleaner) dumpStats() CleanerStats {
cl.statsMutex.Lock() cl.statsMutex.Lock()
stats := CleanerStats{ stats := CleanerStats{
flushed: cl.stats.flushed,
closed: cl.stats.closed,
deleted: cl.stats.deleted, deleted: cl.stats.deleted,
} }
cl.stats.flushed = 0
cl.stats.closed = 0
cl.stats.deleted = 0 cl.stats.deleted = 0
cl.statsMutex.Unlock() cl.statsMutex.Unlock()
return stats return stats
} }

View File

@@ -45,6 +45,7 @@ var quiet = flag.Bool("quiet", false, "Be quiet regarding errors")
var hexdumppkt = flag.Bool("dumppkt", false, "Dump packet as hex") var hexdumppkt = flag.Bool("dumppkt", false, "Dump packet as hex")
var procfs = flag.String("procfs", "/proc", "The procfs directory, used when mapping host volumes into a container") var procfs = flag.String("procfs", "/proc", "The procfs directory, used when mapping host volumes into a container")
var ignoredPorts = flag.String("ignore-ports", "", "A comma separated list of ports to ignore") var ignoredPorts = flag.String("ignore-ports", "", "A comma separated list of ports to ignore")
var maxLiveStreams = flag.Int("max-live-streams", 500, "Maximum live streams to handle concurrently")
// capture // capture
var iface = flag.String("i", "en0", "Interface to read packets from") var iface = flag.String("i", "en0", "Interface to read packets from")
@@ -59,8 +60,10 @@ var tls = flag.Bool("tls", false, "Enable TLS tapper")
var memprofile = flag.String("memprofile", "", "Write memory profile") var memprofile = flag.String("memprofile", "", "Write memory profile")
type TapOpts struct { type TapOpts struct {
HostMode bool HostMode bool
IgnoredPorts []uint16 IgnoredPorts []uint16
maxLiveStreams int
staleConnectionTimeout time.Duration
} }
var extensions []*api.Extension // global var extensions []*api.Extension // global
@@ -89,7 +92,13 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem,
diagnose.StartMemoryProfiler(os.Getenv(MemoryProfilingDumpPath), os.Getenv(MemoryProfilingTimeIntervalSeconds)) diagnose.StartMemoryProfiler(os.Getenv(MemoryProfilingDumpPath), os.Getenv(MemoryProfilingTimeIntervalSeconds))
} }
assembler := initializePassiveTapper(opts, outputItems, streamsMap) assembler, err := initializePassiveTapper(opts, outputItems, streamsMap)
if err != nil {
logger.Log.Errorf("Error initializing tapper %w", err)
return
}
go startPassiveTapper(streamsMap, assembler) go startPassiveTapper(streamsMap, assembler)
} }
@@ -124,7 +133,7 @@ func printNewTapTargets(success bool) {
} }
} }
func printPeriodicStats(cleaner *Cleaner) { func printPeriodicStats(cleaner *Cleaner, assembler *tcpAssembler) {
statsPeriod := time.Second * time.Duration(*statsevery) statsPeriod := time.Second * time.Duration(*statsevery)
ticker := time.NewTicker(statsPeriod) ticker := time.NewTicker(statsPeriod)
@@ -162,8 +171,10 @@ func printPeriodicStats(cleaner *Cleaner) {
} }
} }
logger.Log.Infof( logger.Log.Infof(
"mem: %d, goroutines: %d, cpu: %f, cores: %d/%d, rss: %f", "heap-alloc: %d, heap-idle: %d, heap-objects: %d, goroutines: %d, cpu: %f, cores: %d/%d, rss: %f",
memStats.HeapAlloc, memStats.HeapAlloc,
memStats.HeapIdle,
memStats.HeapObjects,
runtime.NumGoroutine(), runtime.NumGoroutine(),
sysInfo.CPU, sysInfo.CPU,
logicalCoreCount, logicalCoreCount,
@@ -172,15 +183,19 @@ func printPeriodicStats(cleaner *Cleaner) {
// Since the last print // Since the last print
cleanStats := cleaner.dumpStats() cleanStats := cleaner.dumpStats()
assemblerStats := assembler.DumpStats()
logger.Log.Infof( logger.Log.Infof(
"cleaner - flushed connections: %d, closed connections: %d, deleted messages: %d", "cleaner - flushed connections: %d, closed connections: %d, deleted messages: %d",
cleanStats.flushed, assemblerStats.flushedConnections,
cleanStats.closed, assemblerStats.closedConnections,
cleanStats.deleted, cleanStats.deleted,
) )
currentAppStats := diagnose.AppStats.DumpStats() currentAppStats := diagnose.AppStats.DumpStats()
appStatsJSON, _ := json.Marshal(currentAppStats) appStatsJSON, _ := json.Marshal(currentAppStats)
logger.Log.Infof("app stats - %v", string(appStatsJSON)) logger.Log.Infof("app stats - %v", string(appStatsJSON))
// At the moment
logger.Log.Infof("assembler-stats: %s, packet-source-stats: %s", assembler.Dump(), packetSourceManager.Stats())
} }
} }
@@ -208,7 +223,7 @@ func initializePacketSources() error {
return err return err
} }
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap) *tcpAssembler { func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap) (*tcpAssembler, error) {
diagnose.InitializeErrorsMap(*debug, *verbose, *quiet) diagnose.InitializeErrorsMap(*debug, *verbose, *quiet)
diagnose.InitializeTapperInternalStats() diagnose.InitializeTapperInternalStats()
@@ -219,10 +234,10 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI
} }
opts.IgnoredPorts = append(opts.IgnoredPorts, buildIgnoredPortsList(*ignoredPorts)...) opts.IgnoredPorts = append(opts.IgnoredPorts, buildIgnoredPortsList(*ignoredPorts)...)
opts.maxLiveStreams = *maxLiveStreams
opts.staleConnectionTimeout = time.Duration(*staleTimeoutSeconds) * time.Second
assembler := NewTcpAssembler(outputItems, streamsMap, opts) return NewTcpAssembler(outputItems, streamsMap, opts)
return assembler
} }
func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) { func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) {
@@ -233,14 +248,13 @@ func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) {
staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds) staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds)
cleaner := Cleaner{ cleaner := Cleaner{
assembler: assembler.Assembler, assembler: assembler.Assembler,
assemblerMutex: &assembler.assemblerMutex,
cleanPeriod: cleanPeriod, cleanPeriod: cleanPeriod,
connectionTimeout: staleConnectionTimeout, connectionTimeout: staleConnectionTimeout,
streamsMap: streamsMap, streamsMap: streamsMap,
} }
cleaner.start() cleaner.start()
go printPeriodicStats(&cleaner) go printPeriodicStats(&cleaner, assembler)
assembler.processPackets(*hexdumppkt, mainPacketInputChan) assembler.processPackets(*hexdumppkt, mainPacketInputChan)

View File

@@ -160,3 +160,19 @@ func (m *PacketSourceManager) Close() {
src.close() src.close()
} }
} }
func (m *PacketSourceManager) Stats() string {
result := ""
for _, source := range m.sources {
stats, err := source.Stats()
if err != nil {
result = result + fmt.Sprintf("[%s: err:%s]", source.String(), err)
} else {
result = result + fmt.Sprintf("[%s: rec: %d dropped: %d]", source.String(), stats.PacketsReceived, stats.PacketsDropped)
}
}
return result
}

View File

@@ -116,6 +116,10 @@ func (source *tcpPacketSource) close() {
} }
} }
func (source *tcpPacketSource) Stats() (stat *pcap.Stats, err error) {
return source.handle.Stats()
}
func (source *tcpPacketSource) readPackets(ipdefrag bool, packets chan<- TcpPacketInfo) { func (source *tcpPacketSource) readPackets(ipdefrag bool, packets chan<- TcpPacketInfo) {
if dbgctl.MizuTapperDisablePcap { if dbgctl.MizuTapperDisablePcap {
return return

View File

@@ -2,14 +2,15 @@ package tap
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"os" "os"
"os/signal" "os/signal"
"sync"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly" "github.com/google/gopacket/reassembly"
"github.com/hashicorp/golang-lru/simplelru"
"github.com/up9inc/mizu/logger" "github.com/up9inc/mizu/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/dbgctl" "github.com/up9inc/mizu/tap/dbgctl"
@@ -17,14 +18,33 @@ import (
"github.com/up9inc/mizu/tap/source" "github.com/up9inc/mizu/tap/source"
) )
const PACKETS_SEEN_LOG_THRESHOLD = 1000 const (
lastClosedConnectionsMaxItems = 1000
packetsSeenLogThreshold = 1000
lastAckThreshold = time.Duration(3) * time.Second
)
type connectionId string
func NewConnectionId(c string) connectionId {
return connectionId(c)
}
type AssemblerStats struct {
flushedConnections int
closedConnections int
}
type tcpAssembler struct { type tcpAssembler struct {
*reassembly.Assembler *reassembly.Assembler
streamPool *reassembly.StreamPool streamPool *reassembly.StreamPool
streamFactory *tcpStreamFactory streamFactory *tcpStreamFactory
assemblerMutex sync.Mutex ignoredPorts []uint16
ignoredPorts []uint16 lastClosedConnections *simplelru.LRU // Actual type is map[string]int64 which is "connId -> lastSeen"
liveConnections map[connectionId]bool
maxLiveStreams int
staleConnectionTimeout time.Duration
stats AssemblerStats
} }
// Context // Context
@@ -38,108 +58,166 @@ func (c *context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo return c.CaptureInfo
} }
func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpAssembler { func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap, opts *TapOpts) (*tcpAssembler, error) {
var emitter api.Emitter = &api.Emitting{ var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats, AppStats: &diagnose.AppStats,
OutputChannel: outputItems, OutputChannel: outputItems,
} }
streamFactory := NewTcpStreamFactory(emitter, streamsMap, opts) lastClosedConnections, err := simplelru.NewLRU(lastClosedConnectionsMaxItems, func(key interface{}, value interface{}) {})
streamPool := reassembly.NewStreamPool(streamFactory)
assembler := reassembly.NewAssembler(streamPool) if err != nil {
return nil, err
}
a := &tcpAssembler{
ignoredPorts: opts.IgnoredPorts,
lastClosedConnections: lastClosedConnections,
liveConnections: make(map[connectionId]bool),
maxLiveStreams: opts.maxLiveStreams,
staleConnectionTimeout: opts.staleConnectionTimeout,
stats: AssemblerStats{},
}
a.streamFactory = NewTcpStreamFactory(emitter, streamsMap, opts, a)
a.streamPool = reassembly.NewStreamPool(a.streamFactory)
a.Assembler = reassembly.NewAssembler(a.streamPool)
maxBufferedPagesTotal := GetMaxBufferedPagesPerConnection() maxBufferedPagesTotal := GetMaxBufferedPagesPerConnection()
maxBufferedPagesPerConnection := GetMaxBufferedPagesTotal() maxBufferedPagesPerConnection := GetMaxBufferedPagesTotal()
logger.Log.Infof("Assembler options: maxBufferedPagesTotal=%d, maxBufferedPagesPerConnection=%d, opts=%v", logger.Log.Infof("Assembler options: maxBufferedPagesTotal=%d, maxBufferedPagesPerConnection=%d, opts=%v",
maxBufferedPagesTotal, maxBufferedPagesPerConnection, opts) maxBufferedPagesTotal, maxBufferedPagesPerConnection, opts)
assembler.AssemblerOptions.MaxBufferedPagesTotal = maxBufferedPagesTotal a.Assembler.AssemblerOptions.MaxBufferedPagesTotal = maxBufferedPagesTotal
assembler.AssemblerOptions.MaxBufferedPagesPerConnection = maxBufferedPagesPerConnection a.Assembler.AssemblerOptions.MaxBufferedPagesPerConnection = maxBufferedPagesPerConnection
return &tcpAssembler{ return a, nil
Assembler: assembler,
streamPool: streamPool,
streamFactory: streamFactory,
ignoredPorts: opts.IgnoredPorts,
}
} }
func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.TcpPacketInfo) { func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.TcpPacketInfo) {
signalChan := make(chan os.Signal, 1) signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt) signal.Notify(signalChan, os.Interrupt)
ticker := time.NewTicker(a.staleConnectionTimeout)
for packetInfo := range packets { out:
packetsCount := diagnose.AppStats.IncPacketsCount() for {
if packetsCount%PACKETS_SEEN_LOG_THRESHOLD == 0 {
logger.Log.Debugf("Packets seen: #%d", packetsCount)
}
packet := packetInfo.Packet
data := packet.Data()
diagnose.AppStats.UpdateProcessedBytes(uint64(len(data)))
if dumpPacket {
logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data))
}
tcp := packet.Layer(layers.LayerTypeTCP)
if tcp != nil {
diagnose.AppStats.IncTcpPacketsCount()
tcp := tcp.(*layers.TCP)
if a.shouldIgnorePort(uint16(tcp.DstPort)) {
diagnose.AppStats.IncIgnoredPacketsCount()
} else {
c := context{
CaptureInfo: packet.Metadata().CaptureInfo,
Origin: packetInfo.Source.Origin,
}
diagnose.InternalStats.Totalsz += len(tcp.Payload)
if !dbgctl.MizuTapperDisableTcpReassembly {
a.assemblerMutex.Lock()
a.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c)
a.assemblerMutex.Unlock()
}
}
}
done := *maxcount > 0 && int64(diagnose.AppStats.PacketsCount) >= *maxcount
if done {
errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)",
diagnose.AppStats.PacketsCount,
diagnose.AppStats.ProcessedBytes,
time.Since(diagnose.AppStats.StartTime),
diagnose.TapErrors.ErrorsCount,
errorMapLen)
}
select { select {
case packetInfo, ok := <-packets:
if !ok {
break out
}
if a.processPacket(packetInfo, dumpPacket) {
break out
}
case <-signalChan: case <-signalChan:
logger.Log.Infof("Caught SIGINT: aborting") logger.Log.Infof("Caught SIGINT: aborting")
done = true break out
default: case <-ticker.C:
// NOP: continue a.periodicClean()
}
if done {
break
} }
} }
a.assemblerMutex.Lock()
closed := a.FlushAll() closed := a.FlushAll()
a.assemblerMutex.Unlock()
logger.Log.Debugf("Final flush: %d closed", closed) logger.Log.Debugf("Final flush: %d closed", closed)
} }
func (a *tcpAssembler) processPacket(packetInfo source.TcpPacketInfo, dumpPacket bool) bool {
packetsCount := diagnose.AppStats.IncPacketsCount()
if packetsCount%packetsSeenLogThreshold == 0 {
logger.Log.Debugf("Packets seen: #%d", packetsCount)
}
packet := packetInfo.Packet
data := packet.Data()
diagnose.AppStats.UpdateProcessedBytes(uint64(len(data)))
if dumpPacket {
logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data))
}
tcp := packet.Layer(layers.LayerTypeTCP)
if tcp != nil {
a.processTcpPacket(packetInfo.Source.Origin, packet, tcp.(*layers.TCP))
}
done := *maxcount > 0 && int64(diagnose.AppStats.PacketsCount) >= *maxcount
if done {
errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)",
diagnose.AppStats.PacketsCount,
diagnose.AppStats.ProcessedBytes,
time.Since(diagnose.AppStats.StartTime),
diagnose.TapErrors.ErrorsCount,
errorMapLen)
}
return done
}
func (a *tcpAssembler) processTcpPacket(origin api.Capture, packet gopacket.Packet, tcp *layers.TCP) {
diagnose.AppStats.IncTcpPacketsCount()
if a.shouldIgnorePort(uint16(tcp.DstPort)) || a.shouldIgnorePort(uint16(tcp.SrcPort)) {
diagnose.AppStats.IncIgnoredPacketsCount()
return
}
id := getConnectionId(packet.NetworkLayer().NetworkFlow().Src().String(),
packet.TransportLayer().TransportFlow().Src().String(),
packet.NetworkLayer().NetworkFlow().Dst().String(),
packet.TransportLayer().TransportFlow().Dst().String())
if a.isRecentlyClosed(id) {
diagnose.AppStats.IncIgnoredLastAckCount()
return
}
if a.shouldThrottle(id) {
diagnose.AppStats.IncThrottledPackets()
return
}
c := context{
CaptureInfo: packet.Metadata().CaptureInfo,
Origin: origin,
}
diagnose.InternalStats.Totalsz += len(tcp.Payload)
if !dbgctl.MizuTapperDisableTcpReassembly {
a.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c)
}
}
func (a *tcpAssembler) tcpStreamCreated(stream *tcpStream) {
a.liveConnections[stream.connectionId] = true
}
func (a *tcpAssembler) tcpStreamClosed(stream *tcpStream) {
a.lastClosedConnections.Add(stream.connectionId, time.Now().UnixMilli())
delete(a.liveConnections, stream.connectionId)
}
func (a *tcpAssembler) isRecentlyClosed(c connectionId) bool {
if closedTimeMillis, ok := a.lastClosedConnections.Get(c); ok {
timeSinceClosed := time.Since(time.UnixMilli(closedTimeMillis.(int64)))
if timeSinceClosed < lastAckThreshold {
return true
}
}
return false
}
func (a *tcpAssembler) shouldThrottle(c connectionId) bool {
if _, ok := a.liveConnections[c]; ok {
return false
}
return len(a.liveConnections) > a.maxLiveStreams
}
func (a *tcpAssembler) dumpStreamPool() { func (a *tcpAssembler) dumpStreamPool() {
a.streamPool.Dump() a.streamPool.Dump()
} }
func (a *tcpAssembler) waitAndDump() { func (a *tcpAssembler) waitAndDump() {
a.streamFactory.WaitGoRoutines() a.streamFactory.WaitGoRoutines()
a.assemblerMutex.Lock()
logger.Log.Debugf("%s", a.Dump()) logger.Log.Debugf("%s", a.Dump())
a.assemblerMutex.Unlock()
} }
func (a *tcpAssembler) shouldIgnorePort(port uint16) bool { func (a *tcpAssembler) shouldIgnorePort(port uint16) bool {
@@ -151,3 +229,26 @@ func (a *tcpAssembler) shouldIgnorePort(port uint16) bool {
return false return false
} }
func (a *tcpAssembler) periodicClean() {
flushed, closed := a.FlushCloseOlderThan(time.Now().Add(-a.staleConnectionTimeout))
stats := a.stats
stats.closedConnections += closed
stats.flushedConnections += flushed
}
func (a *tcpAssembler) DumpStats() AssemblerStats {
result := a.stats
a.stats = AssemblerStats{}
return result
}
func getConnectionId(saddr string, sport string, daddr string, dport string) connectionId {
s := fmt.Sprintf("%s:%s", saddr, sport)
d := fmt.Sprintf("%s:%s", daddr, dport)
if s > d {
return NewConnectionId(fmt.Sprintf("%s#%s", s, d))
} else {
return NewConnectionId(fmt.Sprintf("%s#%s", d, s))
}
}

View File

@@ -151,6 +151,6 @@ func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext)
if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() { if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() {
t.tcpStream.close() t.tcpStream.close()
} }
// do not remove the connection to allow last ACK
return false return true
} }

View File

@@ -8,6 +8,11 @@ import (
"github.com/up9inc/mizu/tap/dbgctl" "github.com/up9inc/mizu/tap/dbgctl"
) )
type tcpStreamCallbacks interface {
tcpStreamCreated(stream *tcpStream)
tcpStreamClosed(stream *tcpStream)
}
/* It's a connection (bidirectional) /* It's a connection (bidirectional)
* Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete)
* ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete)
@@ -25,16 +30,25 @@ type tcpStream struct {
reqResMatchers []api.RequestResponseMatcher reqResMatchers []api.RequestResponseMatcher
createdAt time.Time createdAt time.Time
streamsMap api.TcpStreamMap streamsMap api.TcpStreamMap
connectionId connectionId
callbacks tcpStreamCallbacks
sync.Mutex sync.Mutex
} }
func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream { func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture,
return &tcpStream{ connectionId connectionId, callbacks tcpStreamCallbacks) *tcpStream {
isTapTarget: isTapTarget, t := &tcpStream{
streamsMap: streamsMap, isTapTarget: isTapTarget,
origin: capture, streamsMap: streamsMap,
createdAt: time.Now(), origin: capture,
createdAt: time.Now(),
connectionId: connectionId,
callbacks: callbacks,
} }
t.callbacks.tcpStreamCreated(t)
return t
} }
func (t *tcpStream) getId() int64 { func (t *tcpStream) getId() int64 {
@@ -56,9 +70,9 @@ func (t *tcpStream) close() {
t.isClosed = true t.isClosed = true
t.streamsMap.Delete(t.id) t.streamsMap.Delete(t.id)
t.client.close() t.client.close()
t.server.close() t.server.close()
t.callbacks.tcpStreamClosed(t)
} }
func (t *tcpStream) addCounterPair(counterPair *api.CounterPair) { func (t *tcpStream) addCounterPair(counterPair *api.CounterPair) {

View File

@@ -19,14 +19,15 @@ import (
* Generates a new tcp stream for each new tcp connection. Closes the stream when the connection closes. * Generates a new tcp stream for each new tcp connection. Closes the stream when the connection closes.
*/ */
type tcpStreamFactory struct { type tcpStreamFactory struct {
wg sync.WaitGroup wg sync.WaitGroup
emitter api.Emitter emitter api.Emitter
streamsMap api.TcpStreamMap streamsMap api.TcpStreamMap
ownIps []string ownIps []string
opts *TapOpts opts *TapOpts
streamsCallbacks tcpStreamCallbacks
} }
func NewTcpStreamFactory(emitter api.Emitter, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory { func NewTcpStreamFactory(emitter api.Emitter, streamsMap api.TcpStreamMap, opts *TapOpts, streamsCallbacks tcpStreamCallbacks) *tcpStreamFactory {
var ownIps []string var ownIps []string
if localhostIPs, err := getLocalhostIPs(); err != nil { if localhostIPs, err := getLocalhostIPs(); err != nil {
@@ -39,10 +40,11 @@ func NewTcpStreamFactory(emitter api.Emitter, streamsMap api.TcpStreamMap, opts
} }
return &tcpStreamFactory{ return &tcpStreamFactory{
emitter: emitter, emitter: emitter,
streamsMap: streamsMap, streamsMap: streamsMap,
ownIps: ownIps, ownIps: ownIps,
opts: opts, opts: opts,
streamsCallbacks: streamsCallbacks,
} }
} }
@@ -57,7 +59,8 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort) props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
isTapTarget := props.isTapTarget isTapTarget := props.isTapTarget
stream := NewTcpStream(isTapTarget, factory.streamsMap, getPacketOrigin(ac)) connectionId := getConnectionId(srcIp, srcPort, dstIp, dstPort)
stream := NewTcpStream(isTapTarget, factory.streamsMap, getPacketOrigin(ac), connectionId, factory.streamsCallbacks)
reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream) reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream)
if stream.GetIsTapTarget() { if stream.GetIsTapTarget() {
stream.setId(factory.streamsMap.NextId()) stream.setId(factory.streamsMap.NextId())

View File

@@ -34,7 +34,7 @@ type tlsPoller struct {
extension *api.Extension extension *api.Extension
procfs string procfs string
pidToNamespace sync.Map pidToNamespace sync.Map
fdCache *simplelru.LRU // Actual typs is map[string]addressPair fdCache *simplelru.LRU // Actual type is map[string]addressPair
evictedCounter int evictedCounter int
} }