diff --git a/tap/api/api.go b/tap/api/api.go index 74929210b..586ff592b 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -104,11 +104,6 @@ type OutputChannelItem struct { Namespace string } -type ProtoIdentifier struct { - Protocol *Protocol - IsClosedOthers bool -} - type ReadProgress struct { readBytes int lastCurrent int @@ -123,6 +118,11 @@ func (p *ReadProgress) Current() (n int) { return p.lastCurrent } +func (p *ReadProgress) Reset() { + p.readBytes = 0 + p.lastCurrent = 0 +} + type Dissector interface { Register(*Extension) Ping() @@ -419,13 +419,12 @@ type TcpReader interface { GetCaptureTime() time.Time GetEmitter() Emitter GetIsClosed() bool - GetExtension() *Extension } type TcpStream interface { SetProtocol(protocol *Protocol) GetOrigin() Capture - GetProtoIdentifier() *ProtoIdentifier + GetProtocol() *Protocol GetReqResMatchers() []RequestResponseMatcher GetIsTapTarget() bool GetIsClosed() bool diff --git a/tap/extensions/amqp/main.go b/tap/extensions/amqp/main.go index a1c0dd9aa..6d024287d 100644 --- a/tap/extensions/amqp/main.go +++ b/tap/extensions/amqp/main.go @@ -75,14 +75,14 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api. var lastMethodFrameMessage Message for { - if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &protocol { + if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &protocol { return errors.New("Identified by another protocol") } frame, err := r.ReadFrame() if err == io.EOF { // We must read until we see an EOF... very important! - return nil + return err } switch f := frame.(type) { diff --git a/tap/extensions/amqp/tcp_reader_mock_test.go b/tap/extensions/amqp/tcp_reader_mock_test.go index dd37fc7a4..3081e449e 100644 --- a/tap/extensions/amqp/tcp_reader_mock_test.go +++ b/tap/extensions/amqp/tcp_reader_mock_test.go @@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter { func (reader *tcpReader) GetIsClosed() bool { return reader.isClosed } - -func (reader *tcpReader) GetExtension() *api.Extension { - return reader.extension -} diff --git a/tap/extensions/amqp/tcp_stream_mock_test.go b/tap/extensions/amqp/tcp_stream_mock_test.go index ae68e5982..29138a2e1 100644 --- a/tap/extensions/amqp/tcp_stream_mock_test.go +++ b/tap/extensions/amqp/tcp_stream_mock_test.go @@ -7,18 +7,17 @@ import ( ) type tcpStream struct { - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - origin api.Capture - reqResMatchers []api.RequestResponseMatcher + isClosed bool + protocol *api.Protocol + isTapTarget bool + origin api.Capture + reqResMatchers []api.RequestResponseMatcher sync.Mutex } func NewTcpStream(capture api.Capture) api.TcpStream { return &tcpStream{ - origin: capture, - protoIdentifier: &api.ProtoIdentifier{}, + origin: capture, } } @@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture { return t.origin } -func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tcpStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { diff --git a/tap/extensions/http/main.go b/tap/extensions/http/main.go index cf0e1fac6..19f26c122 100644 --- a/tap/extensions/http/main.go +++ b/tap/extensions/http/main.go @@ -144,7 +144,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api. http2Assembler = createHTTP2Assembler(b) } - if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &http11protocol { + if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &http11protocol { return errors.New("Identified by another protocol") } @@ -200,7 +200,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api. } } - if reader.GetParent().GetProtoIdentifier().Protocol == nil { + if reader.GetParent().GetProtocol() == nil { return err } diff --git a/tap/extensions/http/tcp_reader_mock_test.go b/tap/extensions/http/tcp_reader_mock_test.go index bad15a0fd..87baf4293 100644 --- a/tap/extensions/http/tcp_reader_mock_test.go +++ b/tap/extensions/http/tcp_reader_mock_test.go @@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter { func (reader *tcpReader) GetIsClosed() bool { return reader.isClosed } - -func (reader *tcpReader) GetExtension() *api.Extension { - return reader.extension -} diff --git a/tap/extensions/http/tcp_stream_mock_test.go b/tap/extensions/http/tcp_stream_mock_test.go index ca1b5ee8a..9d3342364 100644 --- a/tap/extensions/http/tcp_stream_mock_test.go +++ b/tap/extensions/http/tcp_stream_mock_test.go @@ -7,18 +7,17 @@ import ( ) type tcpStream struct { - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - origin api.Capture - reqResMatchers []api.RequestResponseMatcher + isClosed bool + protocol *api.Protocol + isTapTarget bool + origin api.Capture + reqResMatchers []api.RequestResponseMatcher sync.Mutex } func NewTcpStream(capture api.Capture) api.TcpStream { return &tcpStream{ - origin: capture, - protoIdentifier: &api.ProtoIdentifier{}, + origin: capture, } } @@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture { return t.origin } -func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tcpStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { diff --git a/tap/extensions/kafka/main.go b/tap/extensions/kafka/main.go index 0996ad5d1..750ca65ed 100644 --- a/tap/extensions/kafka/main.go +++ b/tap/extensions/kafka/main.go @@ -38,7 +38,7 @@ func (d dissecting) Ping() { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error { reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher) for { - if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &_protocol { + if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &_protocol { return errors.New("Identified by another protocol") } diff --git a/tap/extensions/kafka/tcp_reader_mock_test.go b/tap/extensions/kafka/tcp_reader_mock_test.go index 0e9355b92..9bcc5619b 100644 --- a/tap/extensions/kafka/tcp_reader_mock_test.go +++ b/tap/extensions/kafka/tcp_reader_mock_test.go @@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter { func (reader *tcpReader) GetIsClosed() bool { return reader.isClosed } - -func (reader *tcpReader) GetExtension() *api.Extension { - return reader.extension -} diff --git a/tap/extensions/kafka/tcp_stream_mock_test.go b/tap/extensions/kafka/tcp_stream_mock_test.go index d53006e88..9a99d42b6 100644 --- a/tap/extensions/kafka/tcp_stream_mock_test.go +++ b/tap/extensions/kafka/tcp_stream_mock_test.go @@ -7,18 +7,17 @@ import ( ) type tcpStream struct { - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - origin api.Capture - reqResMatchers []api.RequestResponseMatcher + isClosed bool + protocol *api.Protocol + isTapTarget bool + origin api.Capture + reqResMatchers []api.RequestResponseMatcher sync.Mutex } func NewTcpStream(capture api.Capture) api.TcpStream { return &tcpStream{ - origin: capture, - protoIdentifier: &api.ProtoIdentifier{}, + origin: capture, } } @@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture { return t.origin } -func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tcpStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { diff --git a/tap/extensions/redis/tcp_reader_mock_test.go b/tap/extensions/redis/tcp_reader_mock_test.go index 223d7fd9a..6b7f3618e 100644 --- a/tap/extensions/redis/tcp_reader_mock_test.go +++ b/tap/extensions/redis/tcp_reader_mock_test.go @@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter { func (reader *tcpReader) GetIsClosed() bool { return reader.isClosed } - -func (reader *tcpReader) GetExtension() *api.Extension { - return reader.extension -} diff --git a/tap/extensions/redis/tcp_stream_mock_test.go b/tap/extensions/redis/tcp_stream_mock_test.go index 304c85da0..450fe7575 100644 --- a/tap/extensions/redis/tcp_stream_mock_test.go +++ b/tap/extensions/redis/tcp_stream_mock_test.go @@ -7,18 +7,17 @@ import ( ) type tcpStream struct { - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - origin api.Capture - reqResMatchers []api.RequestResponseMatcher + isClosed bool + protocol *api.Protocol + isTapTarget bool + origin api.Capture + reqResMatchers []api.RequestResponseMatcher sync.Mutex } func NewTcpStream(capture api.Capture) api.TcpStream { return &tcpStream{ - origin: capture, - protoIdentifier: &api.ProtoIdentifier{}, + origin: capture, } } @@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture { return t.origin } -func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tcpStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { diff --git a/tap/tcp_reader.go b/tap/tcp_reader.go index 5af4b828d..e15ae87db 100644 --- a/tap/tcp_reader.go +++ b/tap/tcp_reader.go @@ -3,11 +3,9 @@ package tap import ( "bufio" "io" - "io/ioutil" "sync" "time" - "github.com/up9inc/mizu/logger" "github.com/up9inc/mizu/tap/api" ) @@ -17,50 +15,48 @@ import ( * Implements io.Reader interface (Read) */ type tcpReader struct { - ident string - tcpID *api.TcpID - isClosed bool - isClient bool - isOutgoing bool - msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload - data []byte - progress *api.ReadProgress - captureTime time.Time - parent *tcpStream - packetsSeen uint - extension *api.Extension - emitter api.Emitter - counterPair *api.CounterPair - reqResMatcher api.RequestResponseMatcher + ident string + tcpID *api.TcpID + isClosed bool + isClient bool + isOutgoing bool + msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload + msgBuffer []api.TcpReaderDataMsg + msgBufferMaster []api.TcpReaderDataMsg + data []byte + progress *api.ReadProgress + captureTime time.Time + parent *tcpStream + emitter api.Emitter + counterPair *api.CounterPair + reqResMatcher api.RequestResponseMatcher sync.Mutex } -func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent *tcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) *tcpReader { +func NewTcpReader(ident string, tcpId *api.TcpID, parent *tcpStream, isClient bool, isOutgoing bool, emitter api.Emitter) *tcpReader { return &tcpReader{ - msgQueue: msgQueue, - progress: progress, - ident: ident, - tcpID: tcpId, - captureTime: captureTime, - parent: parent, - isClient: isClient, - isOutgoing: isOutgoing, - extension: extension, - emitter: emitter, - counterPair: counterPair, - reqResMatcher: reqResMatcher, + msgQueue: make(chan api.TcpReaderDataMsg), + progress: &api.ReadProgress{}, + ident: ident, + tcpID: tcpId, + parent: parent, + isClient: isClient, + isOutgoing: isOutgoing, + emitter: emitter, } } func (reader *tcpReader) run(options *api.TrafficFilteringOptions, wg *sync.WaitGroup) { defer wg.Done() - b := bufio.NewReader(reader) - err := reader.extension.Dissector.Dissect(b, reader, options) - if err != nil { - _, err = io.Copy(ioutil.Discard, reader) - if err != nil { - logger.Log.Errorf("%v", err) + for i, extension := range extensions { + reader.reqResMatcher = reader.parent.reqResMatchers[i] + reader.counterPair = reader.parent.counterPairs[i] + b := bufio.NewReader(reader) + extension.Dissector.Dissect(b, reader, options) //nolint + if reader.isProtocolIdentified() { + break } + reader.rewind() } } @@ -81,21 +77,56 @@ func (reader *tcpReader) sendMsgIfNotClosed(msg api.TcpReaderDataMsg) { reader.Unlock() } +func (reader *tcpReader) isProtocolIdentified() bool { + return reader.parent.protocol != nil +} + +func (reader *tcpReader) rewind() { + // Reset the data and msgBuffer from the master record + reader.data = make([]byte, 0) + reader.msgBuffer = make([]api.TcpReaderDataMsg, len(reader.msgBufferMaster)) + copy(reader.msgBuffer, reader.msgBufferMaster) + + // Reset the read progress + reader.progress.Reset() +} + +func (reader *tcpReader) populateData(msg api.TcpReaderDataMsg) { + reader.data = msg.GetBytes() + reader.captureTime = msg.GetTimestamp() +} + func (reader *tcpReader) Read(p []byte) (int, error) { var msg api.TcpReaderDataMsg + for len(reader.msgBuffer) > 0 && len(reader.data) == 0 { + // Pop first message + if len(reader.msgBuffer) > 1 { + msg, reader.msgBuffer = reader.msgBuffer[0], reader.msgBuffer[1:] + } else { + msg = reader.msgBuffer[0] + reader.msgBuffer = make([]api.TcpReaderDataMsg, 0) + } + + // Get the bytes + reader.populateData(msg) + } + ok := true for ok && len(reader.data) == 0 { msg, ok = <-reader.msgQueue if msg != nil { - reader.data = msg.GetBytes() - reader.captureTime = msg.GetTimestamp() - } + reader.populateData(msg) - if len(reader.data) > 0 { - reader.packetsSeen += 1 + if !reader.isProtocolIdentified() { + reader.msgBufferMaster = append( + reader.msgBufferMaster, + msg, + ) + } } } + if !ok || len(reader.data) == 0 { return 0, io.EOF } @@ -142,7 +173,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter { func (reader *tcpReader) GetIsClosed() bool { return reader.isClosed } - -func (reader *tcpReader) GetExtension() *api.Extension { - return reader.extension -} diff --git a/tap/tcp_reassembly_stream.go b/tap/tcp_reassembly_stream.go index 16ed07afe..c5a9e864b 100644 --- a/tap/tcp_reassembly_stream.go +++ b/tap/tcp_reassembly_stream.go @@ -6,7 +6,6 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/reassembly" - "github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/diagnose" ) @@ -16,10 +15,10 @@ type tcpReassemblyStream struct { fsmerr bool optchecker reassembly.TCPOptionCheck isDNS bool - tcpStream api.TcpStream + tcpStream *tcpStream } -func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream api.TcpStream) reassembly.Stream { +func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream *tcpStream) reassembly.Stream { return &tcpReassemblyStream{ ident: ident, tcpState: reassembly.NewTCPSimpleFSM(fsmOptions), @@ -139,17 +138,10 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas // This channel is read by an tcpReader object diagnose.AppStats.IncReassembledTcpPayloadsCount() timestamp := ac.GetCaptureInfo().Timestamp - stream := t.tcpStream.(*tcpStream) if dir == reassembly.TCPDirClientToServer { - for i := range stream.getClients() { - reader := stream.getClient(i) - reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) - } + t.tcpStream.client.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } else { - for i := range stream.getServers() { - reader := stream.getServer(i) - reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) - } + t.tcpStream.server.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp)) } } } @@ -157,7 +149,7 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() { - t.tcpStream.(*tcpStream).close() + t.tcpStream.close() } // do not remove the connection to allow last ACK return false diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index 753d00b20..ca99da6c6 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -13,25 +13,26 @@ import ( * In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel. */ type tcpStream struct { - id int64 - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - clients []*tcpReader - servers []*tcpReader - origin api.Capture - reqResMatchers []api.RequestResponseMatcher - createdAt time.Time - streamsMap api.TcpStreamMap + id int64 + isClosed bool + protocol *api.Protocol + isTapTarget bool + client *tcpReader + server *tcpReader + origin api.Capture + counterPairs []*api.CounterPair + reqResMatchers []api.RequestResponseMatcher + createdAt time.Time + streamsMap api.TcpStreamMap sync.Mutex } func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream { return &tcpStream{ - isTapTarget: isTapTarget, - protoIdentifier: &api.ProtoIdentifier{}, - streamsMap: streamsMap, - origin: capture, + isTapTarget: isTapTarget, + streamsMap: streamsMap, + origin: capture, + createdAt: time.Now(), } } @@ -55,38 +56,12 @@ func (t *tcpStream) close() { t.streamsMap.Delete(t.id) - for i := range t.clients { - reader := t.clients[i] - reader.close() - } - for i := range t.servers { - reader := t.servers[i] - reader.close() - } + t.client.close() + t.server.close() } -func (t *tcpStream) addClient(reader *tcpReader) { - t.clients = append(t.clients, reader) -} - -func (t *tcpStream) addServer(reader *tcpReader) { - t.servers = append(t.servers, reader) -} - -func (t *tcpStream) getClients() []*tcpReader { - return t.clients -} - -func (t *tcpStream) getServers() []*tcpReader { - return t.servers -} - -func (t *tcpStream) getClient(index int) *tcpReader { - return t.clients[index] -} - -func (t *tcpStream) getServer(index int) *tcpReader { - return t.servers[index] +func (t *tcpStream) addCounterPair(counterPair *api.CounterPair) { + t.counterPairs = append(t.counterPairs, counterPair) } func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) { @@ -94,37 +69,19 @@ func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) { } func (t *tcpStream) SetProtocol(protocol *api.Protocol) { - t.Lock() - defer t.Unlock() + t.protocol = protocol - if t.protoIdentifier.IsClosedOthers { - return - } - - t.protoIdentifier.Protocol = protocol - - for i := range t.clients { - reader := t.clients[i] - if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { - reader.close() - } - } - for i := range t.servers { - reader := t.servers[i] - if reader.GetExtension().Protocol != t.protoIdentifier.Protocol { - reader.close() - } - } - - t.protoIdentifier.IsClosedOthers = true + // Clean the buffers + t.client.msgBufferMaster = make([]api.TcpReaderDataMsg, 0) + t.server.msgBufferMaster = make([]api.TcpReaderDataMsg, 0) } func (t *tcpStream) GetOrigin() api.Capture { return t.origin } -func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tcpStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 7803d9ec2..a6ced5712 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -3,7 +3,6 @@ package tap import ( "fmt" "sync" - "time" "github.com/up9inc/mizu/logger" "github.com/up9inc/mizu/tap/api" @@ -62,62 +61,50 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream) if stream.GetIsTapTarget() { stream.setId(factory.streamsMap.NextId()) - for i, extension := range extensions { - reqResMatcher := extension.Dissector.NewResponseRequestMatcher() - stream.addReqResMatcher(reqResMatcher) + for _, extension := range extensions { counterPair := &api.CounterPair{ Request: 0, Response: 0, } - stream.addClient( - NewTcpReader( - make(chan api.TcpReaderDataMsg), - &api.ReadProgress{}, - fmt.Sprintf("%s %s", net, transport), - &api.TcpID{ - SrcIP: srcIp, - DstIP: dstIp, - SrcPort: srcPort, - DstPort: dstPort, - }, - time.Time{}, - stream, - true, - props.isOutgoing, - extension, - factory.emitter, - counterPair, - reqResMatcher, - ), - ) - stream.addServer( - NewTcpReader( - make(chan api.TcpReaderDataMsg), - &api.ReadProgress{}, - fmt.Sprintf("%s %s", net, transport), - &api.TcpID{ - SrcIP: net.Dst().String(), - DstIP: net.Src().String(), - SrcPort: transport.Dst().String(), - DstPort: transport.Src().String(), - }, - time.Time{}, - stream, - false, - props.isOutgoing, - extension, - factory.emitter, - counterPair, - reqResMatcher, - ), - ) + stream.addCounterPair(counterPair) - factory.streamsMap.Store(stream.getId(), stream) - - factory.wg.Add(2) - go stream.getClient(i).run(filteringOptions, &factory.wg) - go stream.getServer(i).run(filteringOptions, &factory.wg) + reqResMatcher := extension.Dissector.NewResponseRequestMatcher() + stream.addReqResMatcher(reqResMatcher) } + + stream.client = NewTcpReader( + fmt.Sprintf("%s %s", net, transport), + &api.TcpID{ + SrcIP: srcIp, + DstIP: dstIp, + SrcPort: srcPort, + DstPort: dstPort, + }, + stream, + true, + props.isOutgoing, + factory.emitter, + ) + + stream.server = NewTcpReader( + fmt.Sprintf("%s %s", net, transport), + &api.TcpID{ + SrcIP: net.Dst().String(), + DstIP: net.Src().String(), + SrcPort: transport.Dst().String(), + DstPort: transport.Src().String(), + }, + stream, + false, + props.isOutgoing, + factory.emitter, + ) + + factory.streamsMap.Store(stream.getId(), stream) + + factory.wg.Add(2) + go stream.client.run(filteringOptions, &factory.wg) + go stream.server.run(filteringOptions, &factory.wg) } return reassemblyStream } diff --git a/tap/tcp_streams_map.go b/tap/tcp_streams_map.go index 9133b46c3..fc96aa5c7 100644 --- a/tap/tcp_streams_map.go +++ b/tap/tcp_streams_map.go @@ -57,7 +57,7 @@ func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() { return true } - if stream.protoIdentifier.Protocol == nil { + if stream.protocol == nil { if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) { stream.close() diagnose.AppStats.IncDroppedTcpStreams() diff --git a/tap/tlstapper/tls_poller.go b/tap/tlstapper/tls_poller.go index faba94529..23643fab1 100644 --- a/tap/tlstapper/tls_poller.go +++ b/tap/tlstapper/tls_poller.go @@ -188,8 +188,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, address *addressPair, key } stream := &tlsStream{ - reader: reader, - protoIdentifier: &api.ProtoIdentifier{}, + reader: reader, } streamsMap.Store(streamsMap.NextId(), stream) diff --git a/tap/tlstapper/tls_reader.go b/tap/tlstapper/tls_reader.go index 7856b057f..b76a0008d 100644 --- a/tap/tlstapper/tls_reader.go +++ b/tap/tlstapper/tls_reader.go @@ -94,7 +94,3 @@ func (r *tlsReader) GetEmitter() api.Emitter { func (r *tlsReader) GetIsClosed() bool { return false } - -func (r *tlsReader) GetExtension() *api.Extension { - return r.extension -} diff --git a/tap/tlstapper/tls_stream.go b/tap/tlstapper/tls_stream.go index 09c447f13..4f9f02c15 100644 --- a/tap/tlstapper/tls_stream.go +++ b/tap/tlstapper/tls_stream.go @@ -3,20 +3,20 @@ package tlstapper import "github.com/up9inc/mizu/tap/api" type tlsStream struct { - reader *tlsReader - protoIdentifier *api.ProtoIdentifier + reader *tlsReader + protocol *api.Protocol } func (t *tlsStream) GetOrigin() api.Capture { return api.Ebpf } -func (t *tlsStream) GetProtoIdentifier() *api.ProtoIdentifier { - return t.protoIdentifier +func (t *tlsStream) GetProtocol() *api.Protocol { + return t.protocol } func (t *tlsStream) SetProtocol(protocol *api.Protocol) { - t.protoIdentifier.Protocol = protocol + t.protocol = protocol } func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher {