mirror of
https://github.com/kubeshark/kubeshark.git
synced 2025-06-26 16:24:54 +00:00
Fix the request-response matcher maps iteration in clean()
method and share the streams map with the TLS tapper (#1059)
* Fix `panic: interface conversion: api.RequestResponseMatcher is nil, not *http.requestResponseMatcher` error Also fix the request-response matcher maps iteration in `clean()` method. * Fix the mocks in the unit tests * Remove unnecessary fields from `tlsPoller` and implement `SetProtocol` method * Use concrete types in `tap` package * Share the streams map with the TLS tapper * Check interface conversion error
This commit is contained in:
parent
0881dad17f
commit
1de50b0572
@ -426,7 +426,7 @@ type TcpStream interface {
|
||||
SetProtocol(protocol *Protocol)
|
||||
GetOrigin() Capture
|
||||
GetProtoIdentifier() *ProtoIdentifier
|
||||
GetReqResMatcher() RequestResponseMatcher
|
||||
GetReqResMatchers() []RequestResponseMatcher
|
||||
GetIsTapTarget() bool
|
||||
GetIsClosed() bool
|
||||
}
|
||||
|
@ -34,12 +34,14 @@ func (cl *Cleaner) clean() {
|
||||
cl.assemblerMutex.Unlock()
|
||||
|
||||
cl.streamsMap.Range(func(k, v interface{}) bool {
|
||||
reqResMatcher := v.(api.TcpStream).GetReqResMatcher()
|
||||
reqResMatchers := v.(api.TcpStream).GetReqResMatchers()
|
||||
for _, reqResMatcher := range reqResMatchers {
|
||||
if reqResMatcher == nil {
|
||||
return true
|
||||
continue
|
||||
}
|
||||
deleted := deleteOlderThan(reqResMatcher.GetMap(), startCleanTime.Add(-cl.connectionTimeout))
|
||||
cl.stats.deleted += deleted
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
|
@ -11,7 +11,7 @@ type tcpStream struct {
|
||||
protoIdentifier *api.ProtoIdentifier
|
||||
isTapTarget bool
|
||||
origin api.Capture
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
reqResMatchers []api.RequestResponseMatcher
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
|
||||
return t.protoIdentifier
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reqResMatcher
|
||||
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return t.reqResMatchers
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetIsTapTarget() bool {
|
||||
|
@ -11,7 +11,7 @@ type tcpStream struct {
|
||||
protoIdentifier *api.ProtoIdentifier
|
||||
isTapTarget bool
|
||||
origin api.Capture
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
reqResMatchers []api.RequestResponseMatcher
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
|
||||
return t.protoIdentifier
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reqResMatcher
|
||||
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return t.reqResMatchers
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetIsTapTarget() bool {
|
||||
|
@ -11,7 +11,7 @@ type tcpStream struct {
|
||||
protoIdentifier *api.ProtoIdentifier
|
||||
isTapTarget bool
|
||||
origin api.Capture
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
reqResMatchers []api.RequestResponseMatcher
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
|
||||
return t.protoIdentifier
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reqResMatcher
|
||||
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return t.reqResMatchers
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetIsTapTarget() bool {
|
||||
|
@ -11,7 +11,7 @@ type tcpStream struct {
|
||||
protoIdentifier *api.ProtoIdentifier
|
||||
isTapTarget bool
|
||||
origin api.Capture
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
reqResMatchers []api.RequestResponseMatcher
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
|
||||
return t.protoIdentifier
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reqResMatcher
|
||||
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return t.reqResMatchers
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetIsTapTarget() bool {
|
||||
|
@ -69,10 +69,12 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem,
|
||||
extensions = extensionsRef
|
||||
filteringOptions = options
|
||||
|
||||
streamsMap := NewTcpStreamMap()
|
||||
|
||||
if *tls {
|
||||
for _, e := range extensions {
|
||||
if e.Protocol.Name == "http" {
|
||||
tlsTapperInstance = startTlsTapper(e, outputItems, options)
|
||||
tlsTapperInstance = startTlsTapper(e, outputItems, options, streamsMap)
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -82,7 +84,7 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem,
|
||||
diagnose.StartMemoryProfiler(os.Getenv(MemoryProfilingDumpPath), os.Getenv(MemoryProfilingTimeIntervalSeconds))
|
||||
}
|
||||
|
||||
streamsMap, assembler := initializePassiveTapper(opts, outputItems)
|
||||
assembler := initializePassiveTapper(opts, outputItems, streamsMap)
|
||||
go startPassiveTapper(streamsMap, assembler)
|
||||
}
|
||||
|
||||
@ -181,9 +183,7 @@ func initializePacketSources() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (api.TcpStreamMap, *tcpAssembler) {
|
||||
streamsMap := NewTcpStreamMap()
|
||||
|
||||
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap) *tcpAssembler {
|
||||
diagnose.InitializeErrorsMap(*debug, *verbose, *quiet)
|
||||
diagnose.InitializeTapperInternalStats()
|
||||
|
||||
@ -195,7 +195,7 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI
|
||||
|
||||
assembler := NewTcpAssembler(outputItems, streamsMap, opts)
|
||||
|
||||
return streamsMap, assembler
|
||||
return assembler
|
||||
}
|
||||
|
||||
func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) {
|
||||
@ -232,7 +232,8 @@ func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) {
|
||||
logger.Log.Infof("AppStats: %v", diagnose.AppStats)
|
||||
}
|
||||
|
||||
func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper {
|
||||
func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem,
|
||||
options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) *tlstapper.TlsTapper {
|
||||
tls := tlstapper.TlsTapper{}
|
||||
chunksBufferSize := os.Getpagesize() * 100
|
||||
logBufferSize := os.Getpagesize()
|
||||
@ -262,7 +263,7 @@ func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChanne
|
||||
}
|
||||
|
||||
go tls.PollForLogging()
|
||||
go tls.Poll(emitter, options)
|
||||
go tls.Poll(emitter, options, streamsMap)
|
||||
|
||||
return &tls
|
||||
}
|
||||
|
@ -148,12 +148,12 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas
|
||||
stream := t.tcpStream.(*tcpStream)
|
||||
if dir == reassembly.TCPDirClientToServer {
|
||||
for i := range stream.getClients() {
|
||||
reader := stream.getClient(i).(*tcpReader)
|
||||
reader := stream.getClient(i)
|
||||
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
|
||||
}
|
||||
} else {
|
||||
for i := range stream.getServers() {
|
||||
reader := stream.getServer(i).(*tcpReader)
|
||||
reader := stream.getServer(i)
|
||||
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
|
||||
}
|
||||
}
|
||||
|
@ -17,10 +17,10 @@ type tcpStream struct {
|
||||
isClosed bool
|
||||
protoIdentifier *api.ProtoIdentifier
|
||||
isTapTarget bool
|
||||
clients []api.TcpReader
|
||||
servers []api.TcpReader
|
||||
clients []*tcpReader
|
||||
servers []*tcpReader
|
||||
origin api.Capture
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
reqResMatchers []api.RequestResponseMatcher
|
||||
createdAt time.Time
|
||||
streamsMap api.TcpStreamMap
|
||||
sync.Mutex
|
||||
@ -57,38 +57,42 @@ func (t *tcpStream) close() {
|
||||
|
||||
for i := range t.clients {
|
||||
reader := t.clients[i]
|
||||
reader.(*tcpReader).close()
|
||||
reader.close()
|
||||
}
|
||||
for i := range t.servers {
|
||||
reader := t.servers[i]
|
||||
reader.(*tcpReader).close()
|
||||
reader.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpStream) addClient(reader api.TcpReader) {
|
||||
func (t *tcpStream) addClient(reader *tcpReader) {
|
||||
t.clients = append(t.clients, reader)
|
||||
}
|
||||
|
||||
func (t *tcpStream) addServer(reader api.TcpReader) {
|
||||
func (t *tcpStream) addServer(reader *tcpReader) {
|
||||
t.servers = append(t.servers, reader)
|
||||
}
|
||||
|
||||
func (t *tcpStream) getClients() []api.TcpReader {
|
||||
func (t *tcpStream) getClients() []*tcpReader {
|
||||
return t.clients
|
||||
}
|
||||
|
||||
func (t *tcpStream) getServers() []api.TcpReader {
|
||||
func (t *tcpStream) getServers() []*tcpReader {
|
||||
return t.servers
|
||||
}
|
||||
|
||||
func (t *tcpStream) getClient(index int) api.TcpReader {
|
||||
func (t *tcpStream) getClient(index int) *tcpReader {
|
||||
return t.clients[index]
|
||||
}
|
||||
|
||||
func (t *tcpStream) getServer(index int) api.TcpReader {
|
||||
func (t *tcpStream) getServer(index int) *tcpReader {
|
||||
return t.servers[index]
|
||||
}
|
||||
|
||||
func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) {
|
||||
t.reqResMatchers = append(t.reqResMatchers, reqResMatcher)
|
||||
}
|
||||
|
||||
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
@ -102,13 +106,13 @@ func (t *tcpStream) SetProtocol(protocol *api.Protocol) {
|
||||
for i := range t.clients {
|
||||
reader := t.clients[i]
|
||||
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
|
||||
reader.(*tcpReader).close()
|
||||
reader.close()
|
||||
}
|
||||
}
|
||||
for i := range t.servers {
|
||||
reader := t.servers[i]
|
||||
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
|
||||
reader.(*tcpReader).close()
|
||||
reader.close()
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,8 +127,8 @@ func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
|
||||
return t.protoIdentifier
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reqResMatcher
|
||||
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return t.reqResMatchers
|
||||
}
|
||||
|
||||
func (t *tcpStream) GetIsTapTarget() bool {
|
||||
|
@ -64,6 +64,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay
|
||||
stream.setId(factory.streamsMap.NextId())
|
||||
for i, extension := range extensions {
|
||||
reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
|
||||
stream.addReqResMatcher(reqResMatcher)
|
||||
counterPair := &api.CounterPair{
|
||||
Request: 0,
|
||||
Response: 0,
|
||||
@ -114,8 +115,8 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay
|
||||
factory.streamsMap.Store(stream.getId(), stream)
|
||||
|
||||
factory.wg.Add(2)
|
||||
go stream.getClient(i).(*tcpReader).run(filteringOptions, &factory.wg)
|
||||
go stream.getServer(i).(*tcpReader).run(filteringOptions, &factory.wg)
|
||||
go stream.getClient(i).run(filteringOptions, &factory.wg)
|
||||
go stream.getServer(i).run(filteringOptions, &factory.wg)
|
||||
}
|
||||
}
|
||||
return reassemblyStream
|
||||
|
@ -48,7 +48,13 @@ func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() {
|
||||
<-ticker.C
|
||||
|
||||
streamMap.streams.Range(func(key interface{}, value interface{}) bool {
|
||||
stream := value.(*tcpStream)
|
||||
// `*tlsStream` is not yet applicable to this routine.
|
||||
// So, we cast into `(*tcpStream)` and ignore `*tlsStream`
|
||||
stream, ok := value.(*tcpStream)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if stream.protoIdentifier.Protocol == nil {
|
||||
if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) {
|
||||
stream.close()
|
||||
|
@ -59,7 +59,7 @@ func (p *tlsPoller) close() error {
|
||||
return p.chunksReader.Close()
|
||||
}
|
||||
|
||||
func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptions) {
|
||||
func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) {
|
||||
chunks := make(chan *tlsChunk)
|
||||
|
||||
go p.pollChunksPerfBuffer(chunks)
|
||||
@ -71,7 +71,7 @@ func (p *tlsPoller) poll(emitter api.Emitter, options *api.TrafficFilteringOptio
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.handleTlsChunk(chunk, p.extension, emitter, options); err != nil {
|
||||
if err := p.handleTlsChunk(chunk, p.extension, emitter, options, streamsMap); err != nil {
|
||||
LogError(err)
|
||||
}
|
||||
case key := <-p.closedReaders:
|
||||
@ -115,8 +115,8 @@ func (p *tlsPoller) pollChunksPerfBuffer(chunks chan<- *tlsChunk) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
|
||||
emitter api.Emitter, options *api.TrafficFilteringOptions) error {
|
||||
func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, emitter api.Emitter,
|
||||
options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) error {
|
||||
ip, port, err := chunk.getAddress()
|
||||
|
||||
if err != nil {
|
||||
@ -127,7 +127,7 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
|
||||
reader, exists := p.readers[key]
|
||||
|
||||
if !exists {
|
||||
reader = p.startNewTlsReader(chunk, ip, port, key, emitter, extension, options)
|
||||
reader = p.startNewTlsReader(chunk, ip, port, key, emitter, extension, options, streamsMap)
|
||||
p.readers[key] = reader
|
||||
}
|
||||
|
||||
@ -142,7 +142,8 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
|
||||
}
|
||||
|
||||
func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string,
|
||||
emitter api.Emitter, extension *api.Extension, options *api.TrafficFilteringOptions) *tlsReader {
|
||||
emitter api.Emitter, extension *api.Extension, options *api.TrafficFilteringOptions,
|
||||
streamsMap api.TcpStreamMap) *tlsReader {
|
||||
|
||||
tcpid := p.buildTcpId(chunk, ip, port)
|
||||
|
||||
@ -173,6 +174,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, k
|
||||
reader: reader,
|
||||
protoIdentifier: &api.ProtoIdentifier{},
|
||||
}
|
||||
streamsMap.Store(streamsMap.NextId(), stream)
|
||||
|
||||
reader.parent = stream
|
||||
|
||||
|
@ -19,7 +19,7 @@ type tlsReader struct {
|
||||
extension *api.Extension
|
||||
emitter api.Emitter
|
||||
counterPair *api.CounterPair
|
||||
parent api.TcpStream
|
||||
parent *tlsStream
|
||||
reqResMatcher api.RequestResponseMatcher
|
||||
}
|
||||
|
||||
|
@ -19,8 +19,8 @@ func (t *tlsStream) SetProtocol(protocol *api.Protocol) {
|
||||
t.protoIdentifier.Protocol = protocol
|
||||
}
|
||||
|
||||
func (t *tlsStream) GetReqResMatcher() api.RequestResponseMatcher {
|
||||
return t.reader.reqResMatcher
|
||||
func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher {
|
||||
return []api.RequestResponseMatcher{t.reader.reqResMatcher}
|
||||
}
|
||||
|
||||
func (t *tlsStream) GetIsTapTarget() bool {
|
||||
|
@ -50,8 +50,8 @@ func (t *TlsTapper) Init(chunksBufferSize int, logBufferSize int, procfs string,
|
||||
return t.poller.init(&t.bpfObjects, chunksBufferSize)
|
||||
}
|
||||
|
||||
func (t *TlsTapper) Poll(emitter api.Emitter, options *api.TrafficFilteringOptions) {
|
||||
t.poller.poll(emitter, options)
|
||||
func (t *TlsTapper) Poll(emitter api.Emitter, options *api.TrafficFilteringOptions, streamsMap api.TcpStreamMap) {
|
||||
t.poller.poll(emitter, options, streamsMap)
|
||||
}
|
||||
|
||||
func (t *TlsTapper) PollForLogging() {
|
||||
|
Loading…
Reference in New Issue
Block a user