Define TlsChunk interface and make tlsReader implement TcpReader

This commit is contained in:
M. Mert Yildiran
2022-04-26 17:23:51 +03:00
parent 6e9064dc56
commit f10eff3170
5 changed files with 150 additions and 66 deletions

View File

@@ -417,6 +417,7 @@ type TcpReader interface {
Close() Close()
Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup)
SendMsgIfNotClosed(msg TcpReaderDataMsg) SendMsgIfNotClosed(msg TcpReaderDataMsg)
SendChunk(chunk TlsChunk)
GetReqResMatcher() RequestResponseMatcher GetReqResMatcher() RequestResponseMatcher
GetIsClient() bool GetIsClient() bool
GetReadProgress() *ReadProgress GetReadProgress() *ReadProgress
@@ -457,3 +458,13 @@ type TcpStreamMap interface {
NextId() int64 NextId() int64
CloseTimedoutTcpStreamChannels() CloseTimedoutTcpStreamChannels()
} }
type TlsChunk interface {
GetAddress() (net.IP, uint16, error)
IsClient() bool
IsServer() bool
IsRead() bool
IsWrite() bool
GetRecordedData() []byte
IsRequest() bool
}

View File

@@ -108,6 +108,8 @@ func (reader *tcpReader) SendMsgIfNotClosed(msg api.TcpReaderDataMsg) {
reader.Unlock() reader.Unlock()
} }
func (reader *tcpReader) SendChunk(chunk api.TlsChunk) {}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher { func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher return reader.reqResMatcher
} }

View File

