Add golangReader struct and implement its Read method

This commit is contained in:
M. Mert Yildiran
2022-06-01 15:27:54 +03:00
parent 68dec6cbbc
commit 52c4b18a9d
3 changed files with 145 additions and 78 deletions

View File

@@ -1,18 +1,27 @@
package tlstapper
import "github.com/up9inc/mizu/tap/api"
type golangConnection struct {
Pid uint32
ConnAddr uint32
AddressPair addressPair
Requests [][]byte
Responses [][]byte
Gzipped bool
Pid uint32
ConnAddr uint32
AddressPair addressPair
Requests [][]byte
Responses [][]byte
Gzipped bool
Stream *tlsStream
ClientReader *golangReader
ServerReader *golangReader
}
func NewGolangConnection(pid uint32, connAddr uint32) *golangConnection {
func NewGolangConnection(pid uint32, connAddr uint32, extension *api.Extension, emitter api.Emitter) *golangConnection {
stream := &tlsStream{}
return &golangConnection{
Pid: pid,
ConnAddr: connAddr,
Pid: pid,
ConnAddr: connAddr,
Stream: stream,
ClientReader: NewGolangReader(extension, emitter, stream, true),
ServerReader: NewGolangReader(extension, emitter, stream, false),
}
}

View File

@@ -0,0 +1,105 @@
package tlstapper
import (
"io"
"time"
"github.com/up9inc/mizu/tap/api"
)
type golangReader struct {
key string
msgQueue chan []byte
data []byte
progress *api.ReadProgress
tcpID *api.TcpID
isClient bool
captureTime time.Time
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
parent *tlsStream
reqResMatcher api.RequestResponseMatcher
}
func NewGolangReader(extension *api.Extension, emitter api.Emitter, stream *tlsStream, isClient bool) *golangReader {
return &golangReader{
msgQueue: make(chan []byte, 1),
progress: &api.ReadProgress{},
tcpID: &api.TcpID{},
isClient: isClient,
captureTime: time.Now(),
extension: extension,
emitter: emitter,
counterPair: &api.CounterPair{},
parent: stream,
reqResMatcher: extension.Dissector.NewResponseRequestMatcher(),
}
}
func (r *golangReader) send(b []byte) {
r.captureTime = time.Now()
r.msgQueue <- b
}
func (r *golangReader) Read(p []byte) (int, error) {
var b []byte
for len(r.data) == 0 {
var ok bool
select {
case b, ok = <-r.msgQueue:
if !ok {
return 0, io.EOF
}
r.data = b
}
if len(r.data) > 0 {
break
}
}
l := copy(p, r.data)
r.data = r.data[l:]
r.progress.Feed(l)
return l, nil
}
func (r *golangReader) GetReqResMatcher() api.RequestResponseMatcher {
return r.reqResMatcher
}
func (r *golangReader) GetIsClient() bool {
return r.isClient
}
func (r *golangReader) GetReadProgress() *api.ReadProgress {
return r.progress
}
func (r *golangReader) GetParent() api.TcpStream {
return r.parent
}
func (r *golangReader) GetTcpID() *api.TcpID {
return r.tcpID
}
func (r *golangReader) GetCounterPair() *api.CounterPair {
return r.counterPair
}
func (r *golangReader) GetCaptureTime() time.Time {
return r.captureTime
}
func (r *golangReader) GetEmitter() api.Emitter {
return r.emitter
}
func (r *golangReader) GetIsClosed() bool {
return false
}

View File

@@ -149,8 +149,14 @@ func (p *tlsPoller) pollGolangReadWrite(rd *ringbuf.Reader, emitter api.Emitter,
var _connection interface{}
var ok bool
if _connection, ok = p.golangReadWriteMap.Get(identifier); !ok {
connection = NewGolangConnection(b.Pid, b.ConnAddr)
tlsEmitter := &tlsEmitter{
delegate: emitter,
namespace: p.getNamespace(b.Pid),
}
connection = NewGolangConnection(b.Pid, b.ConnAddr, p.extension, tlsEmitter)
p.golangReadWriteMap.Set(identifier, connection)
streamsMap.Store(streamsMap.NextId(), connection.Stream)
} else {
connection = _connection.(*golangConnection)
}
@@ -166,78 +172,25 @@ func (p *tlsPoller) pollGolangReadWrite(rd *ringbuf.Reader, emitter api.Emitter,
continue
}
tcpid := p.buildTcpId(&connection.AddressPair)
connection.ClientReader.tcpID = &tcpid
connection.ServerReader.tcpID = &api.TcpID{
SrcIP: connection.ClientReader.tcpID.DstIP,
DstIP: connection.ClientReader.tcpID.SrcIP,
SrcPort: connection.ClientReader.tcpID.DstPort,
DstPort: connection.ClientReader.tcpID.SrcPort,
}
go dissect(p.extension, connection.ClientReader, options)
go dissect(p.extension, connection.ServerReader, options)
request := make([]byte, len(b.Data[:]))
copy(request, b.Data[:])
connection.Requests = append(connection.Requests, request)
connection.ClientReader.send(request)
} else {
response := make([]byte, len(b.Data[:]))
copy(response, b.Data[:])
connection.Responses = append(connection.Responses, response)
if !b.IsGzipChunk {
// TODO: Remove these comments
// fmt.Printf("\n\nidentifier: %v\n", identifier)
// fmt.Printf("connection.Pid: %v\n", connection.Pid)
// fmt.Printf("connection.ConnAddr: 0x%x\n", connection.ConnAddr)
// fmt.Printf("connection.AddressPair.srcIp: %v\n", connection.AddressPair.srcIp)
// fmt.Printf("connection.AddressPair.srcPort: %v\n", connection.AddressPair.srcPort)
// fmt.Printf("connection.AddressPair.dstIp: %v\n", connection.AddressPair.dstIp)
// fmt.Printf("connection.AddressPair.dstPort: %v\n", connection.AddressPair.dstPort)
// fmt.Printf("connection.Gzipped: %v\n", connection.Gzipped)
// for i, x := range connection.Requests {
// fmt.Printf("connection.Request[%d]:\n%v\n", i, unix.ByteSliceToString(x))
// }
// for i, y := range connection.Responses {
// fmt.Printf("connection.Response[%d]:\n%v\n", i, unix.ByteSliceToString(y))
// }
// tcpid := p.buildTcpId(&connection.AddressPair)
// tlsEmitter := &tlsEmitter{
// delegate: emitter,
// namespace: p.getNamespace(b.Pid),
// }
// reader := &tlsReader{
// chunks: make(chan *tlsChunk, 1),
// progress: &api.ReadProgress{},
// tcpID: &tcpid,
// isClient: true,
// captureTime: time.Now(),
// extension: p.extension,
// emitter: tlsEmitter,
// counterPair: &api.CounterPair{},
// reqResMatcher: p.extension.Dissector.NewResponseRequestMatcher(),
// }
// stream := &tlsStream{
// reader: reader,
// }
// streamsMap.Store(streamsMap.NextId(), stream)
// reader.parent = stream
// err := p.extension.Dissector.Dissect(bufio.NewReader(bytes.NewReader(connection.Requests[0])), reader, options)
// if err != nil {
// logger.Log.Warningf("Error dissecting TLS %v - %v", reader.GetTcpID(), err)
// }
// reader.isClient = false
// reader.tcpID = &api.TcpID{
// SrcIP: reader.tcpID.DstIP,
// DstIP: reader.tcpID.SrcIP,
// SrcPort: reader.tcpID.DstPort,
// DstPort: reader.tcpID.SrcPort,
// }
// reader.progress = &api.ReadProgress{}
// err = p.extension.Dissector.Dissect(bufio.NewReader(bytes.NewReader(connection.Responses[0])), reader, options)
// if err != nil {
// logger.Log.Warningf("Error dissecting TLS %v - %v", reader.GetTcpID(), err)
// }
}
connection.ServerReader.send(response)
}
}
}
@@ -346,7 +299,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, address *addressPair, key
return reader
}
func dissect(extension *api.Extension, reader *tlsReader, options *api.TrafficFilteringOptions) {
func dissect(extension *api.Extension, reader api.TcpReader, options *api.TrafficFilteringOptions) {
b := bufio.NewReader(reader)
err := extension.Dissector.Dissect(b, reader, options)