diff --git a/tap/tcp_reader.go b/tap/tcp_reader.go index 85dd1b46b..5af4b828d 100644 --- a/tap/tcp_reader.go +++ b/tap/tcp_reader.go @@ -26,7 +26,7 @@ type tcpReader struct { data []byte progress *api.ReadProgress captureTime time.Time - parent api.TcpStream + parent *tcpStream packetsSeen uint extension *api.Extension emitter api.Emitter @@ -35,7 +35,7 @@ type tcpReader struct { sync.Mutex } -func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader { +func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent *tcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) *tcpReader { return &tcpReader{ msgQueue: msgQueue, progress: progress, diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index 5c5914ff1..d67d639fd 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -26,7 +26,7 @@ type tcpStream struct { sync.Mutex } -func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) api.TcpStream { +func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream { return &tcpStream{ isTapTarget: isTapTarget, protoIdentifier: &api.ProtoIdentifier{}, diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index f33389a5a..09d826f15 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -61,15 +61,14 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay stream := NewTcpStream(isTapTarget, factory.streamsMap, getPacketOrigin(ac)) reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream) if stream.GetIsTapTarget() { - _stream := stream.(*tcpStream) - _stream.setId(factory.streamsMap.NextId()) + stream.setId(factory.streamsMap.NextId()) for i, extension := range extensions { reqResMatcher := extension.Dissector.NewResponseRequestMatcher() counterPair := &api.CounterPair{ Request: 0, Response: 0, } - _stream.addClient( + stream.addClient( NewTcpReader( make(chan api.TcpReaderDataMsg), &api.ReadProgress{}, @@ -90,7 +89,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay reqResMatcher, ), ) - _stream.addServer( + stream.addServer( NewTcpReader( make(chan api.TcpReaderDataMsg), &api.ReadProgress{}, @@ -112,11 +111,11 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay ), ) - factory.streamsMap.Store(stream.(*tcpStream).getId(), stream) + factory.streamsMap.Store(stream.getId(), stream) factory.wg.Add(2) - go _stream.getClient(i).(*tcpReader).run(filteringOptions, &factory.wg) - go _stream.getServer(i).(*tcpReader).run(filteringOptions, &factory.wg) + go stream.getClient(i).(*tcpReader).run(filteringOptions, &factory.wg) + go stream.getServer(i).(*tcpReader).run(filteringOptions, &factory.wg) } } return reassemblyStream diff --git a/tap/tlstapper/tls_poller.go b/tap/tlstapper/tls_poller.go index 50ddea510..86e06dad9 100644 --- a/tap/tlstapper/tls_poller.go +++ b/tap/tlstapper/tls_poller.go @@ -21,34 +21,25 @@ import ( ) type tlsPoller struct { - tls *TlsTapper - readers map[string]api.TcpReader - closedReaders chan string - reqResMatcher api.RequestResponseMatcher - chunksReader *perf.Reader - extension *api.Extension - procfs string - pidToNamespace sync.Map - isClosed bool - protoIdentifier *api.ProtoIdentifier - isTapTarget bool - origin api.Capture - createdAt time.Time + tls *TlsTapper + readers map[string]*tlsReader + closedReaders chan string + reqResMatcher api.RequestResponseMatcher + chunksReader *perf.Reader + extension *api.Extension + procfs string + pidToNamespace sync.Map } func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller { return &tlsPoller{ - tls: tls, - readers: make(map[string]api.TcpReader), - closedReaders: make(chan string, 100), - reqResMatcher: extension.Dissector.NewResponseRequestMatcher(), - extension: extension, - chunksReader: nil, - procfs: procfs, - protoIdentifier: &api.ProtoIdentifier{}, - isTapTarget: true, - origin: api.Ebpf, - createdAt: time.Now(), + tls: tls, + readers: make(map[string]*tlsReader), + closedReaders: make(chan string, 100), + reqResMatcher: extension.Dissector.NewResponseRequestMatcher(), + extension: extension, + chunksReader: nil, + procfs: procfs, } } @@ -135,24 +126,13 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, key := buildTlsKey(chunk, ip, port) reader, exists := p.readers[key] - newReader := NewTlsReader( - key, - func(r *tlsReader) { - p.closeReader(key, r) - }, - chunk.isRequest(), - p, - ) - if !exists { - reader = p.startNewTlsReader(chunk, ip, port, key, extension, newReader, options) + reader = p.startNewTlsReader(chunk, ip, port, key, emitter, extension, options) p.readers[key] = reader } - tlsReader := reader.(*tlsReader) - - tlsReader.setCaptureTime(time.Now()) - tlsReader.sendChunk(chunk) + reader.captureTime = time.Now() + reader.chunks <- chunk if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" { p.logTls(chunk, ip, port) @@ -161,25 +141,46 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, return nil } -func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension, - reader api.TcpReader, options *api.TrafficFilteringOptions) api.TcpReader { +func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, + emitter api.Emitter, extension *api.Extension, options *api.TrafficFilteringOptions) *tlsReader { tcpid := p.buildTcpId(chunk, ip, port) - tlsReader := reader.(*tlsReader) - tlsReader.setTcpID(&tcpid) + doneHandler := func(r *tlsReader) { + p.closeReader(key, r) + } - tlsReader.setEmitter(&tlsEmitter{ - delegate: reader.GetEmitter(), + tlsEmitter := &tlsEmitter{ + delegate: emitter, namespace: p.getNamespace(chunk.Pid), - }) + } + + reader := &tlsReader{ + key: key, + chunks: make(chan *tlsChunk, 1), + doneHandler: doneHandler, + progress: &api.ReadProgress{}, + tcpID: &tcpid, + isClient: chunk.isRequest(), + captureTime: time.Now(), + extension: extension, + emitter: tlsEmitter, + counterPair: &api.CounterPair{}, + reqResMatcher: p.reqResMatcher, + } + + stream := &tlsStream{ + reader: reader, + protoIdentifier: &api.ProtoIdentifier{}, + } + + reader.parent = stream go dissect(extension, reader, options) return reader } -func dissect(extension *api.Extension, reader api.TcpReader, - options *api.TrafficFilteringOptions) { +func dissect(extension *api.Extension, reader *tlsReader, options *api.TrafficFilteringOptions) { b := bufio.NewReader(reader) err := extension.Dissector.Dissect(b, reader, options) @@ -279,27 +280,3 @@ func (p *tlsPoller) logTls(chunk *tlsChunk, ip net.IP, port uint16) { srcIp, srcPort, dstIp, dstPort, chunk.Recorded, chunk.Len, chunk.Start, str, hex.EncodeToString(chunk.Data[0:chunk.Recorded])) } - -func (p *tlsPoller) SetProtocol(protocol *api.Protocol) { - // TODO: Implement -} - -func (p *tlsPoller) GetOrigin() api.Capture { - return p.origin -} - -func (p *tlsPoller) GetProtoIdentifier() *api.ProtoIdentifier { - return p.protoIdentifier -} - -func (p *tlsPoller) GetReqResMatcher() api.RequestResponseMatcher { - return p.reqResMatcher -} - -func (p *tlsPoller) GetIsTapTarget() bool { - return p.isTapTarget -} - -func (p *tlsPoller) GetIsClosed() bool { - return p.isClosed -} diff --git a/tap/tlstapper/tls_reader.go b/tap/tlstapper/tls_reader.go index ed2e6fc00..45ebf0ce2 100644 --- a/tap/tlstapper/tls_reader.go +++ b/tap/tlstapper/tls_reader.go @@ -14,41 +14,15 @@ type tlsReader struct { doneHandler func(r *tlsReader) progress *api.ReadProgress tcpID *api.TcpID - isClosed bool isClient bool captureTime time.Time - parent api.TcpStream extension *api.Extension emitter api.Emitter counterPair *api.CounterPair + parent api.TcpStream reqResMatcher api.RequestResponseMatcher } -func NewTlsReader(key string, doneHandler func(r *tlsReader), isClient bool, stream api.TcpStream) api.TcpReader { - return &tlsReader{ - key: key, - chunks: make(chan *tlsChunk, 1), - doneHandler: doneHandler, - parent: stream, - } -} - -func (r *tlsReader) sendChunk(chunk *tlsChunk) { - r.chunks <- chunk -} - -func (r *tlsReader) setTcpID(tcpID *api.TcpID) { - r.tcpID = tcpID -} - -func (r *tlsReader) setCaptureTime(captureTime time.Time) { - r.captureTime = captureTime -} - -func (r *tlsReader) setEmitter(emitter api.Emitter) { - r.emitter = emitter -} - func (r *tlsReader) Read(p []byte) (int, error) { var chunk *tlsChunk @@ -111,7 +85,7 @@ func (r *tlsReader) GetEmitter() api.Emitter { } func (r *tlsReader) GetIsClosed() bool { - return r.isClosed + return false } func (r *tlsReader) GetExtension() *api.Extension { diff --git a/tap/tlstapper/tls_stream.go b/tap/tlstapper/tls_stream.go new file mode 100644 index 000000000..99134daa5 --- /dev/null +++ b/tap/tlstapper/tls_stream.go @@ -0,0 +1,32 @@ +package tlstapper + +import "github.com/up9inc/mizu/tap/api" + +type tlsStream struct { + reader *tlsReader + protoIdentifier *api.ProtoIdentifier +} + +func (t *tlsStream) GetOrigin() api.Capture { + return api.Ebpf +} + +func (t *tlsStream) GetProtoIdentifier() *api.ProtoIdentifier { + return t.protoIdentifier +} + +func (t *tlsStream) SetProtocol(protocol *api.Protocol) { + t.protoIdentifier.Protocol = protocol +} + +func (t *tlsStream) GetReqResMatcher() api.RequestResponseMatcher { + return t.reader.reqResMatcher +} + +func (t *tlsStream) GetIsTapTarget() bool { + return true +} + +func (t *tlsStream) GetIsClosed() bool { + return false +}