@@ -16,18 +16,18 @@ const FLAGS_IS_READ_BIT uint32 = (1 << 1)
// Be careful when editing, alignment and padding should be exactly the same in go/c. // Be careful when editing, alignment and padding should be exactly the same in go/c.
// //
type tlsChunk struct { type tlsChunk struct {
Pid uint32 // process id Pid uint32 // process id
Tgid uint32 // thread id inside the process Tgid uint32 // thread id inside the process
Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[]) Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[])
Start uint32 // the start offset withing the native buffer Start uint32 // the start offset withing the native buffer
Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[] Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[]
Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor) Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor)
Flags uint32 // bitwise flags Flags uint32 // bitwise flags
Address [16]byte // ipv4 address and port Address [16]byte // ipv4 address and port
Data [4096]byte // actual tls data Data [4096]byte // actual tls data
} }
func (c *tlsChunk) getAddress() (net.IP, uint16, error) { func (c *tlsChunk) GetAddress() (net.IP, uint16, error) {
address := bytes.NewReader(c.Address[:]) address := bytes.NewReader(c.Address[:])
var family uint16 var family uint16
var port uint16 var port uint16
@@ -50,26 +50,26 @@ func (c *tlsChunk) getAddress() (net.IP, uint16, error) {
return ip, port, nil return ip, port, nil
} }
func (c *tlsChunk) isClient() bool { func (c *tlsChunk) IsClient() bool {
return c.Flags&FLAGS_IS_CLIENT_BIT != 0 return c.Flags&FLAGS_IS_CLIENT_BIT != 0
} }
func (c *tlsChunk) isServer() bool { func (c *tlsChunk) IsServer() bool {
return !c.isClient() return !c.IsClient()
} }
func (c *tlsChunk) isRead() bool { func (c *tlsChunk) IsRead() bool {
return c.Flags&FLAGS_IS_READ_BIT != 0 return c.Flags&FLAGS_IS_READ_BIT != 0
} }
func (c *tlsChunk) isWrite() bool { func (c *tlsChunk) IsWrite() bool {
return !c.isRead() return !c.IsRead()
} }
func (c *tlsChunk) getRecordedData() []byte { func (c *tlsChunk) GetRecordedData() []byte {
return c.Data[:c.Recorded] return c.Data[:c.Recorded]
} }
func (c *tlsChunk) isRequest() bool { func (c *tlsChunk) IsRequest() bool {
return (c.isClient() && c.isWrite()) || (c.isServer() && c.isRead()) return (c.IsClient() && c.IsWrite()) || (c.IsServer() && c.IsRead())
} }

View File

@@ -24,7 +24,7 @@ import (
type tlsPoller struct { type tlsPoller struct {
tls *TlsTapper tls *TlsTapper
readers map[string]*tlsReader readers map[string]api.TcpReader
closedReaders chan string closedReaders chan string
reqResMatcher api.RequestResponseMatcher reqResMatcher api.RequestResponseMatcher
chunksReader *perf.Reader chunksReader *perf.Reader
@@ -36,7 +36,7 @@ type tlsPoller struct {
func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller { func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller {
return &tlsPoller{ return &tlsPoller{
tls: tls, tls: tls,
readers: make(map[string]*tlsReader), readers: make(map[string]api.TcpReader),
closedReaders: make(chan string, 100), closedReaders: make(chan string, 100),
reqResMatcher: extension.Dissector.NewResponseRequestMatcher(), reqResMatcher: extension.Dissector.NewResponseRequestMatcher(),
extension: extension, extension: extension,
@@ -119,7 +119,7 @@ func (p *tlsPoller) pollChunksPerfBuffer(chunks chan<- *tlsChunk) {
func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension, func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
emitter api.Emitter, options *shared.TrafficFilteringOptions) error { emitter api.Emitter, options *shared.TrafficFilteringOptions) error {
ip, port, err := chunk.getAddress() ip, port, err := chunk.GetAddress()
if err != nil { if err != nil {
return err return err
@@ -128,29 +128,22 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
key := buildTlsKey(chunk, ip, port) key := buildTlsKey(chunk, ip, port)
reader, exists := p.readers[key] reader, exists := p.readers[key]
tcpStream := tcp.NewTcpStreamDummy(api.Ebpf) stream := tcp.NewTcpStreamDummy(api.Ebpf)
tcpReader := tcp.NewTcpReader( tlsReader := NewTlsReader(
make(chan api.TcpReaderDataMsg), key,
reader.progress, func(r *tlsReader) {
"", p.closeReader(key, r)
&api.TcpID{}, },
time.Time{}, stream,
tcpStream,
chunk.isRequest(),
false,
nil,
emitter,
&api.CounterPair{},
p.reqResMatcher,
) )
if !exists { if !exists {
reader = p.startNewTlsReader(chunk, ip, port, key, extension, tcpReader, options) reader = p.startNewTlsReader(chunk, ip, port, key, extension, tlsReader, options)
p.readers[key] = reader p.readers[key] = reader
} }
tcpReader.SetCaptureTime(time.Now()) reader.SetCaptureTime(time.Now())
reader.chunks <- chunk reader.SendChunk(chunk)
if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" { if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" {
p.logTls(chunk, ip, port) p.logTls(chunk, ip, port)
@@ -160,36 +153,28 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
} }
func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension, func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension,
tcpReader api.TcpReader, options *shared.TrafficFilteringOptions) *tlsReader { reader api.TcpReader, options *shared.TrafficFilteringOptions) api.TcpReader {
reader := &tlsReader{
key: key,
chunks: make(chan *tlsChunk, 1),
doneHandler: func(r *tlsReader) {
p.closeReader(key, r)
},
}
tcpid := p.buildTcpId(chunk, ip, port) tcpid := p.buildTcpId(chunk, ip, port)
tcpReader.SetTcpID(&tcpid) reader.SetTcpID(&tcpid)
tcpReader.SetEmitter(&tlsEmitter{ reader.SetEmitter(&tlsEmitter{
delegate: tcpReader.GetEmitter(), delegate: reader.GetEmitter(),
namespace: p.getNamespace(chunk.Pid), namespace: p.getNamespace(chunk.Pid),
}) })
go dissect(extension, reader, tcpReader, options) go dissect(extension, reader, options)
return reader return reader
} }
func dissect(extension *api.Extension, reader *tlsReader, tcpReader api.TcpReader, func dissect(extension *api.Extension, reader api.TcpReader,
options *shared.TrafficFilteringOptions) { options *shared.TrafficFilteringOptions) {
b := bufio.NewReader(reader) b := bufio.NewReader(reader)
err := extension.Dissector.Dissect(b, tcpReader, options) err := extension.Dissector.Dissect(b, reader, options)
if err != nil { if err != nil {
logger.Log.Warningf("Error dissecting TLS %v - %v", tcpReader.GetTcpID(), err) logger.Log.Warningf("Error dissecting TLS %v - %v", reader.GetTcpID(), err)
} }
} }
@@ -199,11 +184,11 @@ func (p *tlsPoller) closeReader(key string, r *tlsReader) {
} }
func buildTlsKey(chunk *tlsChunk, ip net.IP, port uint16) string { func buildTlsKey(chunk *tlsChunk, ip net.IP, port uint16) string {
return fmt.Sprintf("%v:%v-%v:%v", chunk.isClient(), chunk.isRead(), ip, port) return fmt.Sprintf("%v:%v-%v:%v", chunk.IsClient(), chunk.IsRead(), ip, port)
} }
func (p *tlsPoller) buildTcpId(chunk *tlsChunk, ip net.IP, port uint16) api.TcpID { func (p *tlsPoller) buildTcpId(chunk *tlsChunk, ip net.IP, port uint16) api.TcpID {
myIp, myPort, err := getAddressBySockfd(p.procfs, chunk.Pid, chunk.Fd, chunk.isClient()) myIp, myPort, err := getAddressBySockfd(p.procfs, chunk.Pid, chunk.Fd, chunk.IsClient())
if err != nil { if err != nil {
// May happen if the socket already closed, very likely to happen for localhost // May happen if the socket already closed, very likely to happen for localhost
@@ -212,7 +197,7 @@ func (p *tlsPoller) buildTcpId(chunk *tlsChunk, ip net.IP, port uint16) api.TcpI
myPort = api.UnknownPort myPort = api.UnknownPort
} }
if chunk.isRequest() { if chunk.IsRequest() {
return api.TcpID{ return api.TcpID{
SrcIP: myIp.String(), SrcIP: myIp.String(),
DstIP: ip.String(), DstIP: ip.String(),
@@ -261,13 +246,13 @@ func (p *tlsPoller) clearPids() {
func (p *tlsPoller) logTls(chunk *tlsChunk, ip net.IP, port uint16) { func (p *tlsPoller) logTls(chunk *tlsChunk, ip net.IP, port uint16) {
var flagsStr string var flagsStr string
if chunk.isClient() { if chunk.IsClient() {
flagsStr = "C" flagsStr = "C"
} else { } else {
flagsStr = "S" flagsStr = "S"
} }
if chunk.isRead() { if chunk.IsRead() {
flagsStr += "R" flagsStr += "R"
} else { } else {
flagsStr += "W" flagsStr += "W"

View File

@@ -2,21 +2,43 @@ package tlstapper
import ( import (
"io" "io"
"sync"
"time" "time"
"github.com/up9inc/mizu/shared"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
) )
type tlsReader struct { type tlsReader struct {
key string key string
chunks chan *tlsChunk chunks chan api.TlsChunk
data []byte data []byte
doneHandler func(r *tlsReader) doneHandler func(r *tlsReader)
progress *api.ReadProgress progress *api.ReadProgress
tcpID *api.TcpID
isClosed bool
isClient bool
msgQueue chan api.TcpReaderDataMsg // Unused
captureTime time.Time
parent api.TcpStream
packetsSeen uint
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
}
func NewTlsReader(key string, doneHandler func(r *tlsReader), stream api.TcpStream) api.TcpReader {
return &tlsReader{
key: key,
chunks: make(chan api.TlsChunk, 1),
doneHandler: doneHandler,
parent: stream,
}
} }
func (r *tlsReader) Read(p []byte) (int, error) { func (r *tlsReader) Read(p []byte) (int, error) {
var chunk *tlsChunk var chunk api.TlsChunk
for len(r.data) == 0 { for len(r.data) == 0 {
var ok bool var ok bool
@@ -26,7 +48,7 @@ func (r *tlsReader) Read(p []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
} }
r.data = chunk.getRecordedData() r.data = chunk.GetRecordedData()
case <-time.After(time.Second * 3): case <-time.After(time.Second * 3):
r.doneHandler(r) r.doneHandler(r)
return 0, io.EOF return 0, io.EOF
@@ -43,3 +65,67 @@ func (r *tlsReader) Read(p []byte) (int, error) {
return l, nil return l, nil
} }
func (r *tlsReader) Close() {
r.doneHandler(r)
}
func (r *tlsReader) Run(options *shared.TrafficFilteringOptions, wg *sync.WaitGroup) {}
func (r *tlsReader) SendMsgIfNotClosed(msg api.TcpReaderDataMsg) {}
func (r *tlsReader) SendChunk(chunk api.TlsChunk) {
r.chunks <- chunk
}
func (r *tlsReader) GetReqResMatcher() api.RequestResponseMatcher {
return r.reqResMatcher
}
func (r *tlsReader) GetIsClient() bool {
return r.isClient
}
func (r *tlsReader) GetReadProgress() *api.ReadProgress {
return r.progress
}
func (r *tlsReader) GetParent() api.TcpStream {
return r.parent
}
func (r *tlsReader) GetTcpID() *api.TcpID {
return r.tcpID
}
func (r *tlsReader) GetCounterPair() *api.CounterPair {
return r.counterPair
}
func (r *tlsReader) GetCaptureTime() time.Time {
return r.captureTime
}
func (r *tlsReader) GetEmitter() api.Emitter {
return r.emitter
}
func (r *tlsReader) GetIsClosed() bool {
return r.isClosed
}
func (r *tlsReader) GetExtension() *api.Extension {
return r.extension
}
func (r *tlsReader) SetTcpID(tcpID *api.TcpID) {
r.tcpID = tcpID
}
func (r *tlsReader) SetCaptureTime(captureTime time.Time) {
r.captureTime = captureTime
}
func (r *tlsReader) SetEmitter(emitter api.Emitter) {
r.emitter = emitter
}