Spawn only two Goroutines per TCP stream (#1062)

* Spawn only two Goroutines per TCP stream

* Fix the linter error

* Use `isProtocolIdentified` method instead

* Fix the `Read` method of `tcpReader`

* Remove unnecessary `append`

* Copy to buffer only a message is received

* Remove `exhaustBuffer` field and add `rewind` function

* Rename `buffer` field to `pastData`

* Update tap/tcp_reader.go

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>

* Use `copy` instead of assignment

* No lint

* #run_acceptance_tests

* Fix `rewind` #run_acceptance_tests

* Fix the buffering algorithm #run_acceptance_tests

* Add `TODO`

* Fix the problems in AMQP and Kafka #run_acceptance_tests

* Use `*bytes.Buffer` instead of `[]api.TcpReaderDataMsg` #run_acceptance_tests

* Have a single `*bytes.Buffer`

* Revert "Have a single `*bytes.Buffer`"

This reverts commit fad96a288a.

* Revert "Use `*bytes.Buffer` instead of `[]api.TcpReaderDataMsg` #run_acceptance_tests"

This reverts commit 0fc70bffe2.

* Fix the early timing out issue #run_acceptance_tests

* Remove `NewBytes()` method

* Update the `NewTcpReader` method signature #run_acceptance_tests

* #run_acceptance_tests

* #run_acceptance_tests

* #run_acceptance_tests

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>
This commit is contained in:
M. Mert Yıldıran 2022-05-16 06:06:36 -07:00 committed by GitHub
parent 5c012641a5
commit bfa834e840
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 190 additions and 253 deletions

View File

@ -104,11 +104,6 @@ type OutputChannelItem struct {
Namespace string Namespace string
} }
type ProtoIdentifier struct {
Protocol *Protocol
IsClosedOthers bool
}
type ReadProgress struct { type ReadProgress struct {
readBytes int readBytes int
lastCurrent int lastCurrent int
@ -123,6 +118,11 @@ func (p *ReadProgress) Current() (n int) {
return p.lastCurrent return p.lastCurrent
} }
func (p *ReadProgress) Reset() {
p.readBytes = 0
p.lastCurrent = 0
}
type Dissector interface { type Dissector interface {
Register(*Extension) Register(*Extension)
Ping() Ping()
@ -419,13 +419,12 @@ type TcpReader interface {
GetCaptureTime() time.Time GetCaptureTime() time.Time
GetEmitter() Emitter GetEmitter() Emitter
GetIsClosed() bool GetIsClosed() bool
GetExtension() *Extension
} }
type TcpStream interface { type TcpStream interface {
SetProtocol(protocol *Protocol) SetProtocol(protocol *Protocol)
GetOrigin() Capture GetOrigin() Capture
GetProtoIdentifier() *ProtoIdentifier GetProtocol() *Protocol
GetReqResMatchers() []RequestResponseMatcher GetReqResMatchers() []RequestResponseMatcher
GetIsTapTarget() bool GetIsTapTarget() bool
GetIsClosed() bool GetIsClosed() bool

View File

@ -75,14 +75,14 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
var lastMethodFrameMessage Message var lastMethodFrameMessage Message
for { for {
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &protocol { if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }
frame, err := r.ReadFrame() frame, err := r.ReadFrame()
if err == io.EOF { if err == io.EOF {
// We must read until we see an EOF... very important! // We must read until we see an EOF... very important!
return nil return err
} }
switch f := frame.(type) { switch f := frame.(type) {

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool { func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed return reader.isClosed
} }
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
) )
type tcpStream struct { type tcpStream struct {
isClosed bool isClosed bool
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
isTapTarget bool isTapTarget bool
origin api.Capture origin api.Capture
reqResMatchers []api.RequestResponseMatcher reqResMatchers []api.RequestResponseMatcher
sync.Mutex sync.Mutex
} }
func NewTcpStream(capture api.Capture) api.TcpStream { func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{ return &tcpStream{
origin: capture, origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
} }
} }
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin return t.origin
} }
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -144,7 +144,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
http2Assembler = createHTTP2Assembler(b) http2Assembler = createHTTP2Assembler(b)
} }
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &http11protocol { if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &http11protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }
@ -200,7 +200,7 @@ func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.
} }
} }
if reader.GetParent().GetProtoIdentifier().Protocol == nil { if reader.GetParent().GetProtocol() == nil {
return err return err
} }

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool { func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed return reader.isClosed
} }
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
) )
type tcpStream struct { type tcpStream struct {
isClosed bool isClosed bool
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
isTapTarget bool isTapTarget bool
origin api.Capture origin api.Capture
reqResMatchers []api.RequestResponseMatcher reqResMatchers []api.RequestResponseMatcher
sync.Mutex sync.Mutex
} }
func NewTcpStream(capture api.Capture) api.TcpStream { func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{ return &tcpStream{
origin: capture, origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
} }
} }
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin return t.origin
} }
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -38,7 +38,7 @@ func (d dissecting) Ping() {
func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher) reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher)
for { for {
if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &_protocol { if reader.GetParent().GetProtocol() != nil && reader.GetParent().GetProtocol() != &_protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool { func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed return reader.isClosed
} }
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
) )
type tcpStream struct { type tcpStream struct {
isClosed bool isClosed bool
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
isTapTarget bool isTapTarget bool
origin api.Capture origin api.Capture
reqResMatchers []api.RequestResponseMatcher reqResMatchers []api.RequestResponseMatcher
sync.Mutex sync.Mutex
} }
func NewTcpStream(capture api.Capture) api.TcpStream { func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{ return &tcpStream{
origin: capture, origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
} }
} }
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin return t.origin
} }
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -78,7 +78,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool { func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed return reader.isClosed
} }
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -7,18 +7,17 @@ import (
) )
type tcpStream struct { type tcpStream struct {
isClosed bool isClosed bool
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
isTapTarget bool isTapTarget bool
origin api.Capture origin api.Capture
reqResMatchers []api.RequestResponseMatcher reqResMatchers []api.RequestResponseMatcher
sync.Mutex sync.Mutex
} }
func NewTcpStream(capture api.Capture) api.TcpStream { func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{ return &tcpStream{
origin: capture, origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
} }
} }
@ -28,8 +27,8 @@ func (t *tcpStream) GetOrigin() api.Capture {
return t.origin return t.origin
} }
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -3,11 +3,9 @@ package tap
import ( import (
"bufio" "bufio"
"io" "io"
"io/ioutil"
"sync" "sync"
"time" "time"
"github.com/up9inc/mizu/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
) )
@ -17,50 +15,48 @@ import (
* Implements io.Reader interface (Read) * Implements io.Reader interface (Read)
*/ */
type tcpReader struct { type tcpReader struct {
ident string ident string
tcpID *api.TcpID tcpID *api.TcpID
isClosed bool isClosed bool
isClient bool isClient bool
isOutgoing bool isOutgoing bool
msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte msgBuffer []api.TcpReaderDataMsg
progress *api.ReadProgress msgBufferMaster []api.TcpReaderDataMsg
captureTime time.Time data []byte
parent *tcpStream progress *api.ReadProgress
packetsSeen uint captureTime time.Time
extension *api.Extension parent *tcpStream
emitter api.Emitter emitter api.Emitter
counterPair *api.CounterPair counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher reqResMatcher api.RequestResponseMatcher
sync.Mutex sync.Mutex
} }
func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent *tcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) *tcpReader { func NewTcpReader(ident string, tcpId *api.TcpID, parent *tcpStream, isClient bool, isOutgoing bool, emitter api.Emitter) *tcpReader {
return &tcpReader{ return &tcpReader{
msgQueue: msgQueue, msgQueue: make(chan api.TcpReaderDataMsg),
progress: progress, progress: &api.ReadProgress{},
ident: ident, ident: ident,
tcpID: tcpId, tcpID: tcpId,
captureTime: captureTime, parent: parent,
parent: parent, isClient: isClient,
isClient: isClient, isOutgoing: isOutgoing,
isOutgoing: isOutgoing, emitter: emitter,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
} }
} }
func (reader *tcpReader) run(options *api.TrafficFilteringOptions, wg *sync.WaitGroup) { func (reader *tcpReader) run(options *api.TrafficFilteringOptions, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
b := bufio.NewReader(reader) for i, extension := range extensions {
err := reader.extension.Dissector.Dissect(b, reader, options) reader.reqResMatcher = reader.parent.reqResMatchers[i]
if err != nil { reader.counterPair = reader.parent.counterPairs[i]
_, err = io.Copy(ioutil.Discard, reader) b := bufio.NewReader(reader)
if err != nil { extension.Dissector.Dissect(b, reader, options) //nolint
logger.Log.Errorf("%v", err) if reader.isProtocolIdentified() {
break
} }
reader.rewind()
} }
} }
@ -81,21 +77,56 @@ func (reader *tcpReader) sendMsgIfNotClosed(msg api.TcpReaderDataMsg) {
reader.Unlock() reader.Unlock()
} }
func (reader *tcpReader) isProtocolIdentified() bool {
return reader.parent.protocol != nil
}
func (reader *tcpReader) rewind() {
// Reset the data and msgBuffer from the master record
reader.data = make([]byte, 0)
reader.msgBuffer = make([]api.TcpReaderDataMsg, len(reader.msgBufferMaster))
copy(reader.msgBuffer, reader.msgBufferMaster)
// Reset the read progress
reader.progress.Reset()
}
func (reader *tcpReader) populateData(msg api.TcpReaderDataMsg) {
reader.data = msg.GetBytes()
reader.captureTime = msg.GetTimestamp()
}
func (reader *tcpReader) Read(p []byte) (int, error) { func (reader *tcpReader) Read(p []byte) (int, error) {
var msg api.TcpReaderDataMsg var msg api.TcpReaderDataMsg
for len(reader.msgBuffer) > 0 && len(reader.data) == 0 {
// Pop first message
if len(reader.msgBuffer) > 1 {
msg, reader.msgBuffer = reader.msgBuffer[0], reader.msgBuffer[1:]
} else {
msg = reader.msgBuffer[0]
reader.msgBuffer = make([]api.TcpReaderDataMsg, 0)
}
// Get the bytes
reader.populateData(msg)
}
ok := true ok := true
for ok && len(reader.data) == 0 { for ok && len(reader.data) == 0 {
msg, ok = <-reader.msgQueue msg, ok = <-reader.msgQueue
if msg != nil { if msg != nil {
reader.data = msg.GetBytes() reader.populateData(msg)
reader.captureTime = msg.GetTimestamp()
}
if len(reader.data) > 0 { if !reader.isProtocolIdentified() {
reader.packetsSeen += 1 reader.msgBufferMaster = append(
reader.msgBufferMaster,
msg,
)
}
} }
} }
if !ok || len(reader.data) == 0 { if !ok || len(reader.data) == 0 {
return 0, io.EOF return 0, io.EOF
} }
@ -142,7 +173,3 @@ func (reader *tcpReader) GetEmitter() api.Emitter {
func (reader *tcpReader) GetIsClosed() bool { func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed return reader.isClosed
} }
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -6,7 +6,6 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly" "github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/diagnose"
) )
@ -16,10 +15,10 @@ type tcpReassemblyStream struct {
fsmerr bool fsmerr bool
optchecker reassembly.TCPOptionCheck optchecker reassembly.TCPOptionCheck
isDNS bool isDNS bool
tcpStream api.TcpStream tcpStream *tcpStream
} }
func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream api.TcpStream) reassembly.Stream { func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream *tcpStream) reassembly.Stream {
return &tcpReassemblyStream{ return &tcpReassemblyStream{
ident: ident, ident: ident,
tcpState: reassembly.NewTCPSimpleFSM(fsmOptions), tcpState: reassembly.NewTCPSimpleFSM(fsmOptions),
@ -139,17 +138,10 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas
// This channel is read by an tcpReader object // This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount() diagnose.AppStats.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp timestamp := ac.GetCaptureInfo().Timestamp
stream := t.tcpStream.(*tcpStream)
if dir == reassembly.TCPDirClientToServer { if dir == reassembly.TCPDirClientToServer {
for i := range stream.getClients() { t.tcpStream.client.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
reader := stream.getClient(i)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
} else { } else {
for i := range stream.getServers() { t.tcpStream.server.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
reader := stream.getServer(i)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
} }
} }
} }
@ -157,7 +149,7 @@ func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reas
func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() { if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() {
t.tcpStream.(*tcpStream).close() t.tcpStream.close()
} }
// do not remove the connection to allow last ACK // do not remove the connection to allow last ACK
return false return false

View File

@ -13,25 +13,26 @@ import (
* In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel. * In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel.
*/ */
type tcpStream struct { type tcpStream struct {
id int64 id int64
isClosed bool isClosed bool
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
isTapTarget bool isTapTarget bool
clients []*tcpReader client *tcpReader
servers []*tcpReader server *tcpReader
origin api.Capture origin api.Capture
reqResMatchers []api.RequestResponseMatcher counterPairs []*api.CounterPair
createdAt time.Time reqResMatchers []api.RequestResponseMatcher
streamsMap api.TcpStreamMap createdAt time.Time
streamsMap api.TcpStreamMap
sync.Mutex sync.Mutex
} }
func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream { func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) *tcpStream {
return &tcpStream{ return &tcpStream{
isTapTarget: isTapTarget, isTapTarget: isTapTarget,
protoIdentifier: &api.ProtoIdentifier{}, streamsMap: streamsMap,
streamsMap: streamsMap, origin: capture,
origin: capture, createdAt: time.Now(),
} }
} }
@ -55,38 +56,12 @@ func (t *tcpStream) close() {
t.streamsMap.Delete(t.id) t.streamsMap.Delete(t.id)
for i := range t.clients { t.client.close()
reader := t.clients[i] t.server.close()
reader.close()
}
for i := range t.servers {
reader := t.servers[i]
reader.close()
}
} }
func (t *tcpStream) addClient(reader *tcpReader) { func (t *tcpStream) addCounterPair(counterPair *api.CounterPair) {
t.clients = append(t.clients, reader) t.counterPairs = append(t.counterPairs, counterPair)
}
func (t *tcpStream) addServer(reader *tcpReader) {
t.servers = append(t.servers, reader)
}
func (t *tcpStream) getClients() []*tcpReader {
return t.clients
}
func (t *tcpStream) getServers() []*tcpReader {
return t.servers
}
func (t *tcpStream) getClient(index int) *tcpReader {
return t.clients[index]
}
func (t *tcpStream) getServer(index int) *tcpReader {
return t.servers[index]
} }
func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) { func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) {
@ -94,37 +69,19 @@ func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) {
} }
func (t *tcpStream) SetProtocol(protocol *api.Protocol) { func (t *tcpStream) SetProtocol(protocol *api.Protocol) {
t.Lock() t.protocol = protocol
defer t.Unlock()
if t.protoIdentifier.IsClosedOthers { // Clean the buffers
return t.client.msgBufferMaster = make([]api.TcpReaderDataMsg, 0)
} t.server.msgBufferMaster = make([]api.TcpReaderDataMsg, 0)
t.protoIdentifier.Protocol = protocol
for i := range t.clients {
reader := t.clients[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.close()
}
}
for i := range t.servers {
reader := t.servers[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.close()
}
}
t.protoIdentifier.IsClosedOthers = true
} }
func (t *tcpStream) GetOrigin() api.Capture { func (t *tcpStream) GetOrigin() api.Capture {
return t.origin return t.origin
} }
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tcpStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tcpStream) GetReqResMatchers() []api.RequestResponseMatcher {

View File

@ -3,7 +3,6 @@ package tap
import ( import (
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/up9inc/mizu/logger" "github.com/up9inc/mizu/logger"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
@ -62,62 +61,50 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *lay
reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream) reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream)
if stream.GetIsTapTarget() { if stream.GetIsTapTarget() {
stream.setId(factory.streamsMap.NextId()) stream.setId(factory.streamsMap.NextId())
for i, extension := range extensions { for _, extension := range extensions {
reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
stream.addReqResMatcher(reqResMatcher)
counterPair := &api.CounterPair{ counterPair := &api.CounterPair{
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
stream.addClient( stream.addCounterPair(counterPair)
NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: srcIp,
DstIP: dstIp,
SrcPort: srcPort,
DstPort: dstPort,
},
time.Time{},
stream,
true,
props.isOutgoing,
extension,
factory.emitter,
counterPair,
reqResMatcher,
),
)
stream.addServer(
NewTcpReader(
make(chan api.TcpReaderDataMsg),
&api.ReadProgress{},
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: net.Dst().String(),
DstIP: net.Src().String(),
SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(),
},
time.Time{},
stream,
false,
props.isOutgoing,
extension,
factory.emitter,
counterPair,
reqResMatcher,
),
)
factory.streamsMap.Store(stream.getId(), stream) reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
stream.addReqResMatcher(reqResMatcher)
factory.wg.Add(2)
go stream.getClient(i).run(filteringOptions, &factory.wg)
go stream.getServer(i).run(filteringOptions, &factory.wg)
} }
stream.client = NewTcpReader(
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: srcIp,
DstIP: dstIp,
SrcPort: srcPort,
DstPort: dstPort,
},
stream,
true,
props.isOutgoing,
factory.emitter,
)
stream.server = NewTcpReader(
fmt.Sprintf("%s %s", net, transport),
&api.TcpID{
SrcIP: net.Dst().String(),
DstIP: net.Src().String(),
SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(),
},
stream,
false,
props.isOutgoing,
factory.emitter,
)
factory.streamsMap.Store(stream.getId(), stream)
factory.wg.Add(2)
go stream.client.run(filteringOptions, &factory.wg)
go stream.server.run(filteringOptions, &factory.wg)
} }
return reassemblyStream return reassemblyStream
} }

