diff --git a/tap/api/api.go b/tap/api/api.go index 1e3d86e8e..7f3348d86 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -417,6 +417,7 @@ type TcpReader interface { Close() Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) SendMsgIfNotClosed(msg TcpReaderDataMsg) + SendChunk(chunk TlsChunk) GetReqResMatcher() RequestResponseMatcher GetIsClient() bool GetReadProgress() *ReadProgress @@ -457,3 +458,13 @@ type TcpStreamMap interface { NextId() int64 CloseTimedoutTcpStreamChannels() } + +type TlsChunk interface { + GetAddress() (net.IP, uint16, error) + IsClient() bool + IsServer() bool + IsRead() bool + IsWrite() bool + GetRecordedData() []byte + IsRequest() bool +} diff --git a/tap/tcp/tcp_reader.go b/tap/tcp/tcp_reader.go index d19ebe9d6..ccfff2fb0 100644 --- a/tap/tcp/tcp_reader.go +++ b/tap/tcp/tcp_reader.go @@ -108,6 +108,8 @@ func (reader *tcpReader) SendMsgIfNotClosed(msg api.TcpReaderDataMsg) { reader.Unlock() } +func (reader *tcpReader) SendChunk(chunk api.TlsChunk) {} + func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher { return reader.reqResMatcher } diff --git a/tap/tlstapper/chunk.go b/tap/tlstapper/chunk.go index d88d350da..2dce504b6 100644 --- a/tap/tlstapper/chunk.go +++ b/tap/tlstapper/chunk.go @@ -16,18 +16,18 @@ const FLAGS_IS_READ_BIT uint32 = (1 << 1) // Be careful when editing, alignment and padding should be exactly the same in go/c. // type tlsChunk struct { - Pid uint32 // process id - Tgid uint32 // thread id inside the process - Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[]) - Start uint32 // the start offset withing the native buffer - Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[] - Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor) - Flags uint32 // bitwise flags - Address [16]byte // ipv4 address and port + Pid uint32 // process id + Tgid uint32 // thread id inside the process + Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[]) + Start uint32 // the start offset withing the native buffer + Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[] + Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor) + Flags uint32 // bitwise flags + Address [16]byte // ipv4 address and port Data [4096]byte // actual tls data } -func (c *tlsChunk) getAddress() (net.IP, uint16, error) { +func (c *tlsChunk) GetAddress() (net.IP, uint16, error) { address := bytes.NewReader(c.Address[:]) var family uint16 var port uint16 @@ -50,26 +50,26 @@ func (c *tlsChunk) getAddress() (net.IP, uint16, error) { return ip, port, nil } -func (c *tlsChunk) isClient() bool { +func (c *tlsChunk) IsClient() bool { return c.Flags&FLAGS_IS_CLIENT_BIT != 0 } -func (c *tlsChunk) isServer() bool { - return !c.isClient() +func (c *tlsChunk) IsServer() bool { + return !c.IsClient() } -func (c *tlsChunk) isRead() bool { +func (c *tlsChunk) IsRead() bool { return c.Flags&FLAGS_IS_READ_BIT != 0 } -func (c *tlsChunk) isWrite() bool { - return !c.isRead() +func (c *tlsChunk) IsWrite() bool { + return !c.IsRead() } -func (c *tlsChunk) getRecordedData() []byte { +func (c *tlsChunk) GetRecordedData() []byte { return c.Data[:c.Recorded] } -func (c *tlsChunk) isRequest() bool { - return (c.isClient() && c.isWrite()) || (c.isServer() && c.isRead()) +func (c *tlsChunk) IsRequest() bool { + return (c.IsClient() && c.IsWrite()) || (c.IsServer() && c.IsRead()) } diff --git a/tap/tlstapper/tls_poller.go b/tap/tlstapper/tls_poller.go index b54c0a801..9853068c0 100644 --- a/tap/tlstapper/tls_poller.go +++ b/tap/tlstapper/tls_poller.go @@ -24,7 +24,7 @@ import ( type tlsPoller struct { tls *TlsTapper - readers map[string]*tlsReader + readers map[string]api.TcpReader closedReaders chan string reqResMatcher api.RequestResponseMatcher chunksReader *perf.Reader @@ -36,7 +36,7 @@ type tlsPoller struct { func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller { return &tlsPoller{ tls: tls, - readers: make(map[string]*tlsReader), + readers: make(map[string]api.TcpReader), closedReaders: make(chan string, 100), reqResMatcher: extension.Dissector.NewResponseRequestMatcher(), extension: extension, @@ -119,7 +119,7 @@ func (p *tlsPoller) pollChunksPerfBuffer(chunks chan<- *tlsChunk) { func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, emitter api.Emitter, options *shared.TrafficFilteringOptions) error { - ip, port, err := chunk.getAddress() + ip, port, err := chunk.GetAddress() if err != nil { return err @@ -128,29 +128,22 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, key := buildTlsKey(chunk, ip, port) reader, exists := p.readers[key] - tcpStream := tcp.NewTcpStreamDummy(api.Ebpf) - tcpReader := tcp.NewTcpReader( - make(chan api.TcpReaderDataMsg), - reader.progress, - "", - &api.TcpID{}, - time.Time{}, - tcpStream, - chunk.isRequest(), - false, - nil, - emitter, - &api.CounterPair{}, - p.reqResMatcher, + stream := tcp.NewTcpStreamDummy(api.Ebpf) + tlsReader := NewTlsReader( + key, + func(r *tlsReader) { + p.closeReader(key, r) + }, + stream, ) if !exists { - reader = p.startNewTlsReader(chunk, ip, port, key, extension, tcpReader, options) + reader = p.startNewTlsReader(chunk, ip, port, key, extension, tlsReader, options) p.readers[key] = reader } - tcpReader.SetCaptureTime(time.Now()) - reader.chunks <- chunk + reader.SetCaptureTime(time.Now()) + reader.SendChunk(chunk) if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" { p.logTls(chunk, ip, port) @@ -160,36 +153,28 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, } func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension, - tcpReader api.TcpReader, options *shared.TrafficFilteringOptions) *tlsReader { - - reader := &tlsReader{ - key: key, - chunks: make(chan *tlsChunk, 1), - doneHandler: func(r *tlsReader) { - p.closeReader(key, r) - }, - } + reader api.TcpReader, options *shared.TrafficFilteringOptions) api.TcpReader { tcpid := p.buildTcpId(chunk, ip, port) - tcpReader.SetTcpID(&tcpid) + reader.SetTcpID(&tcpid) - tcpReader.SetEmitter(&tlsEmitter{ - delegate: tcpReader.GetEmitter(), + reader.SetEmitter(&tlsEmitter{ + delegate: reader.GetEmitter(), namespace: p.getNamespace(chunk.Pid), }) - go dissect(extension, reader, tcpReader, options) + go dissect(extension, reader, options) return reader } -func dissect(extension *api.Extension, reader *tlsReader, tcpReader api.TcpReader, +func dissect(extension *api.Extension, reader api.TcpReader, options *shared.TrafficFilteringOptions) { b := bufio.NewReader(reader) - err := extension.Dissector.Dissect(b, tcpReader, options) + err := extension.Dissector.Dissect(b, reader, options) if err != nil { - logger.Log.Warningf("Error dissecting TLS %v - %v", tcpReader.GetTcpID(), err) + logger.Log.Warningf("Error dissecting TLS %v - %v", reader.GetTcpID(), err) } } @@ -199,11 +184,11 @@ func (p *tlsPoller) closeReader(key string, r *tlsReader) { } func buildTlsKey(chunk *tlsChunk, ip net.IP, port uint16) string { - return fmt.Sprintf("%v:%v-%v:%v", chunk.isClient(), chunk.isRead(), ip, port) + return fmt.Sprintf("%v:%v-%v:%v", chunk.IsClient(), chunk.IsRead(), ip, port) } func (p *tlsPoller) buildTcpId(chunk *tlsChunk, ip net.IP, port uint16) api.TcpID { - myIp, myPort, err := getAddressBySockfd(p.procfs, chunk.Pid, chunk.Fd, chunk.isClient()) + myIp, myPort, err := getAddressBySockfd(p.procfs, chunk.Pid, chunk.Fd, chunk.IsClient()) if err != nil { // May happen if the socket already closed, very likely to happen for localhost @@ -212,7 +197,7 @@ func (p *tlsPoller) buildTcpId(chunk *tlsChunk, ip net.IP, port uint16) api.TcpI myPort = api.UnknownPort } - if chunk.isRequest() { + if chunk.IsRequest() { return api.TcpID{ SrcIP: myIp.String(), DstIP: ip.String(), @@ -261,13 +246,13 @@ func (p *tlsPoller) clearPids() { func (p *tlsPoller) logTls(chunk *tlsChunk, ip net.IP, port uint16) { var flagsStr string - if chunk.isClient() { + if chunk.IsClient() { flagsStr = "C" } else { flagsStr = "S" } - if chunk.isRead() { + if chunk.IsRead() { flagsStr += "R" } else { flagsStr += "W" diff --git a/tap/tlstapper/tls_reader.go b/tap/tlstapper/tls_reader.go index 908cbf1ad..e49c202ae 100644 --- a/tap/tlstapper/tls_reader.go +++ b/tap/tlstapper/tls_reader.go @@ -2,21 +2,43 @@ package tlstapper import ( "io" + "sync" "time" + "github.com/up9inc/mizu/shared" "github.com/up9inc/mizu/tap/api" ) type tlsReader struct { - key string - chunks chan *tlsChunk - data []byte - doneHandler func(r *tlsReader) - progress *api.ReadProgress + key string + chunks chan api.TlsChunk + data []byte + doneHandler func(r *tlsReader) + progress *api.ReadProgress + tcpID *api.TcpID + isClosed bool + isClient bool + msgQueue chan api.TcpReaderDataMsg // Unused + captureTime time.Time + parent api.TcpStream + packetsSeen uint + extension *api.Extension + emitter api.Emitter + counterPair *api.CounterPair + reqResMatcher api.RequestResponseMatcher +} + +func NewTlsReader(key string, doneHandler func(r *tlsReader), stream api.TcpStream) api.TcpReader { + return &tlsReader{ + key: key, + chunks: make(chan api.TlsChunk, 1), + doneHandler: doneHandler, + parent: stream, + } } func (r *tlsReader) Read(p []byte) (int, error) { - var chunk *tlsChunk + var chunk api.TlsChunk for len(r.data) == 0 { var ok bool @@ -26,7 +48,7 @@ func (r *tlsReader) Read(p []byte) (int, error) { return 0, io.EOF } - r.data = chunk.getRecordedData() + r.data = chunk.GetRecordedData() case <-time.After(time.Second * 3): r.doneHandler(r) return 0, io.EOF @@ -43,3 +65,67 @@ func (r *tlsReader) Read(p []byte) (int, error) { return l, nil } + +func (r *tlsReader) Close() { + r.doneHandler(r) +} + +func (r *tlsReader) Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) {} + +func (r *tlsReader) SendMsgIfNotClosed(msg api.TcpReaderDataMsg) {} + +func (r *tlsReader) SendChunk(chunk api.TlsChunk) { + r.chunks <- chunk +} + +func (r *tlsReader) GetReqResMatcher() api.RequestResponseMatcher { + return r.reqResMatcher +} + +func (r *tlsReader) GetIsClient() bool { + return r.isClient +} + +func (r *tlsReader) GetReadProgress() *api.ReadProgress { + return r.progress +} + +func (r *tlsReader) GetParent() api.TcpStream { + return r.parent +} + +func (r *tlsReader) GetTcpID() *api.TcpID { + return r.tcpID +} + +func (r *tlsReader) GetCounterPair() *api.CounterPair { + return r.counterPair +} + +func (r *tlsReader) GetCaptureTime() time.Time { + return r.captureTime +} + +func (r *tlsReader) GetEmitter() api.Emitter { + return r.emitter +} + +func (r *tlsReader) GetIsClosed() bool { + return r.isClosed +} + +func (r *tlsReader) GetExtension() *api.Extension { + return r.extension +} + +func (r *tlsReader) SetTcpID(tcpID *api.TcpID) { + r.tcpID = tcpID +} + +func (r *tlsReader) SetCaptureTime(captureTime time.Time) { + r.captureTime = captureTime +} + +func (r *tlsReader) SetEmitter(emitter api.Emitter) { + r.emitter = emitter +}