From 1de50b0572f4cd04e348e3fcaa087525e04742b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Mert=20Y=C4=B1ld=C4=B1ran?= Date: Sun, 1 May 2022 06:16:22 -0700 Subject: [PATCH] Fix the request-response matcher maps iteration in `clean()` method and share the streams map with the TLS tapper (#1059) * Fix `panic: interface conversion: api.RequestResponseMatcher is nil, not *http.requestResponseMatcher` error Also fix the request-response matcher maps iteration in `clean()` method. * Fix the mocks in the unit tests * Remove unnecessary fields from `tlsPoller` and implement `SetProtocol` method * Use concrete types in `tap` package * Share the streams map with the TLS tapper * Check interface conversion error --- tap/api/api.go | 2 +- tap/cleaner.go | 12 ++++--- tap/extensions/amqp/tcp_stream_mock_test.go | 6 ++-- tap/extensions/http/tcp_stream_mock_test.go | 6 ++-- tap/extensions/kafka/tcp_stream_mock_test.go | 6 ++-- tap/extensions/redis/tcp_stream_mock_test.go | 6 ++-- tap/passive_tapper.go | 17 +++++----- tap/tcp_reassembly_stream.go | 4 +-- tap/tcp_stream.go | 34 +++++++++++--------- tap/tcp_stream_factory.go | 5 +-- tap/tcp_streams_map.go | 8 ++++- tap/tlstapper/tls_poller.go | 14 ++++---- tap/tlstapper/tls_reader.go | 2 +- tap/tlstapper/tls_stream.go | 4 +-- tap/tlstapper/tls_tapper.go | 4 +-- 15 files changed, 73 insertions(+), 57 deletions(-) diff --git a/tap/api/api.go b/tap/api/api.go index 648b14fb4..74929210b 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -426,7 +426,7 @@ type TcpStream interface { SetProtocol(protocol *Protocol) GetOrigin() Capture GetProtoIdentifier() *ProtoIdentifier - GetReqResMatcher() RequestResponseMatcher + GetReqResMatchers() []RequestResponseMatcher GetIsTapTarget() bool GetIsClosed() bool } diff --git a/tap/cleaner.go b/tap/cleaner.go index 8d3e4616e..04268dbfc 100644 --- a/tap/cleaner.go +++ b/tap/cleaner.go @@ -34,12 +34,14 @@ func (cl *Cleaner) clean() { cl.assemblerMutex.Unlock() cl.streamsMap.Range(func(k, v interface{}) bool { - reqResMatcher := v.(api.TcpStream).GetReqResMatcher() - if reqResMatcher == nil { - return true + reqResMatchers := v.(api.TcpStream).GetReqResMatchers() + for _, reqResMatcher := range reqResMatchers { + if reqResMatcher == nil { + continue + } + deleted := deleteOlderThan(reqResMatcher.GetMap(), startCleanTime.Add(-cl.connectionTimeout)) + cl.stats.deleted += deleted } - deleted := deleteOlderThan(reqResMatcher.GetMap(), startCleanTime.Add(-cl.connectionTimeout)) - cl.stats.deleted += deleted return true }) diff --git a/tap/extensions/amqp/tcp_stream_mock_test.go b/tap/extensions/amqp/tcp_stream_mock_test.go index 23140f169..ae68e5982 100644 --- a/tap/extensions/amqp/tcp_stream_mock_test.go +++ b/tap/extensions/amqp/tcp_stream_mock_test.go @@ -11,7 +11,7 @@ type tcpStream struct { protoIdentifier *api.ProtoIdentifier isTapTarget bool origin api.Capture - reqResMatcher api.RequestResponseMatcher + reqResMatchers []api.RequestResponseMatcher sync.Mutex } @@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { return t.protoIdentifier } -func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reqResMatcher +func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { + return t.reqResMatchers } func (t *tcpStream) GetIsTapTarget() bool { diff --git a/tap/extensions/http/tcp_stream_mock_test.go b/tap/extensions/http/tcp_stream_mock_test.go index 2a06c12bf..ca1b5ee8a 100644 --- a/tap/extensions/http/tcp_stream_mock_test.go +++ b/tap/extensions/http/tcp_stream_mock_test.go @@ -11,7 +11,7 @@ type tcpStream struct { protoIdentifier *api.ProtoIdentifier isTapTarget bool origin api.Capture - reqResMatcher api.RequestResponseMatcher + reqResMatchers []api.RequestResponseMatcher sync.Mutex } @@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { return t.protoIdentifier } -func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reqResMatcher +func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { + return t.reqResMatchers } func (t *tcpStream) GetIsTapTarget() bool { diff --git a/tap/extensions/kafka/tcp_stream_mock_test.go b/tap/extensions/kafka/tcp_stream_mock_test.go index 69bc224a9..d53006e88 100644 --- a/tap/extensions/kafka/tcp_stream_mock_test.go +++ b/tap/extensions/kafka/tcp_stream_mock_test.go @@ -11,7 +11,7 @@ type tcpStream struct { protoIdentifier *api.ProtoIdentifier isTapTarget bool origin api.Capture - reqResMatcher api.RequestResponseMatcher + reqResMatchers []api.RequestResponseMatcher sync.Mutex } @@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { return t.protoIdentifier } -func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reqResMatcher +func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { + return t.reqResMatchers } func (t *tcpStream) GetIsTapTarget() bool { diff --git a/tap/extensions/redis/tcp_stream_mock_test.go b/tap/extensions/redis/tcp_stream_mock_test.go index 038f06738..304c85da0 100644 --- a/tap/extensions/redis/tcp_stream_mock_test.go +++ b/tap/extensions/redis/tcp_stream_mock_test.go @@ -11,7 +11,7 @@ type tcpStream struct { protoIdentifier *api.ProtoIdentifier isTapTarget bool origin api.Capture - reqResMatcher api.RequestResponseMatcher + reqResMatchers []api.RequestResponseMatcher sync.Mutex } @@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { return t.protoIdentifier } -func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reqResMatcher +func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { + return t.reqResMatchers } func (t *tcpStream) GetIsTapTarget() bool { diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index e310a931e..43acb3b84 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -69,10 +69,12 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, extensions = extensionsRef filteringOptions = options + streamsMap := NewTcpStreamMap() + if *tls { for _, e := range extensions { if e.Protocol.Name == "http" { - tlsTapperInstance = startTlsTapper(e, outputItems, options) + tlsTapperInstance = startTlsTapper(e, outputItems, options, streamsMap) break } } @@ -82,7 +84,7 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, diagnose.StartMemoryProfiler(os.Getenv(MemoryProfilingDumpPath), os.Getenv(MemoryProfilingTimeIntervalSeconds)) } - streamsMap, assembler := initializePassiveTapper(opts, outputItems) + assembler := initializePassiveTapper(opts, outputItems, streamsMap) go startPassiveTapper(streamsMap, assembler) } @@ -181,9 +183,7 @@ func initializePacketSources() error { return err } -func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (api.TcpStreamMap, *tcpAssembler) { - streamsMap := NewTcpStreamMap() - +func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap) *tcpAssembler { diagnose.InitializeErrorsMap(*debug, *verbose, *quiet) diagnose.InitializeTapperInternalStats() @@ -195,7 +195,7 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI assembler := NewTcpAssembler(outputItems, streamsMap, opts) - return streamsMap, assembler + return assembler } func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) { @@ -232,7 +232,8 @@ func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) { logger.Log.Infof("AppStats: %v", diagnose.AppStats) } -func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper { +func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, + options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) *tlstapper.TlsTapper { tls := tlstapper.TlsTapper{} chunksBufferSize := os.Getpagesize() * 100 logBufferSize := os.Getpagesize() @@ -262,7 +263,7 @@ func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChanne } go tls.PollForLogging() - go tls.Poll(emitter, options) + go tls.Poll(emitter, options, streamsMap) return &tls } diff --git a/tap/tcp_reassembly_stream.go b/tap/tcp_reassembly_stream.go index 8cc1f035e..36edefc41 100644 --- a/tap/tcp_reassembly_stream.go +++ b/tap/tcp_reassembly_stream.go @@ -148,12 +148,12 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas stream := t.tcpStream.(*tcpStream) if dir == reassembly.TCPDirClientToServer { for i := range stream.getClients() { - reader := stream.getClient(i).(*tcpReader) + reader := stream.getClient(i) reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } } else { for i := range stream.getServers() { - reader := stream.getServer(i).(*tcpReader) + reader := stream.getServer(i) reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } } diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index d67d639fd..753d00b20 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -17,10 +17,10 @@ type tcpStream struct { isClosed bool protoIdentifier *api.ProtoIdentifier isTapTarget bool - clients []api.TcpReader - servers []api.TcpReader + clients []*tcpReader + servers []*tcpReader origin api.Capture - reqResMatcher api.RequestResponseMatcher + reqResMatchers []api.RequestResponseMatcher createdAt time.Time streamsMap api.TcpStreamMap sync.Mutex @@ -57,38 +57,42 @@ func (t *tcpStream) close() { for i := range t.clients { reader := t.clients[i] - reader.(*tcpReader).close() + reader.close() } for i := range t.servers { reader := t.servers[i] - reader.(*tcpReader).close() + reader.close() } } -func (t *tcpStream) addClient(reader api.TcpReader) { +func (t *tcpStream) addClient(reader *tcpReader) { t.clients = append(t.clients, reader) } -func (t *tcpStream) addServer(reader api.TcpReader) { +func (t *tcpStream) addServer(reader *tcpReader) { t.servers = append(t.servers, reader) } -func (t *tcpStream) getClients() []api.TcpReader { +func (t *tcpStream) getClients() []*tcpReader { return t.clients } -func (t *tcpStream) getServers() []api.TcpReader { +func (t *tcpStream) getServers() []*tcpReader { return t.servers } -func (t *tcpStream) getClient(index int) api.TcpReader { +func (t *tcpStream) getClient(index int) *tcpReader { return t.clients[index] } -func (t *tcpStream) getServer(index int) api.TcpReader { +func (t *tcpStream) getServer(index int) *tcpReader { return t.servers[index] } +func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) { + t.reqResMatchers = append(t.reqResMatchers, reqResMatcher) +} + func (t *tcpStream) SetProtocol(protocol *api.Protocol) { t.Lock() defer t.Unlock() @@ -102,13 +106,13 @@ func (t *tcpStream) SetProtocol(protocol *api.Protocol) { for i := range t.clients { reader := t.clients[i] if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { - reader.(*tcpReader).close() + reader.close() } } for i := range t.servers { reader := t.servers[i] if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { - reader.(*tcpReader).close() + reader.close() } } @@ -123,8 +127,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { return t.protoIdentifier } -func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reqResMatcher +func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { + return t.reqResMatchers } func (t *tcpStream) GetIsTapTarget() bool { diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 09d826f15..7803d9ec2 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -64,6 +64,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay stream.setId(factory.streamsMap.NextId()) for i, extension := range extensions { reqResMatcher := extension.Dissector.NewResponseRequestMatcher() + stream.addReqResMatcher(reqResMatcher) counterPair := &api.CounterPair{ Request: 0, Response: 0, @@ -114,8 +115,8 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay factory.streamsMap.Store(stream.getId(), stream) factory.wg.Add(2) - go stream.getClient(i).(*tcpReader).run(filteringOptions, &factory.wg) - go stream.getServer(i).(*tcpReader).run(filteringOptions, &factory.wg) + go stream.getClient(i).run(filteringOptions, &factory.wg) + go stream.getServer(i).run(filteringOptions, &factory.wg) } } return reassemblyStream diff --git a/tap/tcp_streams_map.go b/tap/tcp_streams_map.go index 1f8c8f4e6..523a06f2a 100644 --- a/tap/tcp_streams_map.go +++ b/tap/tcp_streams_map.go @@ -48,7 +48,13 @@ func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() { <-ticker.C streamMap.streams.Range(func(key interface{}, value interface{}) bool { - stream := value.(*tcpStream) + // `*tlsStream` is not yet applicable to this routine. + // So, we cast into `(*tcpStream)` and ignore `*tlsStream` + stream, ok := value.(*tcpStream) + if !ok { + return true + } + if stream.protoIdentifier.Protocol == nil { if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) { stream.close() diff --git a/tap/tlstapper/tls_poller.go b/tap/tlstapper/tls_poller.go index 86e06dad9..4acc8ca0f 100644 --- a/tap/tlstapper/tls_poller.go +++ b/tap/tlstapper/tls_poller.go @@ -59,7 +59,7 @@ func (p *tlsPoller) close() error { return p.chunksReader.Close() } -func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptions) { +func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) { chunks := make(chan *tlsChunk) go p.pollChunksPerfBuffer(chunks) @@ -71,7 +71,7 @@ func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptio return } - if err := p.handleTlsChunk(chunk, p.extension, emitter, options); err != nil { + if err := p.handleTlsChunk(chunk, p.extension, emitter, options, streamsMap); err != nil { LogError(err) } case key := <-p.closedReaders: @@ -115,8 +115,8 @@ func (p *tlsPoller) pollChunksPerfBuffer(chunks chan<- *tlsChunk) { } } -func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, - emitter api.Emitter, options *api.TrafficFilteringOptions) error { +func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, emitter api.Emitter, + options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) error { ip, port, err := chunk.getAddress() if err != nil { @@ -127,7 +127,7 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, reader, exists := p.readers[key] if !exists { - reader = p.startNewTlsReader(chunk, ip, port, key, emitter, extension, options) + reader = p.startNewTlsReader(chunk, ip, port, key, emitter, extension, options, streamsMap) p.readers[key] = reader } @@ -142,7 +142,8 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, } func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, - emitter api.Emitter, extension *api.Extension, options *api.TrafficFilteringOptions) *tlsReader { + emitter api.Emitter, extension *api.Extension, options *api.TrafficFilteringOptions, + streamsMap api.TcpStreamMap) *tlsReader { tcpid := p.buildTcpId(chunk, ip, port) @@ -173,6 +174,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, k reader: reader, protoIdentifier: &api.ProtoIdentifier{}, } + streamsMap.Store(streamsMap.NextId(), stream) reader.parent = stream diff --git a/tap/tlstapper/tls_reader.go b/tap/tlstapper/tls_reader.go index 45ebf0ce2..fa1c91611 100644 --- a/tap/tlstapper/tls_reader.go +++ b/tap/tlstapper/tls_reader.go @@ -19,7 +19,7 @@ type tlsReader struct { extension *api.Extension emitter api.Emitter counterPair *api.CounterPair - parent api.TcpStream + parent *tlsStream reqResMatcher api.RequestResponseMatcher } diff --git a/tap/tlstapper/tls_stream.go b/tap/tlstapper/tls_stream.go index 99134daa5..09c447f13 100644 --- a/tap/tlstapper/tls_stream.go +++ b/tap/tlstapper/tls_stream.go @@ -19,8 +19,8 @@ func (t *tlsStream) SetProtocol(protocol *api.Protocol) { t.protoIdentifier.Protocol = protocol } -func (t *tlsStream) GetReqResMatcher() api.RequestResponseMatcher { - return t.reader.reqResMatcher +func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher { + return []api.RequestResponseMatcher{t.reader.reqResMatcher} } func (t *tlsStream) GetIsTapTarget() bool { diff --git a/tap/tlstapper/tls_tapper.go b/tap/tlstapper/tls_tapper.go index c7fbc8110..63bb340d7 100644 --- a/tap/tlstapper/tls_tapper.go +++ b/tap/tlstapper/tls_tapper.go @@ -50,8 +50,8 @@ func (t *TlsTapper) Init(chunksBufferSize int, logBufferSize int, procfs string, return t.poller.init(&t.bpfObjects, chunksBufferSize) } -func (t *TlsTapper) Poll(emitter api.Emitter, options *api.TrafficFilteringOptions) { - t.poller.poll(emitter, options) +func (t *TlsTapper) Poll(emitter api.Emitter, options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) { + t.poller.poll(emitter, options, streamsMap) } func (t *TlsTapper) PollForLogging() {