Spawn only two Goroutines per TCP stream

This commit is contained in:
M. Mert Yildiran 2022-05-01 19:03:32 +03:00
parent 684c51686f
commit 41527c1be5
No known key found for this signature in database
GPG Key ID: D42ADB236521BF7A
20 changed files with 149 additions and 233 deletions

View File

@ -104,11 +104,6 @@ type OutputChannelItem struct {
Namespace string
}
type ProtoIdentifier struct {
Protocol *Protocol
IsClosedOthers bool
}
type ReadProgress struct {
readBytes int
lastCurrent int
@ -419,13 +414,12 @@ type TcpReader interface {
GetCaptureTime() time.Time
GetEmitter() Emitter
GetIsClosed() bool
GetExtension() *Extension
}
type TcpStream interface {
SetProtocol(protocol *Protocol)
GetOrigin() Capture
GetProtoIdentifier() *ProtoIdentifier
GetProtocol() *Protocol
GetReqResMatchers() []RequestResponseMatcher
GetIsTapTarget() bool
GetIsClosed() bool

View File

@ -75,7 +75,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
var lastMethodFrameMessage Message
for {
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &protocol {
if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &protocol {
return errors.New("Identified by another protocol")
}

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
isClosed bool
protocol *api.Protocol
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
origin: capture,
}
}
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -116,7 +116,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
http2Assembler = createHTTP2Assembler(b)
}
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &http11protocol {
if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &http11protocol {
return errors.New("Identified by another protocol")
}
@ -172,7 +172,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
}
}
if reader.GetParent().GetProtoIdentifier().Protocol == nil {
if reader.GetParent().GetProtocol() == nil {
return err
}

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
isClosed bool
protocol *api.Protocol
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
origin: capture,
}
}
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -38,7 +38,7 @@ func (d dissecting) Ping() {
func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher)
for {
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &_protocol {
if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &_protocol {
return errors.New("Identified by another protocol")
}

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
isClosed bool
protocol *api.Protocol
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
origin: capture,
}
}
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
isClosed bool
protocol *api.Protocol
isTapTarget bool
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
origin: capture,
}
}
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -3,11 +3,9 @@ package tap
import (
"bufio"
"io"
"io/ioutil"
"sync"
"time"
"github.com/up9inc/mizu/logger"
"github.com/up9inc/mizu/tap/api"
)
@ -23,44 +21,44 @@ type tcpReader struct {
isClient bool
isOutgoing bool
msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload
buffer []byte
exhaustBuffer bool
data []byte
progress *api.ReadProgress
captureTime time.Time
parent *tcpStream
packetsSeen uint
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent *tcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) *tcpReader {
func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent *tcpStream, isClient bool, isOutgoing bool, emitter api.Emitter) *tcpReader {
return &tcpReader{
msgQueue: msgQueue,
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
msgQueue: msgQueue,
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
emitter: emitter,
}
}
func (reader *tcpReader) run(options *api.TrafficFilteringOptions, wg *sync.WaitGroup) {
defer wg.Done()
b := bufio.NewReader(reader)
err := reader.extension.Dissector.Dissect(b, reader, options)
if err != nil {
_, err = io.Copy(ioutil.Discard, reader)
if err != nil {
logger.Log.Errorf("%v", err)
for i, extension := range extensions {
reader.reqResMatcher = reader.parent.reqResMatchers[i]
reader.counterPair = reader.parent.counterPairs[i]
b := bufio.NewReader(reader)
extension.Dissector.Dissect(b, reader, options)
if reader.parent.protocol != nil {
break
}
reader.exhaustBuffer = true
}
}
@ -81,7 +79,17 @@ func (reader *tcpReader) sendMsgIfNotClosed(msg api.TcpReaderDataMsg) {
reader.Unlock()
}
func (reader *tcpReader) isProtocolIdentified() bool {
return reader.parent.protocol != nil
}
func (reader *tcpReader) Read(p []byte) (int, error) {
if reader.exhaustBuffer {
l := copy(p, reader.buffer)
reader.exhaustBuffer = false
return l, nil
}
var msg api.TcpReaderDataMsg
ok := true
@ -101,6 +109,9 @@ func (reader *tcpReader) Read(p []byte) (int, error) {
}
l := copy(p, reader.data)
if !reader.isProtocolIdentified() {
reader.buffer = append(reader.buffer, reader.data...)
}
reader.data = reader.data[l:]
reader.progress.Feed(l)
@ -142,7 +153,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -6,7 +6,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
)
@ -22,10 +21,10 @@ type tcpReassemblyStream struct {
fsmerr bool
optchecker reassembly.TCPOptionCheck
isDNS bool
tcpStream api.TcpStream
tcpStream *tcpStream
}
func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream api.TcpStream) ReassemblyStream {
func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream *tcpStream) ReassemblyStream {
return &tcpReassemblyStream{
ident: ident,
tcpState: reassembly.NewTCPSimpleFSM(fsmOptions),
@ -145,17 +144,10 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas
// This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp
stream := t.tcpStream.(*tcpStream)
if dir == reassembly.TCPDirClientToServer {
for i := range stream.getClients() {
reader := stream.getClient(i)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
t.tcpStream.client.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
} else {
for i := range stream.getServers() {
reader := stream.getServer(i)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
t.tcpStream.server.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
}
}
@ -163,7 +155,7 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas
func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() {
t.tcpStream.(*tcpStream).close()
t.tcpStream.close()
}
// do not remove the connection to allow last ACK
return false

View File

@ -13,25 +13,25 @@ import (
* In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel.
*/
type tcpStream struct {
id int64
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
clients []*tcpReader
servers []*tcpReader
origin api.Capture
reqResMatchers []api.RequestResponseMatcher
createdAt time.Time
streamsMap api.TcpStreamMap
id int64
isClosed bool
protocol *api.Protocol
isTapTarget bool
client *tcpReader
server *tcpReader
origin api.Capture
counterPairs []*api.CounterPair
reqResMatchers []api.RequestResponseMatcher
createdAt time.Time
streamsMap api.TcpStreamMap
sync.Mutex
}
func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream {
return &tcpStream{
isTapTarget: isTapTarget,
protoIdentifier: &api.ProtoIdentifier{},
streamsMap: streamsMap,
origin: capture,
isTapTarget: isTapTarget,
streamsMap: streamsMap,
origin: capture,
}
}
@ -55,38 +55,12 @@ func (t *tcpStream) close() {
t.streamsMap.Delete(t.id)
for i := range t.clients {
reader := t.clients[i]
reader.close()
}
for i := range t.servers {
reader := t.servers[i]
reader.close()
}
t.client.close()
t.server.close()
}
func (t *tcpStream) addClient(reader *tcpReader) {
t.clients = append(t.clients, reader)
}
func (t *tcpStream) addServer(reader *tcpReader) {
t.servers = append(t.servers, reader)
}
func (t *tcpStream) getClients() []*tcpReader {
return t.clients
}
func (t *tcpStream) getServers() []*tcpReader {
return t.servers
}
func (t *tcpStream) getClient(index int) *tcpReader {
return t.clients[index]
}
func (t *tcpStream) getServer(index int) *tcpReader {
return t.servers[index]
func (t *tcpStream) addCounterPair(counterPair *api.CounterPair) {
t.counterPairs = append(t.counterPairs, counterPair)
}
func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) {
@ -94,37 +68,17 @@ func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) {
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {
t.Lock()
defer t.Unlock()
if t.protoIdentifier.IsClosedOthers {
return
}
t.protoIdentifier.Protocol = protocol
for i := range t.clients {
reader := t.clients[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.close()
}
}
for i := range t.servers {
reader := t.servers[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.close()
}
}
t.protoIdentifier.IsClosedOthers = true
t.protocol = protocol
t.client.buffer = []byte{}
t.server.buffer = []byte{}
}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -62,62 +62,56 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay
reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream)
if stream.GetIsTapTarget() {
stream.setId(factory.streamsMap.NextId())
for i, extension := range extensions {
reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
stream.addReqResMatcher(reqResMatcher)
for _, extension := range extensions {
counterPair := &api.CounterPair{
Request: 0,
Response: 0,
}
stream.addClient(
NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: srcIp,
DstIP: dstIp,
SrcPort: srcPort,
DstPort: dstPort,
},
time.Time{},
stream,
true,
props.isOutgoing,
extension,
factory.emitter,
counterPair,
reqResMatcher,
),
)
stream.addServer(
NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: net.Dst().String(),
DstIP: net.Src().String(),
SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(),
},
time.Time{},
stream,
false,
props.isOutgoing,
extension,
factory.emitter,
counterPair,
reqResMatcher,
),
)
stream.addCounterPair(counterPair)
factory.streamsMap.Store(stream.getId(), stream)
factory.wg.Add(2)
go stream.getClient(i).run(filteringOptions, &factory.wg)
go stream.getServer(i).run(filteringOptions, &factory.wg)
reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
stream.addReqResMatcher(reqResMatcher)
}
stream.client = NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: srcIp,
DstIP: dstIp,
SrcPort: srcPort,
DstPort: dstPort,
},
time.Time{},
stream,
true,
props.isOutgoing,
factory.emitter,
)
stream.server = NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: net.Dst().String(),
DstIP: net.Src().String(),
SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(),
},
time.Time{},
stream,
false,
props.isOutgoing,
factory.emitter,
)
factory.streamsMap.Store(stream.getId(), stream)
factory.wg.Add(2)
go stream.client.run(filteringOptions, &factory.wg)
go stream.server.run(filteringOptions, &factory.wg)
}
return reassemblyStream
}

