Refactor tap module and move some of the code to tap/api module

This commit is contained in:
M. Mert Yildiran
2022-04-20 11:32:48 +03:00
parent ba4bab3ed5
commit 110325598a
18 changed files with 1319 additions and 344 deletions

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/google/martian/har" "github.com/google/martian/har"
"github.com/up9inc/mizu/tap/api/diagnose"
) )
const mizuTestEnvVar = "MIZU_TEST" const mizuTestEnvVar = "MIZU_TEST"
@@ -144,7 +145,7 @@ type RequestResponseMatcher interface {
} }
type Emitting struct { type Emitting struct {
AppStats *AppStats AppStats *diagnose.AppStats
OutputChannel chan *OutputChannelItem OutputChannel chan *OutputChannelItem
} }

View File

@@ -9,10 +9,9 @@ import (
"time" "time"
"github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
) )
var AppStats = api.AppStats{} var AppStatsInst = AppStats{}
func StartMemoryProfiler(envDumpPath string, envTimeInterval string) { func StartMemoryProfiler(envDumpPath string, envTimeInterval string) {
dumpPath := "/app/pprof" dumpPath := "/app/pprof"

View File

@@ -1,4 +1,4 @@
package api package diagnose
import ( import (
"sync/atomic" "sync/atomic"

View File

@@ -2,4 +2,12 @@ module github.com/up9inc/mizu/tap/api
go 1.17 go 1.17
require github.com/google/martian v2.1.0+incompatible require (
github.com/google/gopacket v1.1.19
github.com/google/martian v2.1.0+incompatible
github.com/up9inc/mizu/shared v0.0.0
)
require github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 // indirect
replace github.com/up9inc/mizu/shared v0.0.0 => ../../shared

File diff suppressed because it is too large Load Diff

85
tap/api/tcp_reader.go Normal file
View File

@@ -0,0 +1,85 @@
package api
import (
"bufio"
"io"
"io/ioutil"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
)
type TcpReaderDataMsg struct {
bytes []byte
timestamp time.Time
}
/* TcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
* The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection.
* An TcpReader object is unidirectional: it parses either a client stream or a server stream.
* Implements io.Reader interface (Read)
*/
type TcpReader struct {
Ident string
TcpID *TcpID
isClosed bool
IsClient bool
IsOutgoing bool
MsgQueue chan TcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte
Progress *ReadProgress
SuperTimer *SuperTimer
Parent *TcpStream
packetsSeen uint
Extension *Extension
Emitter Emitter
CounterPair *CounterPair
ReqResMatcher RequestResponseMatcher
sync.Mutex
}
func (h *TcpReader) Read(p []byte) (int, error) {
var msg TcpReaderDataMsg
ok := true
for ok && len(h.data) == 0 {
msg, ok = <-h.MsgQueue
h.data = msg.bytes
h.SuperTimer.CaptureTime = msg.timestamp
if len(h.data) > 0 {
h.packetsSeen += 1
}
}
if !ok || len(h.data) == 0 {
return 0, io.EOF
}
l := copy(p, h.data)
h.data = h.data[l:]
h.Progress.Feed(l)
return l, nil
}
func (h *TcpReader) Close() {
h.Lock()
if !h.isClosed {
h.isClosed = true
close(h.MsgQueue)
}
h.Unlock()
}
func (h *TcpReader) Run(filteringOptions *TrafficFilteringOptions, wg *sync.WaitGroup) {
defer wg.Done()
b := bufio.NewReader(h)
err := h.Extension.Dissector.Dissect(b, h.Progress, h.Parent.Origin, h.IsClient, h.TcpID, h.CounterPair, h.SuperTimer, h.Parent.SuperIdentifier, h.Emitter, filteringOptions, h.ReqResMatcher)
if err != nil {
_, err = io.Copy(ioutil.Discard, b)
if err != nil {
logger.Log.Errorf("%v", err)
}
}
}

View File

@@ -1,4 +1,4 @@
package tap package api
import ( import (
"encoding/binary" "encoding/binary"
@@ -8,79 +8,57 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly" "github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api/diagnose"
"github.com/up9inc/mizu/tap/diagnose"
) )
/* It's a connection (bidirectional) /* It's a connection (bidirectional)
* Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete)
* ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete)
* In our implementation, we pass information from ReassembledSG to the tcpReader through a shared channel. * In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel.
*/ */
type tcpStream struct { type TcpStream struct {
id int64 Id int64
isClosed bool isClosed bool
superIdentifier *api.SuperIdentifier SuperIdentifier *SuperIdentifier
tcpstate *reassembly.TCPSimpleFSM TcpState *reassembly.TCPSimpleFSM
fsmerr bool fsmerr bool
optchecker reassembly.TCPOptionCheck Optchecker reassembly.TCPOptionCheck
net, transport gopacket.Flow Net, Transport gopacket.Flow
isDNS bool IsDNS bool
isTapTarget bool IsTapTarget bool
clients []tcpReader Clients []TcpReader
servers []tcpReader Servers []TcpReader
ident string Ident string
origin api.Capture Origin Capture
reqResMatcher api.RequestResponseMatcher ReqResMatcher RequestResponseMatcher
createdAt time.Time createdAt time.Time
StreamsMap *TcpStreamMap
sync.Mutex sync.Mutex
streamsMap *tcpStreamMap
} }
func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
// FSM // FSM
if !t.tcpstate.CheckState(tcp, dir) { if !t.TcpState.CheckState(tcp, dir) {
diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpstate.String()) diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.Ident, t.TcpState.String())
diagnose.InternalStats.RejectFsm++ diagnose.InternalStats.RejectFsm++
if !t.fsmerr { if !t.fsmerr {
t.fsmerr = true t.fsmerr = true
diagnose.InternalStats.RejectConnFsm++ diagnose.InternalStats.RejectConnFsm++
} }
if !*ignorefsmerr {
return false
}
} }
// Options // Options
err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) err := t.Optchecker.Accept(tcp, ci, dir, nextSeq, start)
if err != nil { if err != nil {
diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.Ident, err)
diagnose.InternalStats.RejectOpt++
if !*nooptcheck {
return false
}
}
// Checksum
accept := true
if *checksum {
c, err := tcp.ComputeChecksum()
if err != nil {
diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err)
accept = false
} else if c != 0x0 {
diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c)
accept = false
}
}
if !accept {
diagnose.InternalStats.RejectOpt++ diagnose.InternalStats.RejectOpt++
} }
*start = true *start = true
return accept return true
} }
func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { func (t *TcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, _, _, skip := sg.Info() dir, _, _, skip := sg.Info()
length, saved := sg.Lengths() length, saved := sg.Lengths()
// update stats // update stats
@@ -109,14 +87,12 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes
diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets
if skip == -1 && *allowmissinginit { if skip != -1 && skip != 0 {
// this is allowed
} else if skip != 0 {
// Missing bytes in stream: do not even try to parse it // Missing bytes in stream: do not even try to parse it
return return
} }
data := sg.Fetch(length) data := sg.Fetch(length)
if t.isDNS { if t.IsDNS {
dns := &layers.DNS{} dns := &layers.DNS{}
var decoded []gopacket.LayerType var decoded []gopacket.LayerType
if len(data) < 2 { if len(data) < 2 {
@@ -143,27 +119,27 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
if len(data) > 2+int(dnsSize) { if len(data) > 2+int(dnsSize) {
sg.KeepFrom(2 + int(dnsSize)) sg.KeepFrom(2 + int(dnsSize))
} }
} else if t.isTapTarget { } else if t.IsTapTarget {
if length > 0 { if length > 0 {
// This is where we pass the reassembled information onwards // This is where we pass the reassembled information onwards
// This channel is read by an tcpReader object // This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount() diagnose.AppStatsInst.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp timestamp := ac.GetCaptureInfo().Timestamp
if dir == reassembly.TCPDirClientToServer { if dir == reassembly.TCPDirClientToServer {
for i := range t.clients { for i := range t.Clients {
reader := &t.clients[i] reader := &t.Clients[i]
reader.Lock() reader.Lock()
if !reader.isClosed { if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp} reader.MsgQueue <- TcpReaderDataMsg{data, timestamp}
} }
reader.Unlock() reader.Unlock()
} }
} else { } else {
for i := range t.servers { for i := range t.Servers {
reader := &t.servers[i] reader := &t.Servers[i]
reader.Lock() reader.Lock()
if !reader.isClosed { if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp} reader.MsgQueue <- TcpReaderDataMsg{data, timestamp}
} }
reader.Unlock() reader.Unlock()
} }
@@ -172,15 +148,15 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
} }
} }
func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { func (t *TcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.isTapTarget && !t.isClosed { if t.IsTapTarget && !t.isClosed {
t.Close() t.Close()
} }
// do not remove the connection to allow last ACK // do not remove the connection to allow last ACK
return false return false
} }
func (t *tcpStream) Close() { func (t *TcpStream) Close() {
shouldReturn := false shouldReturn := false
t.Lock() t.Lock()
if t.isClosed { if t.isClosed {
@@ -192,14 +168,14 @@ func (t *tcpStream) Close() {
if shouldReturn { if shouldReturn {
return return
} }
t.streamsMap.Delete(t.id) t.StreamsMap.Delete(t.Id)
for i := range t.clients { for i := range t.Clients {
reader := &t.clients[i] reader := &t.Clients[i]
reader.Close() reader.Close()
} }
for i := range t.servers { for i := range t.Servers {
reader := &t.servers[i] reader := &t.Servers[i]
reader.Close() reader.Close()
} }
} }

View File

@@ -0,0 +1,29 @@
package api
import (
"sync"
)
type TcpStreamMap struct {
Streams *sync.Map
streamId int64
}
func NewTcpStreamMap() *TcpStreamMap {
return &TcpStreamMap{
Streams: &sync.Map{},
}
}
func (streamMap *TcpStreamMap) Store(key, value interface{}) {
streamMap.Streams.Store(key, value)
}
func (streamMap *TcpStreamMap) Delete(key interface{}) {
streamMap.Streams.Delete(key)
}
func (streamMap *TcpStreamMap) NextId() int64 {
streamMap.streamId++
return streamMap.streamId
}

View File

@@ -22,7 +22,7 @@ type Cleaner struct {
connectionTimeout time.Duration connectionTimeout time.Duration
stats CleanerStats stats CleanerStats
statsMutex sync.Mutex statsMutex sync.Mutex
streamsMap *tcpStreamMap streamsMap *api.TcpStreamMap
} }
func (cl *Cleaner) clean() { func (cl *Cleaner) clean() {
@@ -33,8 +33,8 @@ func (cl *Cleaner) clean() {
flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout))
cl.assemblerMutex.Unlock() cl.assemblerMutex.Unlock()
cl.streamsMap.streams.Range(func(k, v interface{}) bool { cl.streamsMap.Streams.Range(func(k, v interface{}) bool {
reqResMatcher := v.(*tcpStream).reqResMatcher reqResMatcher := v.(*api.TcpStream).ReqResMatcher
if reqResMatcher == nil { if reqResMatcher == nil {
return true return true
} }

View File

@@ -4,7 +4,7 @@ import (
"net" "net"
"strings" "strings"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/api/diagnose"
) )
var privateIPBlocks []*net.IPNet var privateIPBlocks []*net.IPNet

View File

@@ -19,7 +19,7 @@ import (
"github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/api/diagnose"
"github.com/up9inc/mizu/tap/source" "github.com/up9inc/mizu/tap/source"
"github.com/up9inc/mizu/tap/tlstapper" "github.com/up9inc/mizu/tap/tlstapper"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
@@ -31,10 +31,7 @@ var maxcount = flag.Int64("c", -1, "Only grab this many packets, then exit")
var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)") var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)")
var statsevery = flag.Int("stats", 60, "Output statistics every N seconds") var statsevery = flag.Int("stats", 60, "Output statistics every N seconds")
var lazy = flag.Bool("lazy", false, "If true, do lazy decoding") var lazy = flag.Bool("lazy", false, "If true, do lazy decoding")
var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") // global
var checksum = flag.Bool("checksum", false, "Check TCP checksum") // global
var nooptcheck = flag.Bool("nooptcheck", true, "Do not check TCP options (useful to ignore MSS on captures with TSO)") // global
var ignorefsmerr = flag.Bool("ignorefsmerr", true, "Ignore TCP FSM errors") // global
var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global
var verbose = flag.Bool("verbose", false, "Be verbose") var verbose = flag.Bool("verbose", false, "Be verbose")
var debug = flag.Bool("debug", false, "Display debug information") var debug = flag.Bool("debug", false, "Display debug information")
@@ -128,7 +125,7 @@ func printPeriodicStats(cleaner *Cleaner) {
errorMapLen, errorsSummery := diagnose.TapErrors.GetErrorsSummary() errorMapLen, errorsSummery := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("%v (errors: %v, errTypes:%v) - Errors Summary: %s", logger.Log.Infof("%v (errors: %v, errTypes:%v) - Errors Summary: %s",
time.Since(diagnose.AppStats.StartTime), time.Since(diagnose.AppStatsInst.StartTime),
diagnose.TapErrors.ErrorsCount, diagnose.TapErrors.ErrorsCount,
errorMapLen, errorMapLen,
errorsSummery, errorsSummery,
@@ -151,7 +148,7 @@ func printPeriodicStats(cleaner *Cleaner) {
cleanStats.closed, cleanStats.closed,
cleanStats.deleted, cleanStats.deleted,
) )
currentAppStats := diagnose.AppStats.DumpStats() currentAppStats := diagnose.AppStatsInst.DumpStats()
appStatsJSON, _ := json.Marshal(currentAppStats) appStatsJSON, _ := json.Marshal(currentAppStats)
logger.Log.Infof("app stats - %v", string(appStatsJSON)) logger.Log.Infof("app stats - %v", string(appStatsJSON))
} }
@@ -181,8 +178,8 @@ func initializePacketSources() error {
return err return err
} }
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*tcpStreamMap, *tcpAssembler) { func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*api.TcpStreamMap, *tcpAssembler) {
streamsMap := NewTcpStreamMap() streamsMap := api.NewTcpStreamMap()
diagnose.InitializeErrorsMap(*debug, *verbose, *quiet) diagnose.InitializeErrorsMap(*debug, *verbose, *quiet)
diagnose.InitializeTapperInternalStats() diagnose.InitializeTapperInternalStats()
@@ -198,10 +195,8 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI
return streamsMap, assembler return streamsMap, assembler
} }
func startPassiveTapper(streamsMap *tcpStreamMap, assembler *tcpAssembler) { func startPassiveTapper(streamsMap *api.TcpStreamMap, assembler *tcpAssembler) {
go streamsMap.closeTimedoutTcpStreamChannels() diagnose.AppStatsInst.SetStartTime(time.Now())
diagnose.AppStats.SetStartTime(time.Now())
staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds) staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds)
cleaner := Cleaner{ cleaner := Cleaner{
@@ -229,7 +224,7 @@ func startPassiveTapper(streamsMap *tcpStreamMap, assembler *tcpAssembler) {
diagnose.InternalStats.PrintStatsSummary() diagnose.InternalStats.PrintStatsSummary()
diagnose.TapErrors.PrintSummary() diagnose.TapErrors.PrintSummary()
logger.Log.Infof("AppStats: %v", diagnose.AppStats) logger.Log.Infof("AppStats: %v", diagnose.AppStatsInst)
} }
func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper { func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper {
@@ -257,7 +252,7 @@ func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChanne
} }
var emitter api.Emitter = &api.Emitting{ var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats, AppStats: &diagnose.AppStatsInst,
OutputChannel: outputItems, OutputChannel: outputItems,
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/gopacket/pcap" "github.com/google/gopacket/pcap"
"github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/api/diagnose"
) )
type tcpPacketSource struct { type tcpPacketSource struct {

View File

@@ -12,7 +12,7 @@ import (
"github.com/google/gopacket/reassembly" "github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/api/diagnose"
"github.com/up9inc/mizu/tap/source" "github.com/up9inc/mizu/tap/source"
) )
@@ -36,9 +36,9 @@ func (c *context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo return c.CaptureInfo
} }
func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *tcpStreamMap, opts *TapOpts) *tcpAssembler { func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpAssembler {
var emitter api.Emitter = &api.Emitting{ var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats, AppStats: &diagnose.AppStatsInst,
OutputChannel: outputItems, OutputChannel: outputItems,
} }
@@ -65,7 +65,7 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
signal.Notify(signalChan, os.Interrupt) signal.Notify(signalChan, os.Interrupt)
for packetInfo := range packets { for packetInfo := range packets {
packetsCount := diagnose.AppStats.IncPacketsCount() packetsCount := diagnose.AppStatsInst.IncPacketsCount()
if packetsCount%PACKETS_SEEN_LOG_THRESHOLD == 0 { if packetsCount%PACKETS_SEEN_LOG_THRESHOLD == 0 {
logger.Log.Debugf("Packets seen: #%d", packetsCount) logger.Log.Debugf("Packets seen: #%d", packetsCount)
@@ -73,21 +73,15 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
packet := packetInfo.Packet packet := packetInfo.Packet
data := packet.Data() data := packet.Data()
diagnose.AppStats.UpdateProcessedBytes(uint64(len(data))) diagnose.AppStatsInst.UpdateProcessedBytes(uint64(len(data)))
if dumpPacket { if dumpPacket {
logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data)) logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data))
} }
tcp := packet.Layer(layers.LayerTypeTCP) tcp := packet.Layer(layers.LayerTypeTCP)
if tcp != nil { if tcp != nil {
diagnose.AppStats.IncTcpPacketsCount() diagnose.AppStatsInst.IncTcpPacketsCount()
tcp := tcp.(*layers.TCP) tcp := tcp.(*layers.TCP)
if *checksum {
err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer())
if err != nil {
logger.Log.Fatalf("Failed to set network layer for checksum: %s", err)
}
}
c := context{ c := context{
CaptureInfo: packet.Metadata().CaptureInfo, CaptureInfo: packet.Metadata().CaptureInfo,
@@ -99,13 +93,13 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
a.assemblerMutex.Unlock() a.assemblerMutex.Unlock()
} }
done := *maxcount > 0 && int64(diagnose.AppStats.PacketsCount) >= *maxcount done := *maxcount > 0 && int64(diagnose.AppStatsInst.PacketsCount) >= *maxcount
if done { if done {
errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary() errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)", logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)",
diagnose.AppStats.PacketsCount, diagnose.AppStatsInst.PacketsCount,
diagnose.AppStats.ProcessedBytes, diagnose.AppStatsInst.ProcessedBytes,
time.Since(diagnose.AppStats.StartTime), time.Since(diagnose.AppStatsInst.StartTime),
diagnose.TapErrors.ErrorsCount, diagnose.TapErrors.ErrorsCount,
errorMapLen) errorMapLen)
} }

View File

@@ -1,94 +0,0 @@
package tap
import (
"bufio"
"io"
"io/ioutil"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
)
type tcpReaderDataMsg struct {
bytes []byte
timestamp time.Time
}
type ConnectionInfo struct {
ClientIP string
ClientPort string
ServerIP string
ServerPort string
IsOutgoing bool
}
/* tcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
* The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection.
* An tcpReader object is unidirectional: it parses either a client stream or a server stream.
* Implements io.Reader interface (Read)
*/
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
msgQueue chan tcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte
progress *api.ReadProgress
superTimer *api.SuperTimer
parent *tcpStream
packetsSeen uint
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func (h *tcpReader) Read(p []byte) (int, error) {
var msg tcpReaderDataMsg
ok := true
for ok && len(h.data) == 0 {
msg, ok = <-h.msgQueue
h.data = msg.bytes
h.superTimer.CaptureTime = msg.timestamp
if len(h.data) > 0 {
h.packetsSeen += 1
}
}
if !ok || len(h.data) == 0 {
return 0, io.EOF
}
l := copy(p, h.data)
h.data = h.data[l:]
h.progress.Feed(l)
return l, nil
}
func (h *tcpReader) Close() {
h.Lock()
if !h.isClosed {
h.isClosed = true
close(h.msgQueue)
}
h.Unlock()
}
func (h *tcpReader) run(wg *sync.WaitGroup) {
defer wg.Done()
b := bufio.NewReader(h)
err := h.extension.Dissector.Dissect(b, h.progress, h.parent.origin, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter, filteringOptions, h.reqResMatcher)
if err != nil {
_, err = io.Copy(ioutil.Discard, b)
if err != nil {
logger.Log.Errorf("%v", err)
}
}
}

View File

@@ -21,12 +21,12 @@ import (
type tcpStreamFactory struct { type tcpStreamFactory struct {
wg sync.WaitGroup wg sync.WaitGroup
Emitter api.Emitter Emitter api.Emitter
streamsMap *tcpStreamMap streamsMap *api.TcpStreamMap
ownIps []string ownIps []string
opts *TapOpts opts *TapOpts
} }
func NewTcpStreamFactory(emitter api.Emitter, streamsMap *tcpStreamMap, opts *TapOpts) *tcpStreamFactory { func NewTcpStreamFactory(emitter api.Emitter, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory {
var ownIps []string var ownIps []string
if localhostIPs, err := getLocalhostIPs(); err != nil { if localhostIPs, err := getLocalhostIPs(); err != nil {
@@ -57,71 +57,71 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort) props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
isTapTarget := props.isTapTarget isTapTarget := props.isTapTarget
stream := &tcpStream{ stream := &api.TcpStream{
net: net, Net: net,
transport: transport, Transport: transport,
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, IsDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
isTapTarget: isTapTarget, IsTapTarget: isTapTarget,
tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions), TcpState: reassembly.NewTCPSimpleFSM(fsmOptions),
ident: fmt.Sprintf("%s:%s", net, transport), Ident: fmt.Sprintf("%s:%s", net, transport),
optchecker: reassembly.NewTCPOptionCheck(), Optchecker: reassembly.NewTCPOptionCheck(),
superIdentifier: &api.SuperIdentifier{}, SuperIdentifier: &api.SuperIdentifier{},
streamsMap: factory.streamsMap, StreamsMap: factory.streamsMap,
origin: getPacketOrigin(ac), Origin: getPacketOrigin(ac),
} }
if stream.isTapTarget { if stream.IsTapTarget {
stream.id = factory.streamsMap.nextId() stream.Id = factory.streamsMap.NextId()
for i, extension := range extensions { for i, extension := range extensions {
reqResMatcher := extension.Dissector.NewResponseRequestMatcher() reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
counterPair := &api.CounterPair{ counterPair := &api.CounterPair{
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
stream.clients = append(stream.clients, tcpReader{ stream.Clients = append(stream.Clients, api.TcpReader{
msgQueue: make(chan tcpReaderDataMsg), MsgQueue: make(chan api.TcpReaderDataMsg),
progress: &api.ReadProgress{}, Progress: &api.ReadProgress{},
superTimer: &api.SuperTimer{}, SuperTimer: &api.SuperTimer{},
ident: fmt.Sprintf("%s %s", net, transport), Ident: fmt.Sprintf("%s %s", net, transport),
tcpID: &api.TcpID{ TcpID: &api.TcpID{
SrcIP: srcIp, SrcIP: srcIp,
DstIP: dstIp, DstIP: dstIp,
SrcPort: srcPort, SrcPort: srcPort,
DstPort: dstPort, DstPort: dstPort,
}, },
parent: stream, Parent: stream,
isClient: true, IsClient: true,
isOutgoing: props.isOutgoing, IsOutgoing: props.isOutgoing,
extension: extension, Extension: extension,
emitter: factory.Emitter, Emitter: factory.Emitter,
counterPair: counterPair, CounterPair: counterPair,
reqResMatcher: reqResMatcher, ReqResMatcher: reqResMatcher,
}) })
stream.servers = append(stream.servers, tcpReader{ stream.Servers = append(stream.Servers, api.TcpReader{
msgQueue: make(chan tcpReaderDataMsg), MsgQueue: make(chan api.TcpReaderDataMsg),
progress: &api.ReadProgress{}, Progress: &api.ReadProgress{},
superTimer: &api.SuperTimer{}, SuperTimer: &api.SuperTimer{},
ident: fmt.Sprintf("%s %s", net, transport), Ident: fmt.Sprintf("%s %s", net, transport),
tcpID: &api.TcpID{ TcpID: &api.TcpID{
SrcIP: net.Dst().String(), SrcIP: net.Dst().String(),
DstIP: net.Src().String(), DstIP: net.Src().String(),
SrcPort: transport.Dst().String(), SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(), DstPort: transport.Src().String(),
}, },
parent: stream, Parent: stream,
isClient: false, IsClient: false,
isOutgoing: props.isOutgoing, IsOutgoing: props.isOutgoing,
extension: extension, Extension: extension,
emitter: factory.Emitter, Emitter: factory.Emitter,
counterPair: counterPair, CounterPair: counterPair,
reqResMatcher: reqResMatcher, ReqResMatcher: reqResMatcher,
}) })
factory.streamsMap.Store(stream.id, stream) factory.streamsMap.Store(stream.Id, stream)
factory.wg.Add(2) factory.wg.Add(2)
// Start reading from channel stream.reader.bytes // Start reading from channel stream.reader.bytes
go stream.clients[i].run(&factory.wg) go stream.Clients[i].Run(filteringOptions, &factory.wg)
go stream.servers[i].run(&factory.wg) go stream.Servers[i].Run(filteringOptions, &factory.wg)
} }
} }
return stream return stream

View File

@@ -1,98 +0,0 @@
package tap
import (
"os"
"runtime"
_debug "runtime/debug"
"strconv"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/diagnose"
)
type tcpStreamMap struct {
streams *sync.Map
streamId int64
}
func NewTcpStreamMap() *tcpStreamMap {
return &tcpStreamMap{
streams: &sync.Map{},
}
}
func (streamMap *tcpStreamMap) Store(key, value interface{}) {
streamMap.streams.Store(key, value)
}
func (streamMap *tcpStreamMap) Delete(key interface{}) {
streamMap.streams.Delete(key)
}
func (streamMap *tcpStreamMap) nextId() int64 {
streamMap.streamId++
return streamMap.streamId
}
func (streamMap *tcpStreamMap) getCloseTimedoutTcpChannelsInterval() time.Duration {
defaultDuration := 1000 * time.Millisecond
rangeMin := 10
rangeMax := 10000
closeTimedoutTcpChannelsIntervalMsStr := os.Getenv(CloseTimedoutTcpChannelsIntervalMsEnvVar)
if closeTimedoutTcpChannelsIntervalMsStr == "" {
return defaultDuration
} else {
closeTimedoutTcpChannelsIntervalMs, err := strconv.Atoi(closeTimedoutTcpChannelsIntervalMsStr)
if err != nil {
logger.Log.Warningf("Error parsing environment variable %s: %v\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, err)
return defaultDuration
} else {
if closeTimedoutTcpChannelsIntervalMs < rangeMin || closeTimedoutTcpChannelsIntervalMs > rangeMax {
logger.Log.Warningf("The value of environment variable %s is not in acceptable range: %d - %d\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, rangeMin, rangeMax)
return defaultDuration
} else {
return time.Duration(closeTimedoutTcpChannelsIntervalMs) * time.Millisecond
}
}
}
}
func (streamMap *tcpStreamMap) closeTimedoutTcpStreamChannels() {
tcpStreamChannelTimeout := GetTcpChannelTimeoutMs()
closeTimedoutTcpChannelsIntervalMs := streamMap.getCloseTimedoutTcpChannelsInterval()
logger.Log.Infof("Using %d ms as the close timedout TCP stream channels interval", closeTimedoutTcpChannelsIntervalMs/time.Millisecond)
for {
time.Sleep(closeTimedoutTcpChannelsIntervalMs)
_debug.FreeOSMemory()
streamMap.streams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*tcpStream)
if stream.superIdentifier.Protocol == nil {
if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeout)) {
stream.Close()
diagnose.AppStats.IncDroppedTcpStreams()
logger.Log.Debugf("Dropped an unidentified TCP stream because of timeout. Total dropped: %d Total Goroutines: %d Timeout (ms): %d",
diagnose.AppStats.DroppedTcpStreams, runtime.NumGoroutine(), tcpStreamChannelTimeout/time.Millisecond)
}
} else {
if !stream.superIdentifier.IsClosedOthers {
for i := range stream.clients {
reader := &stream.clients[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
for i := range stream.servers {
reader := &stream.servers[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
stream.superIdentifier.IsClosedOthers = true
}
}
return true
})
}
}