From f8cf8848e87291a1b57d1d89e04be5df069c5aff Mon Sep 17 00:00:00 2001 From: "M. Mert Yildiran" Date: Mon, 25 Apr 2022 19:15:49 +0300 Subject: [PATCH] Define a bunch of interfaces and don't export any new structs from `tap/api` --- tap/api/api.go | 2 +- tap/api/tcp_reader.go | 129 +++++++++++++++++---- tap/api/tcp_reader_data_msg.go | 25 ++++ tap/api/tcp_stream.go | 186 +++++++++++++++++++++--------- tap/api/tcp_streams_map.go | 40 ++++--- tap/cleaner.go | 6 +- tap/extensions/amqp/main.go | 44 +++---- tap/extensions/amqp/main_test.go | 54 +++++---- tap/extensions/http/main.go | 52 ++++----- tap/extensions/http/main_test.go | 54 +++++---- tap/extensions/kafka/main.go | 16 +-- tap/extensions/kafka/main_test.go | 54 +++++---- tap/extensions/redis/main.go | 10 +- tap/extensions/redis/main_test.go | 54 +++++---- tap/passive_tapper.go | 4 +- tap/tcp_assembler.go | 2 +- tap/tcp_stream_factory.go | 110 +++++++++--------- tap/tlstapper/tls_poller.go | 27 +++-- 18 files changed, 537 insertions(+), 332 deletions(-) create mode 100644 tap/api/tcp_reader_data_msg.go diff --git a/tap/api/api.go b/tap/api/api.go index 26fa9fe15..14f62bbc7 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -133,7 +133,7 @@ func (p *ReadProgress) Current() (n int) { type Dissector interface { Register(*Extension) Ping() - Dissect(b *bufio.Reader, reader *TcpReader, options *shared.TrafficFilteringOptions) error + Dissect(b *bufio.Reader, reader TcpReader, options *shared.TrafficFilteringOptions) error Analyze(item *OutputChannelItem, resolvedSource string, resolvedDestination string, namespace string) *Entry Summarize(entry *Entry) *BaseEntry Represent(request map[string]interface{}, response map[string]interface{}) (object []byte, err error) diff --git a/tap/api/tcp_reader.go b/tap/api/tcp_reader.go index 7858a1888..11a663ff7 100644 --- a/tap/api/tcp_reader.go +++ b/tap/api/tcp_reader.go @@ -11,9 +11,21 @@ import ( "github.com/up9inc/mizu/shared/logger" ) -type TcpReaderDataMsg struct { - bytes []byte - timestamp time.Time +type TcpReader interface { + Read(p []byte) (int, error) + Close() + Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) + SendMsgIfNotClosed(msg TcpReaderDataMsg) + GetReqResMatcher() RequestResponseMatcher + GetIsClient() bool + GetReadProgress() *ReadProgress + GetParent() TcpStream + GetTcpID() *TcpID + GetCounterPair() *CounterPair + GetCaptureTime() time.Time + GetEmitter() Emitter + GetIsClosed() bool + GetExtension() *Extension } /* TcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses. @@ -21,34 +33,53 @@ type TcpReaderDataMsg struct { * An TcpReader object is unidirectional: it parses either a client stream or a server stream. * Implements io.Reader interface (Read) */ -type TcpReader struct { - Ident string - TcpID *TcpID +type tcpReader struct { + ident string + tcpID *TcpID isClosed bool - IsClient bool - IsOutgoing bool - MsgQueue chan TcpReaderDataMsg // Channel of captured reassembled tcp payload + isClient bool + isOutgoing bool + msgQueue chan TcpReaderDataMsg // Channel of captured reassembled tcp payload data []byte - Progress *ReadProgress - CaptureTime time.Time - Parent *TcpStream + progress *ReadProgress + captureTime time.Time + parent TcpStream packetsSeen uint - Extension *Extension - Emitter Emitter - CounterPair *CounterPair - ReqResMatcher RequestResponseMatcher + extension *Extension + emitter Emitter + counterPair *CounterPair + reqResMatcher RequestResponseMatcher sync.Mutex } -func (reader *TcpReader) Read(p []byte) (int, error) { +func NewTcpReader(msgQueue chan TcpReaderDataMsg, progress *ReadProgress, ident string, tcpId *TcpID, captureTime time.Time, parent TcpStream, isClient bool, isOutgoing bool, extension *Extension, emitter Emitter, counterPair *CounterPair, reqResMatcher RequestResponseMatcher) TcpReader { + return &tcpReader{ + msgQueue: msgQueue, + progress: progress, + ident: ident, + tcpID: tcpId, + captureTime: captureTime, + parent: parent, + isClient: isClient, + isOutgoing: isOutgoing, + extension: extension, + emitter: emitter, + counterPair: counterPair, + reqResMatcher: reqResMatcher, + } +} + +func (reader *tcpReader) Read(p []byte) (int, error) { var msg TcpReaderDataMsg ok := true for ok && len(reader.data) == 0 { - msg, ok = <-reader.MsgQueue - reader.data = msg.bytes + msg, ok = <-reader.msgQueue + if msg != nil { + reader.data = msg.GetBytes() + reader.captureTime = msg.GetTimestamp() + } - reader.CaptureTime = msg.timestamp if len(reader.data) > 0 { reader.packetsSeen += 1 } @@ -59,24 +90,24 @@ func (reader *TcpReader) Read(p []byte) (int, error) { l := copy(p, reader.data) reader.data = reader.data[l:] - reader.Progress.Feed(l) + reader.progress.Feed(l) return l, nil } -func (reader *TcpReader) Close() { +func (reader *tcpReader) Close() { reader.Lock() if !reader.isClosed { reader.isClosed = true - close(reader.MsgQueue) + close(reader.msgQueue) } reader.Unlock() } -func (reader *TcpReader) Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { +func (reader *tcpReader) Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { defer wg.Done() b := bufio.NewReader(reader) - err := reader.Extension.Dissector.Dissect(b, reader, options) + err := reader.extension.Dissector.Dissect(b, reader, options) if err != nil { _, err = io.Copy(ioutil.Discard, reader) if err != nil { @@ -84,3 +115,51 @@ func (reader *TcpReader) Run(options *shared.TrafficFilteringOptions, wg *sync.W } } } + +func (reader *tcpReader) SendMsgIfNotClosed(msg TcpReaderDataMsg) { + reader.Lock() + if !reader.isClosed { + reader.msgQueue <- msg + } + reader.Unlock() +} + +func (reader *tcpReader) GetReqResMatcher() RequestResponseMatcher { + return reader.reqResMatcher +} + +func (reader *tcpReader) GetIsClient() bool { + return reader.isClient +} + +func (reader *tcpReader) GetReadProgress() *ReadProgress { + return reader.progress +} + +func (reader *tcpReader) GetParent() TcpStream { + return reader.parent +} + +func (reader *tcpReader) GetTcpID() *TcpID { + return reader.tcpID +} + +func (reader *tcpReader) GetCounterPair() *CounterPair { + return reader.counterPair +} + +func (reader *tcpReader) GetCaptureTime() time.Time { + return reader.captureTime +} + +func (reader *tcpReader) GetEmitter() Emitter { + return reader.emitter +} + +func (reader *tcpReader) GetIsClosed() bool { + return reader.isClosed +} + +func (reader *tcpReader) GetExtension() *Extension { + return reader.extension +} diff --git a/tap/api/tcp_reader_data_msg.go b/tap/api/tcp_reader_data_msg.go new file mode 100644 index 000000000..3d3692c1a --- /dev/null +++ b/tap/api/tcp_reader_data_msg.go @@ -0,0 +1,25 @@ +package api + +import "time" + +type TcpReaderDataMsg interface { + GetBytes() []byte + GetTimestamp() time.Time +} + +type tcpReaderDataMsg struct { + bytes []byte + timestamp time.Time +} + +func NewTcpReaderDataMsg(data []byte, timestamp time.Time) TcpReaderDataMsg { + return &tcpReaderDataMsg{data, timestamp} +} + +func (dataMsg *tcpReaderDataMsg) GetBytes() []byte { + return dataMsg.bytes +} + +func (dataMsg *tcpReaderDataMsg) GetTimestamp() time.Time { + return dataMsg.timestamp +} diff --git a/tap/api/tcp_stream.go b/tap/api/tcp_stream.go index 408bf186d..428bc4fdc 100644 --- a/tap/api/tcp_stream.go +++ b/tap/api/tcp_stream.go @@ -2,44 +2,86 @@ package api import ( "encoding/binary" + "fmt" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/reassembly" + "github.com/up9inc/mizu/shared" "github.com/up9inc/mizu/tap/api/diagnose" ) +type TcpStream interface { + Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool + ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) + ReassemblyComplete(ac reassembly.AssemblerContext) bool + Close() + CloseOtherProtocolDissectors(protocol *Protocol) + AddClient(reader TcpReader) + AddServer(reader TcpReader) + ClientRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) + ServerRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) + GetOrigin() Capture + GetProtoIdentifier() *ProtoIdentifier + GetReqResMatcher() RequestResponseMatcher + GetIsTapTarget() bool + GetId() int64 + SetId(id int64) +} + /* It's a connection (bidirectional) * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) * In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel. */ -type TcpStream struct { - Id int64 +type tcpStream struct { + id int64 isClosed bool - ProtoIdentifier *ProtoIdentifier - TcpState *reassembly.TCPSimpleFSM + protoIdentifier *ProtoIdentifier + tcpState *reassembly.TCPSimpleFSM fsmerr bool - Optchecker reassembly.TCPOptionCheck - Net, Transport gopacket.Flow - IsDNS bool - IsTapTarget bool - Clients []TcpReader - Servers []TcpReader - Ident string - Origin Capture - ReqResMatcher RequestResponseMatcher + optchecker reassembly.TCPOptionCheck + net, transport gopacket.Flow + isDNS bool + isTapTarget bool + clients []TcpReader + servers []TcpReader + ident string + origin Capture + reqResMatcher RequestResponseMatcher createdAt time.Time - StreamsMap *TcpStreamMap + streamsMap TcpStreamMap sync.Mutex } -func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { +func NewTcpStream(net gopacket.Flow, transport gopacket.Flow, tcp *layers.TCP, isTapTarget bool, fsmOptions reassembly.TCPSimpleFSMOptions, streamsMap TcpStreamMap, capture Capture) TcpStream { + return &tcpStream{ + net: net, + transport: transport, + isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, + isTapTarget: isTapTarget, + tcpState: reassembly.NewTCPSimpleFSM(fsmOptions), + ident: fmt.Sprintf("%s:%s", net, transport), + optchecker: reassembly.NewTCPOptionCheck(), + protoIdentifier: &ProtoIdentifier{}, + streamsMap: streamsMap, + origin: capture, + } +} + +func NewTcpStreamDummy(capture Capture) TcpStream { + return &tcpStream{ + origin: capture, + protoIdentifier: &ProtoIdentifier{}, + } +} + +func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { // FSM - if !t.TcpState.CheckState(tcp, dir) { - diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.Ident, t.TcpState.String()) + if !t.tcpState.CheckState(tcp, dir) { + diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpState.String()) diagnose.InternalStats.RejectFsm++ if !t.fsmerr { t.fsmerr = true @@ -50,9 +92,9 @@ func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassem } } // Options - err := t.Optchecker.Accept(tcp, ci, dir, nextSeq, start) + err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) if err != nil { - diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.Ident, err) + diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) diagnose.InternalStats.RejectOpt++ if !*nooptcheck { return false @@ -63,10 +105,10 @@ func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassem if *checksum { c, err := tcp.ComputeChecksum() if err != nil { - diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.Ident, err) + diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err) accept = false } else if c != 0x0 { - diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.Ident, c) + diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c) accept = false } } @@ -79,7 +121,7 @@ func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassem return accept } -func (t *TcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { +func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { dir, _, _, skip := sg.Info() length, saved := sg.Lengths() // update stats @@ -113,7 +155,7 @@ func (t *TcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass return } data := sg.Fetch(length) - if t.IsDNS { + if t.isDNS { dns := &layers.DNS{} var decoded []gopacket.LayerType if len(data) < 2 { @@ -140,44 +182,36 @@ func (t *TcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass if len(data) > 2+int(dnsSize) { sg.KeepFrom(2 + int(dnsSize)) } - } else if t.IsTapTarget { + } else if t.isTapTarget { if length > 0 { // This is where we pass the reassembled information onwards // This channel is read by an tcpReader object diagnose.AppStatsInst.IncReassembledTcpPayloadsCount() timestamp := ac.GetCaptureInfo().Timestamp if dir == reassembly.TCPDirClientToServer { - for i := range t.Clients { - reader := &t.Clients[i] - reader.Lock() - if !reader.isClosed { - reader.MsgQueue <- TcpReaderDataMsg{data, timestamp} - } - reader.Unlock() + for i := range t.clients { + reader := t.clients[i] + reader.SendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } } else { - for i := range t.Servers { - reader := &t.Servers[i] - reader.Lock() - if !reader.isClosed { - reader.MsgQueue <- TcpReaderDataMsg{data, timestamp} - } - reader.Unlock() + for i := range t.servers { + reader := t.servers[i] + reader.SendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } } } } } -func (t *TcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { - if t.IsTapTarget && !t.isClosed { +func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { + if t.isTapTarget && !t.isClosed { t.Close() } // do not remove the connection to allow last ACK return false } -func (t *TcpStream) Close() { +func (t *tcpStream) Close() { t.Lock() defer t.Unlock() @@ -187,40 +221,80 @@ func (t *TcpStream) Close() { t.isClosed = true - t.StreamsMap.Delete(t.Id) + t.streamsMap.Delete(t.id) - for i := range t.Clients { - reader := &t.Clients[i] + for i := range t.clients { + reader := t.clients[i] reader.Close() } - for i := range t.Servers { - reader := &t.Servers[i] + for i := range t.servers { + reader := t.servers[i] reader.Close() } } -func (t *TcpStream) CloseOtherProtocolDissectors(protocol *Protocol) { +func (t *tcpStream) CloseOtherProtocolDissectors(protocol *Protocol) { t.Lock() defer t.Unlock() - if t.ProtoIdentifier.IsClosedOthers { + if t.protoIdentifier.IsClosedOthers { return } - t.ProtoIdentifier.Protocol = protocol + t.protoIdentifier.Protocol = protocol - for i := range t.Clients { - reader := &t.Clients[i] - if reader.Extension.Protocol != t.ProtoIdentifier.Protocol { + for i := range t.clients { + reader := t.clients[i] + if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { reader.Close() } } - for i := range t.Servers { - reader := &t.Servers[i] - if reader.Extension.Protocol != t.ProtoIdentifier.Protocol { + for i := range t.servers { + reader := t.servers[i] + if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { reader.Close() } } - t.ProtoIdentifier.IsClosedOthers = true + t.protoIdentifier.IsClosedOthers = true +} + +func (t *tcpStream) AddClient(reader TcpReader) { + t.clients = append(t.clients, reader) +} + +func (t *tcpStream) AddServer(reader TcpReader) { + t.servers = append(t.servers, reader) +} + +func (t *tcpStream) ClientRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { + t.clients[index].Run(filteringOptions, wg) +} + +func (t *tcpStream) ServerRun(index int, filteringOptions *shared.TrafficFilteringOptions, wg *sync.WaitGroup) { + t.servers[index].Run(filteringOptions, wg) +} + +func (t *tcpStream) GetOrigin() Capture { + return t.origin +} + +func (t *tcpStream) GetProtoIdentifier() *ProtoIdentifier { + return t.protoIdentifier +} + +func (t *tcpStream) GetReqResMatcher() RequestResponseMatcher { + return t.reqResMatcher +} + +func (t *tcpStream) GetIsTapTarget() bool { + return t.isTapTarget +} + +func (t *tcpStream) GetId() int64 { + return t.id +} + +func (t *tcpStream) SetId(id int64) { + t.id = id } diff --git a/tap/api/tcp_streams_map.go b/tap/api/tcp_streams_map.go index 750721508..6983cc789 100644 --- a/tap/api/tcp_streams_map.go +++ b/tap/api/tcp_streams_map.go @@ -10,31 +10,43 @@ import ( "github.com/up9inc/mizu/tap/api/diagnose" ) -type TcpStreamMap struct { - Streams *sync.Map +type TcpStreamMap interface { + Range(f func(key, value interface{}) bool) + Store(key, value interface{}) + Delete(key interface{}) + NextId() int64 + CloseTimedoutTcpStreamChannels() +} + +type tcpStreamMap struct { + streams *sync.Map streamId int64 } -func NewTcpStreamMap() *TcpStreamMap { - return &TcpStreamMap{ - Streams: &sync.Map{}, +func NewTcpStreamMap() TcpStreamMap { + return &tcpStreamMap{ + streams: &sync.Map{}, } } -func (streamMap *TcpStreamMap) Store(key, value interface{}) { - streamMap.Streams.Store(key, value) +func (streamMap *tcpStreamMap) Range(f func(key, value interface{}) bool) { + streamMap.streams.Range(f) } -func (streamMap *TcpStreamMap) Delete(key interface{}) { - streamMap.Streams.Delete(key) +func (streamMap *tcpStreamMap) Store(key, value interface{}) { + streamMap.streams.Store(key, value) } -func (streamMap *TcpStreamMap) NextId() int64 { +func (streamMap *tcpStreamMap) Delete(key interface{}) { + streamMap.streams.Delete(key) +} + +func (streamMap *tcpStreamMap) NextId() int64 { streamMap.streamId++ return streamMap.streamId } -func (streamMap *TcpStreamMap) CloseTimedoutTcpStreamChannels() { +func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() { tcpStreamChannelTimeoutMs := GetTcpChannelTimeoutMs() closeTimedoutTcpChannelsIntervalMs := GetCloseTimedoutTcpChannelsInterval() logger.Log.Infof("Using %d ms as the close timedout TCP stream channels interval", closeTimedoutTcpChannelsIntervalMs/time.Millisecond) @@ -44,9 +56,9 @@ func (streamMap *TcpStreamMap) CloseTimedoutTcpStreamChannels() { <-ticker.C debug.FreeOSMemory() - streamMap.Streams.Range(func(key interface{}, value interface{}) bool { - stream := value.(*TcpStream) - if stream.ProtoIdentifier.Protocol == nil { + streamMap.streams.Range(func(key interface{}, value interface{}) bool { + stream := value.(*tcpStream) + if stream.protoIdentifier.Protocol == nil { if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) { stream.Close() diagnose.AppStatsInst.IncDroppedTcpStreams() diff --git a/tap/cleaner.go b/tap/cleaner.go index 796b72433..204e45ecc 100644 --- a/tap/cleaner.go +++ b/tap/cleaner.go @@ -22,7 +22,7 @@ type Cleaner struct { connectionTimeout time.Duration stats CleanerStats statsMutex sync.Mutex - streamsMap *api.TcpStreamMap + streamsMap api.TcpStreamMap } func (cl *Cleaner) clean() { @@ -33,8 +33,8 @@ func (cl *Cleaner) clean() { flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) cl.assemblerMutex.Unlock() - cl.streamsMap.Streams.Range(func(k, v interface{}) bool { - reqResMatcher := v.(*api.TcpStream).ReqResMatcher + cl.streamsMap.Range(func(k, v interface{}) bool { + reqResMatcher := v.(api.TcpStream).GetReqResMatcher() if reqResMatcher == nil { return true } diff --git a/tap/extensions/amqp/main.go b/tap/extensions/amqp/main.go index f4e5300c0..036aaac19 100644 --- a/tap/extensions/amqp/main.go +++ b/tap/extensions/amqp/main.go @@ -40,17 +40,17 @@ func (d dissecting) Ping() { const amqpRequest string = "amqp_request" -func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *shared.TrafficFilteringOptions) error { +func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *shared.TrafficFilteringOptions) error { r := AmqpReader{b} var remaining int var header *HeaderFrame connectionInfo := &api.ConnectionInfo{ - ClientIP: reader.TcpID.SrcIP, - ClientPort: reader.TcpID.SrcPort, - ServerIP: reader.TcpID.DstIP, - ServerPort: reader.TcpID.DstPort, + ClientIP: reader.GetTcpID().SrcIP, + ClientPort: reader.GetTcpID().SrcPort, + ServerIP: reader.GetTcpID().DstIP, + ServerPort: reader.GetTcpID().DstPort, IsOutgoing: true, } @@ -76,7 +76,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha var lastMethodFrameMessage Message for { - if reader.Parent.ProtoIdentifier.Protocol != nil && reader.Parent.ProtoIdentifier.Protocol != &protocol { + if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &protocol { return errors.New("Identified by another protocol") } @@ -113,12 +113,12 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha switch lastMethodFrameMessage.(type) { case *BasicPublish: eventBasicPublish.Body = f.Body - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *BasicDeliver: eventBasicDeliver.Body = f.Body - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) } case *MethodFrame: @@ -138,8 +138,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha NoWait: m.NoWait, Arguments: m.Arguments, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *BasicConsume: eventBasicConsume := &BasicConsume{ @@ -151,8 +151,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha NoWait: m.NoWait, Arguments: m.Arguments, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *BasicDeliver: eventBasicDeliver.ConsumerTag = m.ConsumerTag @@ -171,8 +171,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha NoWait: m.NoWait, Arguments: m.Arguments, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *ExchangeDeclare: eventExchangeDeclare := &ExchangeDeclare{ @@ -185,8 +185,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha NoWait: m.NoWait, Arguments: m.Arguments, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *ConnectionStart: eventConnectionStart := &ConnectionStart{ @@ -196,8 +196,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha Mechanisms: m.Mechanisms, Locales: m.Locales, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) case *ConnectionClose: eventConnectionClose := &ConnectionClose{ @@ -206,8 +206,8 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha ClassId: m.ClassId, MethodId: m.MethodId, } - reader.Parent.CloseOtherProtocolDissectors(&protocol) - emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, reader.CaptureTime, reader.Progress.Current(), reader.Emitter, reader.Parent.Origin) + reader.GetParent().CloseOtherProtocolDissectors(&protocol) + emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin()) } default: diff --git a/tap/extensions/amqp/main_test.go b/tap/extensions/amqp/main_test.go index 86d4e3d72..596456387 100644 --- a/tap/extensions/amqp/main_test.go +++ b/tap/extensions/amqp/main_test.go @@ -108,7 +108,6 @@ func TestDissect(t *testing.T) { Request: 0, Response: 0, } - protoIdentifier := &api.ProtoIdentifier{} // Request pathClient := _path @@ -124,18 +123,21 @@ func TestDissect(t *testing.T) { DstPort: "2", } reqResMatcher := dissector.NewResponseRequestMatcher() - reader := &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: true, - TcpID: tcpIDClient, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + stream := api.NewTcpStreamDummy(api.Pcap) + reader := api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDClient, + time.Time{}, + stream, + true, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferClient, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { panic(err) @@ -154,18 +156,20 @@ func TestDissect(t *testing.T) { SrcPort: "2", DstPort: "1", } - reader = &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: false, - TcpID: tcpIDServer, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + reader = api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDServer, + time.Time{}, + stream, + false, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferServer, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { panic(err) diff --git a/tap/extensions/http/main.go b/tap/extensions/http/main.go index 0256a46b6..d8d495c5c 100644 --- a/tap/extensions/http/main.go +++ b/tap/extensions/http/main.go @@ -87,15 +87,15 @@ func (d dissecting) Ping() { log.Printf("pong %s", http11protocol.Name) } -func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *shared.TrafficFilteringOptions) error { - reqResMatcher := reader.ReqResMatcher.(*requestResponseMatcher) +func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *shared.TrafficFilteringOptions) error { + reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher) var err error - isHTTP2, _ := checkIsHTTP2Connection(b, reader.IsClient) + isHTTP2, _ := checkIsHTTP2Connection(b, reader.GetIsClient()) var http2Assembler *Http2Assembler if isHTTP2 { - err = prepareHTTP2Connection(b, reader.IsClient) + err = prepareHTTP2Connection(b, reader.GetIsClient()) if err != nil { return err } @@ -106,74 +106,74 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha for { if switchingProtocolsHTTP2 { switchingProtocolsHTTP2 = false - isHTTP2, err = checkIsHTTP2Connection(b, reader.IsClient) + isHTTP2, err = checkIsHTTP2Connection(b, reader.GetIsClient()) if err != nil { break } - err = prepareHTTP2Connection(b, reader.IsClient) + err = prepareHTTP2Connection(b, reader.GetIsClient()) if err != nil { break } http2Assembler = createHTTP2Assembler(b) } - if reader.Parent.ProtoIdentifier.Protocol != nil && reader.Parent.ProtoIdentifier.Protocol != &http11protocol { + if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &http11protocol { return errors.New("Identified by another protocol") } if isHTTP2 { - err = handleHTTP2Stream(http2Assembler, reader.Progress, reader.Parent.Origin, reader.TcpID, reader.CaptureTime, reader.Emitter, options, reqResMatcher) + err = handleHTTP2Stream(http2Assembler, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher) if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { continue } - reader.Parent.CloseOtherProtocolDissectors(&http11protocol) - } else if reader.IsClient { + reader.GetParent().CloseOtherProtocolDissectors(&http11protocol) + } else if reader.GetIsClient() { var req *http.Request - switchingProtocolsHTTP2, req, err = handleHTTP1ClientStream(b, reader.Progress, reader.Parent.Origin, reader.TcpID, reader.CounterPair, reader.CaptureTime, reader.Emitter, options, reqResMatcher) + switchingProtocolsHTTP2, req, err = handleHTTP1ClientStream(b, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher) if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { continue } - reader.Parent.CloseOtherProtocolDissectors(&http11protocol) + reader.GetParent().CloseOtherProtocolDissectors(&http11protocol) // In case of an HTTP2 upgrade, duplicate the HTTP1 request into HTTP2 with stream ID 1 if switchingProtocolsHTTP2 { ident := fmt.Sprintf( "%s_%s_%s_%s_1_%s", - reader.TcpID.SrcIP, - reader.TcpID.DstIP, - reader.TcpID.SrcPort, - reader.TcpID.DstPort, + reader.GetTcpID().SrcIP, + reader.GetTcpID().DstIP, + reader.GetTcpID().SrcPort, + reader.GetTcpID().DstPort, "HTTP2", ) - item := reqResMatcher.registerRequest(ident, req, reader.CaptureTime, reader.Progress.Current(), req.ProtoMinor) + item := reqResMatcher.registerRequest(ident, req, reader.GetCaptureTime(), reader.GetReadProgress().Current(), req.ProtoMinor) if item != nil { item.ConnectionInfo = &api.ConnectionInfo{ - ClientIP: reader.TcpID.SrcIP, - ClientPort: reader.TcpID.SrcPort, - ServerIP: reader.TcpID.DstIP, - ServerPort: reader.TcpID.DstPort, + ClientIP: reader.GetTcpID().SrcIP, + ClientPort: reader.GetTcpID().SrcPort, + ServerIP: reader.GetTcpID().DstIP, + ServerPort: reader.GetTcpID().DstPort, IsOutgoing: true, } - item.Capture = reader.Parent.Origin - filterAndEmit(item, reader.Emitter, options) + item.Capture = reader.GetParent().GetOrigin() + filterAndEmit(item, reader.GetEmitter(), options) } } } else { - switchingProtocolsHTTP2, err = handleHTTP1ServerStream(b, reader.Progress, reader.Parent.Origin, reader.TcpID, reader.CounterPair, reader.CaptureTime, reader.Emitter, options, reqResMatcher) + switchingProtocolsHTTP2, err = handleHTTP1ServerStream(b, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher) if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { continue } - reader.Parent.CloseOtherProtocolDissectors(&http11protocol) + reader.GetParent().CloseOtherProtocolDissectors(&http11protocol) } } - if reader.Parent.ProtoIdentifier.Protocol == nil { + if reader.GetParent().GetProtoIdentifier().Protocol == nil { return err } diff --git a/tap/extensions/http/main_test.go b/tap/extensions/http/main_test.go index 780fd28fe..2d51c58c5 100644 --- a/tap/extensions/http/main_test.go +++ b/tap/extensions/http/main_test.go @@ -110,7 +110,6 @@ func TestDissect(t *testing.T) { Request: 0, Response: 0, } - protoIdentifier := &api.ProtoIdentifier{} // Request pathClient := _path @@ -126,18 +125,21 @@ func TestDissect(t *testing.T) { DstPort: "2", } reqResMatcher := dissector.NewResponseRequestMatcher() - reader := &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: true, - TcpID: tcpIDClient, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + stream := api.NewTcpStreamDummy(api.Pcap) + reader := api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDClient, + time.Time{}, + stream, + true, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferClient, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { panic(err) @@ -156,18 +158,20 @@ func TestDissect(t *testing.T) { SrcPort: "2", DstPort: "1", } - reader = &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: false, - TcpID: tcpIDServer, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + reader = api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDServer, + time.Time{}, + stream, + false, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferServer, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { panic(err) diff --git a/tap/extensions/kafka/main.go b/tap/extensions/kafka/main.go index 4c3c0c863..fff4148a8 100644 --- a/tap/extensions/kafka/main.go +++ b/tap/extensions/kafka/main.go @@ -36,25 +36,25 @@ func (d dissecting) Ping() { log.Printf("pong %s", _protocol.Name) } -func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *shared.TrafficFilteringOptions) error { - reqResMatcher := reader.ReqResMatcher.(*requestResponseMatcher) +func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *shared.TrafficFilteringOptions) error { + reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher) for { - if reader.Parent.ProtoIdentifier.Protocol != nil && reader.Parent.ProtoIdentifier.Protocol != &_protocol { + if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &_protocol { return errors.New("Identified by another protocol") } - if reader.IsClient { - _, _, err := ReadRequest(b, reader.TcpID, reader.CounterPair, reader.CaptureTime, reqResMatcher) + if reader.GetIsClient() { + _, _, err := ReadRequest(b, reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reqResMatcher) if err != nil { return err } - reader.Parent.CloseOtherProtocolDissectors(&_protocol) + reader.GetParent().CloseOtherProtocolDissectors(&_protocol) } else { - err := ReadResponse(b, reader.Parent.Origin, reader.TcpID, reader.CounterPair, reader.CaptureTime, reader.Emitter, reqResMatcher) + err := ReadResponse(b, reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), reqResMatcher) if err != nil { return err } - reader.Parent.CloseOtherProtocolDissectors(&_protocol) + reader.GetParent().CloseOtherProtocolDissectors(&_protocol) } } } diff --git a/tap/extensions/kafka/main_test.go b/tap/extensions/kafka/main_test.go index 166e1f82c..a076b8c13 100644 --- a/tap/extensions/kafka/main_test.go +++ b/tap/extensions/kafka/main_test.go @@ -108,7 +108,6 @@ func TestDissect(t *testing.T) { Request: 0, Response: 0, } - protoIdentifier := &api.ProtoIdentifier{} // Request pathClient := _path @@ -125,18 +124,21 @@ func TestDissect(t *testing.T) { } reqResMatcher := dissector.NewResponseRequestMatcher() reqResMatcher.SetMaxTry(10) - reader := &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: true, - TcpID: tcpIDClient, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + stream := api.NewTcpStreamDummy(api.Pcap) + reader := api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDClient, + time.Time{}, + stream, + true, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferClient, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { log.Println(err) @@ -155,18 +157,20 @@ func TestDissect(t *testing.T) { SrcPort: "2", DstPort: "1", } - reader = &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: false, - TcpID: tcpIDServer, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + reader = api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDServer, + time.Time{}, + stream, + false, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferServer, reader, options) if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { log.Println(err) diff --git a/tap/extensions/redis/main.go b/tap/extensions/redis/main.go index a7b751449..fc08e87e8 100644 --- a/tap/extensions/redis/main.go +++ b/tap/extensions/redis/main.go @@ -35,8 +35,8 @@ func (d dissecting) Ping() { log.Printf("pong %s", protocol.Name) } -func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *shared.TrafficFilteringOptions) error { - reqResMatcher := reader.ReqResMatcher.(*requestResponseMatcher) +func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *shared.TrafficFilteringOptions) error { + reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher) is := &RedisInputStream{ Reader: b, Buf: make([]byte, 8192), @@ -48,10 +48,10 @@ func (d dissecting) Dissect(b *bufio.Reader, reader *api.TcpReader, options *sha return err } - if reader.IsClient { - err = handleClientStream(reader.Progress, reader.Parent.Origin, reader.TcpID, reader.CounterPair, reader.CaptureTime, reader.Emitter, redisPacket, reqResMatcher) + if reader.GetIsClient() { + err = handleClientStream(reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), redisPacket, reqResMatcher) } else { - err = handleServerStream(reader.Progress, reader.Parent.Origin, reader.TcpID, reader.CounterPair, reader.CaptureTime, reader.Emitter, redisPacket, reqResMatcher) + err = handleServerStream(reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), redisPacket, reqResMatcher) } if err != nil { diff --git a/tap/extensions/redis/main_test.go b/tap/extensions/redis/main_test.go index baacdb747..d6c2e366f 100644 --- a/tap/extensions/redis/main_test.go +++ b/tap/extensions/redis/main_test.go @@ -109,7 +109,6 @@ func TestDissect(t *testing.T) { Request: 0, Response: 0, } - protoIdentifier := &api.ProtoIdentifier{} // Request pathClient := _path @@ -125,18 +124,21 @@ func TestDissect(t *testing.T) { DstPort: "2", } reqResMatcher := dissector.NewResponseRequestMatcher() - reader := &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: true, - TcpID: tcpIDClient, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + stream := api.NewTcpStreamDummy(api.Pcap) + reader := api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDClient, + time.Time{}, + stream, + true, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferClient, reader, options) if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF { log.Println(err) @@ -155,18 +157,20 @@ func TestDissect(t *testing.T) { SrcPort: "2", DstPort: "1", } - reader = &api.TcpReader{ - Progress: &api.ReadProgress{}, - Parent: &api.TcpStream{ - Origin: api.Pcap, - ProtoIdentifier: protoIdentifier, - }, - IsClient: false, - TcpID: tcpIDServer, - Emitter: emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - } + reader = api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + "", + tcpIDServer, + time.Time{}, + stream, + false, + false, + nil, + emitter, + counterPair, + reqResMatcher, + ) err = dissector.Dissect(bufferServer, reader, options) if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF { log.Println(err) diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index a1a3c5de6..eef4fe14c 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -179,7 +179,7 @@ func initializePacketSources() error { return err } -func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*api.TcpStreamMap, *tcpAssembler) { +func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (api.TcpStreamMap, *tcpAssembler) { streamsMap := api.NewTcpStreamMap() diagnose.InitializeErrorsMap(*debug, *verbose, *quiet) @@ -196,7 +196,7 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI return streamsMap, assembler } -func startPassiveTapper(streamsMap *api.TcpStreamMap, assembler *tcpAssembler) { +func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) { go streamsMap.CloseTimedoutTcpStreamChannels() diagnose.AppStatsInst.SetStartTime(time.Now()) diff --git a/tap/tcp_assembler.go b/tap/tcp_assembler.go index 8abfe48c2..3e7717c64 100644 --- a/tap/tcp_assembler.go +++ b/tap/tcp_assembler.go @@ -36,7 +36,7 @@ func (c *context) GetCaptureInfo() gopacket.CaptureInfo { return c.CaptureInfo } -func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpAssembler { +func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpAssembler { var emitter api.Emitter = &api.Emitting{ AppStats: &diagnose.AppStatsInst, OutputChannel: outputItems, diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 7df7a41be..edc973a33 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -3,6 +3,7 @@ package tap import ( "fmt" "sync" + "time" "github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/tap/api" @@ -20,13 +21,13 @@ import ( */ type tcpStreamFactory struct { wg sync.WaitGroup - Emitter api.Emitter - streamsMap *api.TcpStreamMap + emitter api.Emitter + streamsMap api.TcpStreamMap ownIps []string opts *TapOpts } -func NewTcpStreamFactory(emitter api.Emitter, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory { +func NewTcpStreamFactory(emitter api.Emitter, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory { var ownIps []string if localhostIPs, err := getLocalhostIPs(); err != nil { @@ -39,7 +40,7 @@ func NewTcpStreamFactory(emitter api.Emitter, streamsMap *api.TcpStreamMap, opts } return &tcpStreamFactory{ - Emitter: emitter, + emitter: emitter, streamsMap: streamsMap, ownIps: ownIps, opts: opts, @@ -57,69 +58,64 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort) isTapTarget := props.isTapTarget - stream := &api.TcpStream{ - Net: net, - Transport: transport, - IsDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, - IsTapTarget: isTapTarget, - TcpState: reassembly.NewTCPSimpleFSM(fsmOptions), - Ident: fmt.Sprintf("%s:%s", net, transport), - Optchecker: reassembly.NewTCPOptionCheck(), - ProtoIdentifier: &api.ProtoIdentifier{}, - StreamsMap: factory.streamsMap, - Origin: getPacketOrigin(ac), - } - if stream.IsTapTarget { - stream.Id = factory.streamsMap.NextId() + stream := api.NewTcpStream(net, transport, tcp, isTapTarget, fsmOptions, factory.streamsMap, getPacketOrigin(ac)) + if stream.GetIsTapTarget() { + stream.SetId(factory.streamsMap.NextId()) for i, extension := range extensions { reqResMatcher := extension.Dissector.NewResponseRequestMatcher() counterPair := &api.CounterPair{ Request: 0, Response: 0, } - stream.Clients = append(stream.Clients, api.TcpReader{ - MsgQueue: make(chan api.TcpReaderDataMsg), - Progress: &api.ReadProgress{}, - Ident: fmt.Sprintf("%s %s", net, transport), - TcpID: &api.TcpID{ - SrcIP: srcIp, - DstIP: dstIp, - SrcPort: srcPort, - DstPort: dstPort, - }, - Parent: stream, - IsClient: true, - IsOutgoing: props.isOutgoing, - Extension: extension, - Emitter: factory.Emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - }) - stream.Servers = append(stream.Servers, api.TcpReader{ - MsgQueue: make(chan api.TcpReaderDataMsg), - Progress: &api.ReadProgress{}, - Ident: fmt.Sprintf("%s %s", net, transport), - TcpID: &api.TcpID{ - SrcIP: net.Dst().String(), - DstIP: net.Src().String(), - SrcPort: transport.Dst().String(), - DstPort: transport.Src().String(), - }, - Parent: stream, - IsClient: false, - IsOutgoing: props.isOutgoing, - Extension: extension, - Emitter: factory.Emitter, - CounterPair: counterPair, - ReqResMatcher: reqResMatcher, - }) + stream.AddClient( + api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + fmt.Sprintf("%s %s", net, transport), + &api.TcpID{ + SrcIP: srcIp, + DstIP: dstIp, + SrcPort: srcPort, + DstPort: dstPort, + }, + time.Time{}, + stream, + true, + props.isOutgoing, + extension, + factory.emitter, + counterPair, + reqResMatcher, + ), + ) + stream.AddServer( + api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + &api.ReadProgress{}, + fmt.Sprintf("%s %s", net, transport), + &api.TcpID{ + SrcIP: net.Dst().String(), + DstIP: net.Src().String(), + SrcPort: transport.Dst().String(), + DstPort: transport.Src().String(), + }, + time.Time{}, + stream, + false, + props.isOutgoing, + extension, + factory.emitter, + counterPair, + reqResMatcher, + ), + ) - factory.streamsMap.Store(stream.Id, stream) + factory.streamsMap.Store(stream.GetId(), stream) factory.wg.Add(2) // Start reading from channel stream.reader.bytes - go stream.Clients[i].Run(filteringOptions, &factory.wg) - go stream.Servers[i].Run(filteringOptions, &factory.wg) + go stream.ClientRun(i, filteringOptions, &factory.wg) + go stream.ServerRun(i, filteringOptions, &factory.wg) } } return stream diff --git a/tap/tlstapper/tls_poller.go b/tap/tlstapper/tls_poller.go index 8b79a1ebe..c9236c018 100644 --- a/tap/tlstapper/tls_poller.go +++ b/tap/tlstapper/tls_poller.go @@ -167,18 +167,21 @@ func dissect(extension *api.Extension, reader *tlsReader, isRequest bool, tcpid tlsEmitter *tlsEmitter, options *shared.TrafficFilteringOptions, reqResMatcher api.RequestResponseMatcher) { b := bufio.NewReader(reader) - tcpReader := &api.TcpReader{ - Progress: reader.progress, - CaptureTime: time.Now(), - Parent: &api.TcpStream{ - Origin: api.Ebpf, - ProtoIdentifier: &api.ProtoIdentifier{}, - }, - IsClient: isRequest, - TcpID: tcpid, - Emitter: tlsEmitter, - ReqResMatcher: reqResMatcher, - } + tcpStream := api.NewTcpStreamDummy(api.Ebpf) + tcpReader := api.NewTcpReader( + make(chan api.TcpReaderDataMsg), + reader.progress, + "", + tcpid, + time.Now(), + tcpStream, + isRequest, + false, + nil, + tlsEmitter, + &api.CounterPair{}, + reqResMatcher, + ) err := extension.Dissector.Dissect(b, tcpReader, options)