View File

@ -55,7 +55,7 @@ func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() {
return true
}
if stream.protoIdentifier.Protocol == nil {
if stream.protocol == nil {
if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) {
stream.close()
diagnose.AppStats.IncDroppedTcpStreams()

View File

@ -171,8 +171,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, k
}
stream := &tlsStream{
reader: reader,
protoIdentifier: &api.ProtoIdentifier{},
reader: reader,
}
streamsMap.Store(streamsMap.NextId(), stream)

View File

@ -87,7 +87,3 @@ func (r *tlsReader) GetEmitter() api.Emitter {
func (r *tlsReader) GetIsClosed() bool {
return false
}
func (r *tlsReader) GetExtension() *api.Extension {
return r.extension
}

View File

@ -3,20 +3,20 @@ package tlstapper
import "github.com/up9inc/mizu/tap/api"
type tlsStream struct {
reader *tlsReader
protoIdentifier *api.ProtoIdentifier
reader *tlsReader
protocol *api.Protocol
}
func (t *tlsStream) GetOrigin() api.Capture {
return api.Ebpf
}
func (t *tlsStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
func (t *tlsStream) GetProtocol() *api.Protocol {
return t.protocol
}
func (t *tlsStream) SetProtocol(protocol *api.Protocol) {
t.protoIdentifier.Protocol = protocol
t.protocol = protocol
}
func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher {