From 97072427f67b8c00982a1abc4c99060eab8971e4 Mon Sep 17 00:00:00 2001 From: "M. Mert Yildiran" Date: Tue, 26 Apr 2022 18:34:23 +0300 Subject: [PATCH] Define `ReassemblyStream` interface and separate `gopacket` specififc fields to `tcpReassemblyStream` struct Such that make `tap/api` don't depend on `gopacket` --- tap/api/api.go | 13 +-- tap/api/go.mod | 1 - tap/api/go.sum | 2 - tap/passive_tapper.go | 7 +- tap/tcp/tcp_stream.go | 176 ++++------------------------------- tap/tcp_reassembly_stream.go | 170 +++++++++++++++++++++++++++++++++ tap/tcp_stream_factory.go | 24 ++--- 7 files changed, 209 insertions(+), 184 deletions(-) create mode 100644 tap/tcp_reassembly_stream.go diff --git a/tap/api/api.go b/tap/api/api.go index 7f3348d86..6020c02b5 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -14,9 +14,6 @@ import ( "sync" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/reassembly" "github.com/google/martian/har" "github.com/up9inc/mizu/shared" ) @@ -434,19 +431,19 @@ type TcpReader interface { } type TcpStream interface { - Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool - ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) - ReassemblyComplete(ac reassembly.AssemblerContext) bool Close() CloseOtherProtocolDissectors(protocol *Protocol) AddClient(reader TcpReader) AddServer(reader TcpReader) - ClientRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) - ServerRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) + GetClients() []TcpReader + GetServers() []TcpReader + GetClient(index int) TcpReader + GetServer(index int) TcpReader GetOrigin() Capture GetProtoIdentifier() *ProtoIdentifier GetReqResMatcher() RequestResponseMatcher GetIsTapTarget() bool + GetIsClosed() bool GetId() int64 SetId(id int64) } diff --git a/tap/api/go.mod b/tap/api/go.mod index ba82fee6d..a06f80ce5 100644 --- a/tap/api/go.mod +++ b/tap/api/go.mod @@ -3,7 +3,6 @@ module github.com/up9inc/mizu/tap/api go 1.17 require ( - github.com/google/gopacket v1.1.19 github.com/google/martian v2.1.0+incompatible github.com/up9inc/mizu/shared v0.0.0 ) diff --git a/tap/api/go.sum b/tap/api/go.sum index 85df69f5a..33f8b5ed8 100644 --- a/tap/api/go.sum +++ b/tap/api/go.sum @@ -255,8 +255,6 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index f27933d3f..88790ab63 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -33,8 +33,11 @@ var maxcount = flag.Int64("c", -1, "Only grab this many packets, then exit") var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)") var statsevery = flag.Int("stats", 60, "Output statistics every N seconds") var lazy = flag.Bool("lazy", false, "If true, do lazy decoding") -var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") // global -var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global +var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") // global +var checksum = flag.Bool("checksum", false, "Check TCP checksum") // global +var nooptcheck = flag.Bool("nooptcheck", true, "Do not check TCP options (useful to ignore MSS on captures with TSO)") // global +var ignorefsmerr = flag.Bool("ignorefsmerr", true, "Ignore TCP FSM errors") // global +var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global var verbose = flag.Bool("verbose", false, "Be verbose") var debug = flag.Bool("debug", false, "Display debug information") var quiet = flag.Bool("quiet", false, "Be quiet regarding errors") diff --git a/tap/tcp/tcp_stream.go b/tap/tcp/tcp_stream.go index 77297faa6..477f46042 100644 --- a/tap/tcp/tcp_stream.go +++ b/tap/tcp/tcp_stream.go @@ -1,24 +1,13 @@ package tcp import ( - "encoding/binary" - "flag" - "fmt" "sync" "time" - "github.com/google/gopacket" "github.com/google/gopacket/layers" // pulls in all layers decoders - "github.com/google/gopacket/reassembly" - "github.com/up9inc/mizu/shared" "github.com/up9inc/mizu/tap/api" - "github.com/up9inc/mizu/tap/diagnose" ) -var checksum = flag.Bool("checksum", false, "Check TCP checksum") // global -var nooptcheck = flag.Bool("nooptcheck", true, "Do not check TCP options (useful to ignore MSS on captures with TSO)") // global -var ignorefsmerr = flag.Bool("ignorefsmerr", true, "Ignore TCP FSM errors") // global - /* It's a connection (bidirectional) * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) @@ -28,10 +17,6 @@ type tcpStream struct { id int64 isClosed bool protoIdentifier *api.ProtoIdentifier - tcpState *reassembly.TCPSimpleFSM - fsmerr bool - optchecker reassembly.TCPOptionCheck - net, transport gopacket.Flow isDNS bool isTapTarget bool clients []api.TcpReader @@ -44,15 +29,9 @@ type tcpStream struct { sync.Mutex } -func NewTcpStream(net gopacket.Flow, transport gopacket.Flow, tcp *layers.TCP, isTapTarget bool, fsmOptions reassembly.TCPSimpleFSMOptions, streamsMap api.TcpStreamMap, capture api.Capture) api.TcpStream { +func NewTcpStream(tcp *layers.TCP, isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) api.TcpStream { return &tcpStream{ - net: net, - transport: transport, - isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, isTapTarget: isTapTarget, - tcpState: reassembly.NewTCPSimpleFSM(fsmOptions), - ident: fmt.Sprintf("%s:%s", net, transport), - optchecker: reassembly.NewTCPOptionCheck(), protoIdentifier: &api.ProtoIdentifier{}, streamsMap: streamsMap, origin: capture, @@ -66,139 +45,6 @@ func NewTcpStreamDummy(capture api.Capture) api.TcpStream { } } -func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { - // FSM - if !t.tcpState.CheckState(tcp, dir) { - diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpState.String()) - diagnose.InternalStats.RejectFsm++ - if !t.fsmerr { - t.fsmerr = true - diagnose.InternalStats.RejectConnFsm++ - } - if !*ignorefsmerr { - return false - } - } - // Options - err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) - if err != nil { - diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) - diagnose.InternalStats.RejectOpt++ - if !*nooptcheck { - return false - } - } - // Checksum - accept := true - if *checksum { - c, err := tcp.ComputeChecksum() - if err != nil { - diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err) - accept = false - } else if c != 0x0 { - diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c) - accept = false - } - } - if !accept { - diagnose.InternalStats.RejectOpt++ - } - - *start = true - - return accept -} - -func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { - dir, _, _, skip := sg.Info() - length, saved := sg.Lengths() - // update stats - sgStats := sg.Stats() - if skip > 0 { - diagnose.InternalStats.MissedBytes += skip - } - diagnose.InternalStats.Sz += length - saved - diagnose.InternalStats.Pkt += sgStats.Packets - if sgStats.Chunks > 1 { - diagnose.InternalStats.Reassembled++ - } - diagnose.InternalStats.OutOfOrderPackets += sgStats.QueuedPackets - diagnose.InternalStats.OutOfOrderBytes += sgStats.QueuedBytes - if length > diagnose.InternalStats.BiggestChunkBytes { - diagnose.InternalStats.BiggestChunkBytes = length - } - if sgStats.Packets > diagnose.InternalStats.BiggestChunkPackets { - diagnose.InternalStats.BiggestChunkPackets = sgStats.Packets - } - if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 { - // In the original example this was handled with panic(). - // I don't know what this error means or how to handle it properly. - diagnose.TapErrors.SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets) - } - diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes - diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets - - if skip != -1 && skip != 0 { - // Missing bytes in stream: do not even try to parse it - return - } - data := sg.Fetch(length) - if t.isDNS { - dns := &layers.DNS{} - var decoded []gopacket.LayerType - if len(data) < 2 { - if len(data) > 0 { - sg.KeepFrom(0) - } - return - } - dnsSize := binary.BigEndian.Uint16(data[:2]) - missing := int(dnsSize) - len(data[2:]) - diagnose.TapErrors.Debug("dnsSize: %d, missing: %d", dnsSize, missing) - if missing > 0 { - diagnose.TapErrors.Debug("Missing some bytes: %d", missing) - sg.KeepFrom(0) - return - } - p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns) - err := p.DecodeLayers(data[2:], &decoded) - if err != nil { - diagnose.TapErrors.SilentError("DNS-parser", "Failed to decode DNS: %v", err) - } else { - diagnose.TapErrors.Debug("DNS: %s", gopacket.LayerDump(dns)) - } - if len(data) > 2+int(dnsSize) { - sg.KeepFrom(2 + int(dnsSize)) - } - } else if t.isTapTarget { - if length > 0 { - // This is where we pass the reassembled information onwards - // This channel is read by an tcpReader object - diagnose.AppStatsInst.IncReassembledTcpPayloadsCount() - timestamp := ac.GetCaptureInfo().Timestamp - if dir == reassembly.TCPDirClientToServer { - for i := range t.clients { - reader := t.clients[i] - reader.SendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) - } - } else { - for i := range t.servers { - reader := t.servers[i] - reader.SendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) - } - } - } - } -} - -func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { - if t.isTapTarget && !t.isClosed { - t.Close() - } - // do not remove the connection to allow last ACK - return false -} - func (t *tcpStream) Close() { t.Lock() defer t.Unlock() @@ -255,12 +101,20 @@ func (t *tcpStream) AddServer(reader api.TcpReader) { t.servers = append(t.servers, reader) } -func (t *tcpStream) ClientRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { - t.clients[index].Run(filteringOptions, wg) +func (t *tcpStream) GetClients() []api.TcpReader { + return t.clients } -func (t *tcpStream) ServerRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { - t.servers[index].Run(filteringOptions, wg) +func (t *tcpStream) GetServers() []api.TcpReader { + return t.servers +} + +func (t *tcpStream) GetClient(index int) api.TcpReader { + return t.clients[index] +} + +func (t *tcpStream) GetServer(index int) api.TcpReader { + return t.servers[index] } func (t *tcpStream) GetOrigin() api.Capture { @@ -279,6 +133,10 @@ func (t *tcpStream) GetIsTapTarget() bool { return t.isTapTarget } +func (t *tcpStream) GetIsClosed() bool { + return t.isClosed +} + func (t *tcpStream) GetId() int64 { return t.id } diff --git a/tap/tcp_reassembly_stream.go b/tap/tcp_reassembly_stream.go new file mode 100644 index 000000000..4b3a9229a --- /dev/null +++ b/tap/tcp_reassembly_stream.go @@ -0,0 +1,170 @@ +package tap + +import ( + "encoding/binary" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/reassembly" + "github.com/up9inc/mizu/tap/api" + "github.com/up9inc/mizu/tap/diagnose" + "github.com/up9inc/mizu/tap/tcp" +) + +type ReassemblyStream interface { + Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool + ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) + ReassemblyComplete(ac reassembly.AssemblerContext) bool +} + +type tcpReassemblyStream struct { + ident string + tcpState *reassembly.TCPSimpleFSM + fsmerr bool + optchecker reassembly.TCPOptionCheck + isDNS bool + tcpStream api.TcpStream +} + +func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream api.TcpStream) ReassemblyStream { + return &tcpReassemblyStream{ + ident: ident, + tcpState: reassembly.NewTCPSimpleFSM(fsmOptions), + optchecker: reassembly.NewTCPOptionCheck(), + isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, + tcpStream: stream, + } +} + +func (t *tcpReassemblyStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { + // FSM + if !t.tcpState.CheckState(tcp, dir) { + diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpState.String()) + diagnose.InternalStats.RejectFsm++ + if !t.fsmerr { + t.fsmerr = true + diagnose.InternalStats.RejectConnFsm++ + } + if !*ignorefsmerr { + return false + } + } + // Options + err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) + if err != nil { + diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) + diagnose.InternalStats.RejectOpt++ + if !*nooptcheck { + return false + } + } + // Checksum + accept := true + if *checksum { + c, err := tcp.ComputeChecksum() + if err != nil { + diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err) + accept = false + } else if c != 0x0 { + diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c) + accept = false + } + } + if !accept { + diagnose.InternalStats.RejectOpt++ + } + + *start = true + + return accept +} + +func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { + dir, _, _, skip := sg.Info() + length, saved := sg.Lengths() + // update stats + sgStats := sg.Stats() + if skip > 0 { + diagnose.InternalStats.MissedBytes += skip + } + diagnose.InternalStats.Sz += length - saved + diagnose.InternalStats.Pkt += sgStats.Packets + if sgStats.Chunks > 1 { + diagnose.InternalStats.Reassembled++ + } + diagnose.InternalStats.OutOfOrderPackets += sgStats.QueuedPackets + diagnose.InternalStats.OutOfOrderBytes += sgStats.QueuedBytes + if length > diagnose.InternalStats.BiggestChunkBytes { + diagnose.InternalStats.BiggestChunkBytes = length + } + if sgStats.Packets > diagnose.InternalStats.BiggestChunkPackets { + diagnose.InternalStats.BiggestChunkPackets = sgStats.Packets + } + if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 { + // In the original example this was handled with panic(). + // I don't know what this error means or how to handle it properly. + diagnose.TapErrors.SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets) + } + diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes + diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets + + if skip != -1 && skip != 0 { + // Missing bytes in stream: do not even try to parse it + return + } + data := sg.Fetch(length) + if t.isDNS { + dns := &layers.DNS{} + var decoded []gopacket.LayerType + if len(data) < 2 { + if len(data) > 0 { + sg.KeepFrom(0) + } + return + } + dnsSize := binary.BigEndian.Uint16(data[:2]) + missing := int(dnsSize) - len(data[2:]) + diagnose.TapErrors.Debug("dnsSize: %d, missing: %d", dnsSize, missing) + if missing > 0 { + diagnose.TapErrors.Debug("Missing some bytes: %d", missing) + sg.KeepFrom(0) + return + } + p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns) + err := p.DecodeLayers(data[2:], &decoded) + if err != nil { + diagnose.TapErrors.SilentError("DNS-parser", "Failed to decode DNS: %v", err) + } else { + diagnose.TapErrors.Debug("DNS: %s", gopacket.LayerDump(dns)) + } + if len(data) > 2+int(dnsSize) { + sg.KeepFrom(2 + int(dnsSize)) + } + } else if t.tcpStream.GetIsTapTarget() { + if length > 0 { + // This is where we pass the reassembled information onwards + // This channel is read by an tcpReader object + diagnose.AppStatsInst.IncReassembledTcpPayloadsCount() + timestamp := ac.GetCaptureInfo().Timestamp + if dir == reassembly.TCPDirClientToServer { + for i := range t.tcpStream.GetClients() { + reader := t.tcpStream.GetClient(i) + reader.SendMsgIfNotClosed(tcp.NewTcpReaderDataMsg(data, timestamp)) + } + } else { + for i := range t.tcpStream.GetServers() { + reader := t.tcpStream.GetServer(i) + reader.SendMsgIfNotClosed(tcp.NewTcpReaderDataMsg(data, timestamp)) + } + } + } + } +} + +func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { + if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() { + t.tcpStream.Close() + } + // do not remove the connection to allow last ACK + return false +} diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index b4f95343b..7b0bc6f0c 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -59,16 +59,17 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort) isTapTarget := props.isTapTarget - stream := tcp.NewTcpStream(net, transport, tcpLayer, isTapTarget, fsmOptions, factory.streamsMap, getPacketOrigin(ac)) - if stream.GetIsTapTarget() { - stream.SetId(factory.streamsMap.NextId()) + tcpStream := tcp.NewTcpStream(tcpLayer, isTapTarget, factory.streamsMap, getPacketOrigin(ac)) + reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, tcpStream) + if tcpStream.GetIsTapTarget() { + tcpStream.SetId(factory.streamsMap.NextId()) for i, extension := range extensions { reqResMatcher := extension.Dissector.NewResponseRequestMatcher() counterPair := &api.CounterPair{ Request: 0, Response: 0, } - stream.AddClient( + tcpStream.AddClient( tcp.NewTcpReader( make(chan api.TcpReaderDataMsg), &api.ReadProgress{}, @@ -80,7 +81,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay DstPort: dstPort, }, time.Time{}, - stream, + tcpStream, true, props.isOutgoing, extension, @@ -89,7 +90,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay reqResMatcher, ), ) - stream.AddServer( + tcpStream.AddServer( tcp.NewTcpReader( make(chan api.TcpReaderDataMsg), &api.ReadProgress{}, @@ -101,7 +102,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay DstPort: transport.Src().String(), }, time.Time{}, - stream, + tcpStream, false, props.isOutgoing, extension, @@ -111,15 +112,14 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay ), ) - factory.streamsMap.Store(stream.GetId(), stream) + factory.streamsMap.Store(tcpStream.GetId(), tcpStream) factory.wg.Add(2) - // Start reading from channel stream.reader.bytes - go stream.ClientRun(i, filteringOptions, &factory.wg) - go stream.ServerRun(i, filteringOptions, &factory.wg) + go tcpStream.GetClient(i).Run(filteringOptions, &factory.wg) + go tcpStream.GetServer(i).Run(filteringOptions, &factory.wg) } } - return stream + return reassemblyStream } func (factory *tcpStreamFactory) WaitGoRoutines() {