View File

@ -57,7 +57,7 @@ func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() {
return true return true
} }
if stream.protoIdentifier.Protocol == nil { if stream.protocol == nil {
if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) { if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) {
stream.close() stream.close()
diagnose.AppStats.IncDroppedTcpStreams() diagnose.AppStats.IncDroppedTcpStreams()

View File

@ -188,8 +188,7 @@ func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, address *addressPair, key
} }
stream := &tlsStream{ stream := &tlsStream{
reader: reader, reader: reader,
protoIdentifier: &api.ProtoIdentifier{},
} }
streamsMap.Store(streamsMap.NextId(), stream) streamsMap.Store(streamsMap.NextId(), stream)

View File

@ -94,7 +94,3 @@ func (r *tlsReader) GetEmitter() api.Emitter {
func (r *tlsReader) GetIsClosed() bool { func (r *tlsReader) GetIsClosed() bool {
return false return false
} }
func (r *tlsReader) GetExtension() *api.Extension {
return r.extension
}

View File

@ -3,20 +3,20 @@ package tlstapper
import "github.com/up9inc/mizu/tap/api" import "github.com/up9inc/mizu/tap/api"
type tlsStream struct { type tlsStream struct {
reader *tlsReader reader *tlsReader
protoIdentifier *api.ProtoIdentifier protocol *api.Protocol
} }
func (t *tlsStream) GetOrigin() api.Capture { func (t *tlsStream) GetOrigin() api.Capture {
return api.Ebpf return api.Ebpf
} }
func (t *tlsStream) GetProtoIdentifier() *api.ProtoIdentifier { func (t *tlsStream) GetProtocol() *api.Protocol {
return t.protoIdentifier return t.protocol
} }
func (t *tlsStream) SetProtocol(protocol *api.Protocol) { func (t *tlsStream) SetProtocol(protocol *api.Protocol) {
t.protoIdentifier.Protocol = protocol t.protocol = protocol
} }
func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher { func (t *tlsStream) GetReqResMatchers() []api.RequestResponseMatcher {