Refactor tap module to achieve synchronously closing other protocol dissectors upon identification (#1026)

* Remove `tcpStreamWrapper` struct

* Refactor `tap` module and move some of the code to `tap/api` module

* Move `TrafficFilteringOptions` struct to `shared` module

* Change the `Dissect` method signature to have `*TcpReader` as an argument

* Add `CloseOtherProtocolDissectors` method and use it to synchronously close the other protocol dissectors

* Run `go mod tidy` in `cli` module

* Rename `SuperIdentifier` struct to `ProtoIdentifier`

* Remove `SuperTimer` struct

* Bring back `CloseTimedoutTcpStreamChannels` method

* Run `go mod tidy` everywhere

* Remove `GOGC` environment variable from tapper

* Fix the tests

* Bring back `debug.FreeOSMemory()` call

* Make `CloseOtherProtocolDissectors` method mutexed

* Revert "Remove `GOGC` environment variable from tapper"

This reverts commit cfc2484bbb.

* Bring back the removed `checksum`, `nooptcheck` and `ignorefsmerr` flags

* Define a bunch of interfaces and don't export any new structs from `tap/api`

* Keep the interfaces in `tap/api` but move the structs to `tap/tcp`

* Fix the unit tests by depending on `github.com/up9inc/mizu/tap`

* Use the modified `tlsEmitter`

* Define `TlsChunk` interface and make `tlsReader` implement `TcpReader`

* Remove unused fields in `tlsReader`

* Define `ReassemblyStream` interface and separate `gopacket` specififc fields to `tcpReassemblyStream` struct

Such that make `tap/api` don't depend on `gopacket`

* Remove the unused fields

* Make `tlsPoller` implement `TcpStream` interface and remove the call to `NewTcpStreamDummy` method

* Remove unused fields from `tlsPoller`

* Remove almost all of the setter methods in `TcpReader` and `TcpStream` interface and remove `TlsChunk` interface

* Revert "Revert "Remove `GOGC` environment variable from tapper""

This reverts commit ab2b9a803b.

* Revert "Bring back `debug.FreeOSMemory()` call"

This reverts commit 1cce863bbb.

* Remove excess comment

* Fix acceptance tests (`logger` module) #run_acceptance_tests

* Bring back `github.com/patrickmn/go-cache`

* Fix `NewTcpStream` method signature

* Put `tcpReader` and `tcpStream` mocks into protocol dissectors to remove `github.com/up9inc/mizu/tap` dependency

* Fix AMQP tests

* Revert 960ba644cd

* Revert `go.mod` and `go.sum` files in protocol dissectors

* Fix the comment position

* Revert `AppStatsInst` change

* Fix indent

* Fix CLI build

* Fix linter error

* Fix error msg

* Revert some of the changes in `chunk.go`
This commit is contained in:
M. Mert Yıldıran 2022-04-28 07:19:14 -07:00 committed by GitHub
parent ed9e162af0
commit d3e6a69d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1445 additions and 1369 deletions

View File

@ -10,7 +10,6 @@ const (
ValidationRulesFileName = "validation-rules.yaml" ValidationRulesFileName = "validation-rules.yaml"
ContractFileName = "contract-oas.yaml" ContractFileName = "contract-oas.yaml"
ConfigFileName = "mizu-config.json" ConfigFileName = "mizu-config.json"
GoGCEnvVar = "GOGC"
DefaultApiServerPort = 8899 DefaultApiServerPort = 8899
LogLevelEnvVar = "LOG_LEVEL" LogLevelEnvVar = "LOG_LEVEL"
MizuAgentImageRepo = "docker.io/up9inc/mizu" MizuAgentImageRepo = "docker.io/up9inc/mizu"

View File

@ -768,7 +768,6 @@ func (provider *Provider) ApplyMizuTapperDaemonSet(ctx context.Context, namespac
agentContainer.WithEnv( agentContainer.WithEnv(
applyconfcore.EnvVar().WithName(shared.LogLevelEnvVar).WithValue(logLevel.String()), applyconfcore.EnvVar().WithName(shared.LogLevelEnvVar).WithValue(logLevel.String()),
applyconfcore.EnvVar().WithName(shared.HostModeEnvVar).WithValue("1"), applyconfcore.EnvVar().WithName(shared.HostModeEnvVar).WithValue("1"),
applyconfcore.EnvVar().WithName(shared.GoGCEnvVar).WithValue("12800"),
applyconfcore.EnvVar().WithName(shared.MizuFilteringOptionsEnvVar).WithValue(string(mizuApiFilteringOptionsJsonStr)), applyconfcore.EnvVar().WithName(shared.MizuFilteringOptionsEnvVar).WithValue(string(mizuApiFilteringOptionsJsonStr)),
) )
agentContainer.WithEnv( agentContainer.WithEnv(

View File

@ -104,11 +104,7 @@ type OutputChannelItem struct {
Namespace string Namespace string
} }
type SuperTimer struct { type ProtoIdentifier struct {
CaptureTime time.Time
}
type SuperIdentifier struct {
Protocol *Protocol Protocol *Protocol
IsClosedOthers bool IsClosedOthers bool
} }
@ -130,7 +126,7 @@ func (p *ReadProgress) Current() (n int) {
type Dissector interface { type Dissector interface {
Register(*Extension) Register(*Extension)
Ping() Ping()
Dissect(b *bufio.Reader, progress *ReadProgress, capture Capture, isClient bool, tcpID *TcpID, counterPair *CounterPair, superTimer *SuperTimer, superIdentifier *SuperIdentifier, emitter Emitter, options *TrafficFilteringOptions, reqResMatcher RequestResponseMatcher) error Dissect(b *bufio.Reader, reader TcpReader, options *TrafficFilteringOptions) error
Analyze(item *OutputChannelItem, resolvedSource string, resolvedDestination string, namespace string) *Entry Analyze(item *OutputChannelItem, resolvedSource string, resolvedDestination string, namespace string) *Entry
Summarize(entry *Entry) *BaseEntry Summarize(entry *Entry) *BaseEntry
Represent(request map[string]interface{}, response map[string]interface{}) (object []byte, err error) Represent(request map[string]interface{}, response map[string]interface{}) (object []byte, err error)
@ -406,3 +402,39 @@ func (r *HTTPResponseWrapper) MarshalJSON() ([]byte, error) {
Response: r.Response, Response: r.Response,
}) })
} }
type TcpReaderDataMsg interface {
GetBytes() []byte
GetTimestamp() time.Time
}
type TcpReader interface {
Read(p []byte) (int, error)
GetReqResMatcher() RequestResponseMatcher
GetIsClient() bool
GetReadProgress() *ReadProgress
GetParent() TcpStream
GetTcpID() *TcpID
GetCounterPair() *CounterPair
GetCaptureTime() time.Time
GetEmitter() Emitter
GetIsClosed() bool
GetExtension() *Extension
}
type TcpStream interface {
SetProtocol(protocol *Protocol)
GetOrigin() Capture
GetProtoIdentifier() *ProtoIdentifier
GetReqResMatcher() RequestResponseMatcher
GetIsTapTarget() bool
GetIsClosed() bool
}
type TcpStreamMap interface {
Range(f func(key, value interface{}) bool)
Store(key, value interface{})
Delete(key interface{})
NextId() int64
CloseTimedoutTcpStreamChannels()
}

View File

@ -3,3 +3,7 @@ module github.com/up9inc/mizu/tap/api
go 1.17 go 1.17
require github.com/google/martian v2.1.0+incompatible require github.com/google/martian v2.1.0+incompatible
replace github.com/up9inc/mizu/logger v0.0.0 => ../../logger
replace github.com/up9inc/mizu/shared v0.0.0 => ../../shared

View File

@ -22,7 +22,7 @@ type Cleaner struct {
connectionTimeout time.Duration connectionTimeout time.Duration
stats CleanerStats stats CleanerStats
statsMutex sync.Mutex statsMutex sync.Mutex
streamsMap *tcpStreamMap streamsMap api.TcpStreamMap
} }
func (cl *Cleaner) clean() { func (cl *Cleaner) clean() {
@ -33,8 +33,8 @@ func (cl *Cleaner) clean() {
flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout))
cl.assemblerMutex.Unlock() cl.assemblerMutex.Unlock()
cl.streamsMap.streams.Range(func(k, v interface{}) bool { cl.streamsMap.Range(func(k, v interface{}) bool {
reqResMatcher := v.(*tcpStreamWrapper).reqResMatcher reqResMatcher := v.(api.TcpStream).GetReqResMatcher()
if reqResMatcher == nil { if reqResMatcher == nil {
return true return true
} }

View File

@ -15,3 +15,5 @@ require (
) )
replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api
replace github.com/up9inc/mizu/logger v0.0.0 => ../../../logger

View File

@ -39,17 +39,17 @@ func (d dissecting) Ping() {
const amqpRequest string = "amqp_request" const amqpRequest string = "amqp_request"
func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter, options *api.TrafficFilteringOptions, _reqResMatcher api.RequestResponseMatcher) error { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
r := AmqpReader{b} r := AmqpReader{b}
var remaining int var remaining int
var header *HeaderFrame var header *HeaderFrame
connectionInfo := &api.ConnectionInfo{ connectionInfo := &api.ConnectionInfo{
ClientIP: tcpID.SrcIP, ClientIP: reader.GetTcpID().SrcIP,
ClientPort: tcpID.SrcPort, ClientPort: reader.GetTcpID().SrcPort,
ServerIP: tcpID.DstIP, ServerIP: reader.GetTcpID().DstIP,
ServerPort: tcpID.DstPort, ServerPort: reader.GetTcpID().DstPort,
IsOutgoing: true, IsOutgoing: true,
} }
@ -75,7 +75,7 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
var lastMethodFrameMessage Message var lastMethodFrameMessage Message
for { for {
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &protocol { if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }
@ -112,12 +112,12 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
switch lastMethodFrameMessage.(type) { switch lastMethodFrameMessage.(type) {
case *BasicPublish: case *BasicPublish:
eventBasicPublish.Body = f.Body eventBasicPublish.Body = f.Body
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventBasicPublish, amqpRequest, basicMethodMap[40], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *BasicDeliver: case *BasicDeliver:
eventBasicDeliver.Body = f.Body eventBasicDeliver.Body = f.Body
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventBasicDeliver, amqpRequest, basicMethodMap[60], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
} }
case *MethodFrame: case *MethodFrame:
@ -137,8 +137,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
NoWait: m.NoWait, NoWait: m.NoWait,
Arguments: m.Arguments, Arguments: m.Arguments,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventQueueBind, amqpRequest, queueMethodMap[20], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *BasicConsume: case *BasicConsume:
eventBasicConsume := &BasicConsume{ eventBasicConsume := &BasicConsume{
@ -150,8 +150,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
NoWait: m.NoWait, NoWait: m.NoWait,
Arguments: m.Arguments, Arguments: m.Arguments,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventBasicConsume, amqpRequest, basicMethodMap[20], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *BasicDeliver: case *BasicDeliver:
eventBasicDeliver.ConsumerTag = m.ConsumerTag eventBasicDeliver.ConsumerTag = m.ConsumerTag
@ -170,8 +170,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
NoWait: m.NoWait, NoWait: m.NoWait,
Arguments: m.Arguments, Arguments: m.Arguments,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventQueueDeclare, amqpRequest, queueMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *ExchangeDeclare: case *ExchangeDeclare:
eventExchangeDeclare := &ExchangeDeclare{ eventExchangeDeclare := &ExchangeDeclare{
@ -184,8 +184,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
NoWait: m.NoWait, NoWait: m.NoWait,
Arguments: m.Arguments, Arguments: m.Arguments,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventExchangeDeclare, amqpRequest, exchangeMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *ConnectionStart: case *ConnectionStart:
eventConnectionStart := &ConnectionStart{ eventConnectionStart := &ConnectionStart{
@ -195,8 +195,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
Mechanisms: m.Mechanisms, Mechanisms: m.Mechanisms,
Locales: m.Locales, Locales: m.Locales,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventConnectionStart, amqpRequest, connectionMethodMap[10], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
case *ConnectionClose: case *ConnectionClose:
eventConnectionClose := &ConnectionClose{ eventConnectionClose := &ConnectionClose{
@ -205,8 +205,8 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
ClassId: m.ClassId, ClassId: m.ClassId,
MethodId: m.MethodId, MethodId: m.MethodId,
} }
superIdentifier.Protocol = &protocol reader.GetParent().SetProtocol(&protocol)
emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, superTimer.CaptureTime, progress.Current(), emitter, capture) emitAMQP(*eventConnectionClose, amqpRequest, connectionMethodMap[50], connectionInfo, reader.GetCaptureTime(), reader.GetReadProgress().Current(), reader.GetEmitter(), reader.GetParent().GetOrigin())
} }
default: default:

View File

@ -106,7 +106,6 @@ func TestDissect(t *testing.T) {
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
superIdentifier := &api.SuperIdentifier{}
// Request // Request
pathClient := _path pathClient := _path
@ -122,7 +121,21 @@ func TestDissect(t *testing.T) {
DstPort: "2", DstPort: "2",
} }
reqResMatcher := dissector.NewResponseRequestMatcher() reqResMatcher := dissector.NewResponseRequestMatcher()
err = dissector.Dissect(bufferClient, &api.ReadProgress{}, api.Pcap, true, tcpIDClient, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) stream := NewTcpStream(api.Pcap)
reader := NewTcpReader(
&api.ReadProgress{},
"",
tcpIDClient,
time.Time{},
stream,
true,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferClient, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
panic(err) panic(err)
} }
@ -140,7 +153,20 @@ func TestDissect(t *testing.T) {
SrcPort: "2", SrcPort: "2",
DstPort: "1", DstPort: "1",
} }
err = dissector.Dissect(bufferServer, &api.ReadProgress{}, api.Pcap, false, tcpIDServer, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) reader = NewTcpReader(
&api.ReadProgress{},
"",
tcpIDServer,
time.Time{},
stream,
false,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferServer, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
panic(err) panic(err)
} }

View File

@ -0,0 +1,84 @@
package amqp
import (
"sync"
"time"
"github.com/up9inc/mizu/tap/api"
)
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
progress *api.ReadProgress
captureTime time.Time
parent api.TcpStream
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpReader(progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader {
return &tcpReader{
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
}
}
func (reader *tcpReader) Read(p []byte) (int, error) {
return 0, nil
}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher
}
func (reader *tcpReader) GetIsClient() bool {
return reader.isClient
}
func (reader *tcpReader) GetReadProgress() *api.ReadProgress {
return reader.progress
}
func (reader *tcpReader) GetParent() api.TcpStream {
return reader.parent
}
func (reader *tcpReader) GetTcpID() *api.TcpID {
return reader.tcpID
}
func (reader *tcpReader) GetCounterPair() *api.CounterPair {
return reader.counterPair
}
func (reader *tcpReader) GetCaptureTime() time.Time {
return reader.captureTime
}
func (reader *tcpReader) GetEmitter() api.Emitter {
return reader.emitter
}
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -0,0 +1,45 @@
package amqp
import (
"sync"
"github.com/up9inc/mizu/tap/api"
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
}
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
}
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
return t.reqResMatcher
}
func (t *tcpStream) GetIsTapTarget() bool {
return t.isTapTarget
}
func (t *tcpStream) GetIsClosed() bool {
return t.isClosed
}

View File

@ -18,3 +18,5 @@ require (
) )
replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api
replace github.com/up9inc/mizu/logger v0.0.0 => ../../../logger

View File

@ -8,6 +8,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
) )
@ -47,7 +48,7 @@ func replaceForwardedFor(item *api.OutputChannelItem) {
item.ConnectionInfo.ClientPort = "" item.ConnectionInfo.ClientPort = ""
} }
func handleHTTP2Stream(http2Assembler *Http2Assembler, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, superTimer *api.SuperTimer, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) error { func handleHTTP2Stream(http2Assembler *Http2Assembler, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, captureTime time.Time, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) error {
streamID, messageHTTP1, isGrpc, err := http2Assembler.readMessage() streamID, messageHTTP1, isGrpc, err := http2Assembler.readMessage()
if err != nil { if err != nil {
return err return err
@ -66,7 +67,7 @@ func handleHTTP2Stream(http2Assembler *Http2Assembler, progress *api.ReadProgres
streamID, streamID,
"HTTP2", "HTTP2",
) )
item = reqResMatcher.registerRequest(ident, &messageHTTP1, superTimer.CaptureTime, progress.Current(), messageHTTP1.ProtoMinor) item = reqResMatcher.registerRequest(ident, &messageHTTP1, captureTime, progress.Current(), messageHTTP1.ProtoMinor)
if item != nil { if item != nil {
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
ClientIP: tcpID.SrcIP, ClientIP: tcpID.SrcIP,
@ -86,7 +87,7 @@ func handleHTTP2Stream(http2Assembler *Http2Assembler, progress *api.ReadProgres
streamID, streamID,
"HTTP2", "HTTP2",
) )
item = reqResMatcher.registerResponse(ident, &messageHTTP1, superTimer.CaptureTime, progress.Current(), messageHTTP1.ProtoMinor) item = reqResMatcher.registerResponse(ident, &messageHTTP1, captureTime, progress.Current(), messageHTTP1.ProtoMinor)
if item != nil { if item != nil {
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
ClientIP: tcpID.DstIP, ClientIP: tcpID.DstIP,
@ -111,7 +112,7 @@ func handleHTTP2Stream(http2Assembler *Http2Assembler, progress *api.ReadProgres
return nil return nil
} }
func handleHTTP1ClientStream(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) (switchingProtocolsHTTP2 bool, req *http.Request, err error) { func handleHTTP1ClientStream(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) (switchingProtocolsHTTP2 bool, req *http.Request, err error) {
req, err = http.ReadRequest(b) req, err = http.ReadRequest(b)
if err != nil { if err != nil {
return return
@ -139,7 +140,7 @@ func handleHTTP1ClientStream(b *bufio.Reader, progress *api.ReadProgress, captur
requestCounter, requestCounter,
"HTTP1", "HTTP1",
) )
item := reqResMatcher.registerRequest(ident, req, superTimer.CaptureTime, progress.Current(), req.ProtoMinor) item := reqResMatcher.registerRequest(ident, req, captureTime, progress.Current(), req.ProtoMinor)
if item != nil { if item != nil {
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
ClientIP: tcpID.SrcIP, ClientIP: tcpID.SrcIP,
@ -154,7 +155,7 @@ func handleHTTP1ClientStream(b *bufio.Reader, progress *api.ReadProgress, captur
return return
} }
func handleHTTP1ServerStream(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) (switchingProtocolsHTTP2 bool, err error) { func handleHTTP1ServerStream(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, emitter api.Emitter, options *api.TrafficFilteringOptions, reqResMatcher *requestResponseMatcher) (switchingProtocolsHTTP2 bool, err error) {
var res *http.Response var res *http.Response
res, err = http.ReadResponse(b, nil) res, err = http.ReadResponse(b, nil)
if err != nil { if err != nil {
@ -183,7 +184,7 @@ func handleHTTP1ServerStream(b *bufio.Reader, progress *api.ReadProgress, captur
responseCounter, responseCounter,
"HTTP1", "HTTP1",
) )
item := reqResMatcher.registerResponse(ident, res, superTimer.CaptureTime, progress.Current(), res.ProtoMinor) item := reqResMatcher.registerResponse(ident, res, captureTime, progress.Current(), res.ProtoMinor)
if item != nil { if item != nil {
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
ClientIP: tcpID.DstIP, ClientIP: tcpID.DstIP,

View File

@ -86,15 +86,15 @@ func (d dissecting) Ping() {
log.Printf("pong %s", http11protocol.Name) log.Printf("pong %s", http11protocol.Name)
} }
func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter, options *api.TrafficFilteringOptions, _reqResMatcher api.RequestResponseMatcher) error { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
reqResMatcher := _reqResMatcher.(*requestResponseMatcher) reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher)
var err error var err error
isHTTP2, _ := checkIsHTTP2Connection(b, isClient) isHTTP2, _ := checkIsHTTP2Connection(b, reader.GetIsClient())
var http2Assembler *Http2Assembler var http2Assembler *Http2Assembler
if isHTTP2 { if isHTTP2 {
err = prepareHTTP2Connection(b, isClient) err = prepareHTTP2Connection(b, reader.GetIsClient())
if err != nil { if err != nil {
return err return err
} }
@ -105,74 +105,74 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
for { for {
if switchingProtocolsHTTP2 { if switchingProtocolsHTTP2 {
switchingProtocolsHTTP2 = false switchingProtocolsHTTP2 = false
isHTTP2, err = checkIsHTTP2Connection(b, isClient) isHTTP2, err = checkIsHTTP2Connection(b, reader.GetIsClient())
if err != nil { if err != nil {
break break
} }
err = prepareHTTP2Connection(b, isClient) err = prepareHTTP2Connection(b, reader.GetIsClient())
if err != nil { if err != nil {
break break
} }
http2Assembler = createHTTP2Assembler(b) http2Assembler = createHTTP2Assembler(b)
} }
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &http11protocol { if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &http11protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }
if isHTTP2 { if isHTTP2 {
err = handleHTTP2Stream(http2Assembler, progress, capture, tcpID, superTimer, emitter, options, reqResMatcher) err = handleHTTP2Stream(http2Assembler, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher)
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
break break
} else if err != nil { } else if err != nil {
continue continue
} }
superIdentifier.Protocol = &http11protocol reader.GetParent().SetProtocol(&http11protocol)
} else if isClient { } else if reader.GetIsClient() {
var req *http.Request var req *http.Request
switchingProtocolsHTTP2, req, err = handleHTTP1ClientStream(b, progress, capture, tcpID, counterPair, superTimer, emitter, options, reqResMatcher) switchingProtocolsHTTP2, req, err = handleHTTP1ClientStream(b, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher)
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
break break
} else if err != nil { } else if err != nil {
continue continue
} }
superIdentifier.Protocol = &http11protocol reader.GetParent().SetProtocol(&http11protocol)
// In case of an HTTP2 upgrade, duplicate the HTTP1 request into HTTP2 with stream ID 1 // In case of an HTTP2 upgrade, duplicate the HTTP1 request into HTTP2 with stream ID 1
if switchingProtocolsHTTP2 { if switchingProtocolsHTTP2 {
ident := fmt.Sprintf( ident := fmt.Sprintf(
"%s_%s_%s_%s_1_%s", "%s_%s_%s_%s_1_%s",
tcpID.SrcIP, reader.GetTcpID().SrcIP,
tcpID.DstIP, reader.GetTcpID().DstIP,
tcpID.SrcPort, reader.GetTcpID().SrcPort,
tcpID.DstPort, reader.GetTcpID().DstPort,
"HTTP2", "HTTP2",
) )
item := reqResMatcher.registerRequest(ident, req, superTimer.CaptureTime, progress.Current(), req.ProtoMinor) item := reqResMatcher.registerRequest(ident, req, reader.GetCaptureTime(), reader.GetReadProgress().Current(), req.ProtoMinor)
if item != nil { if item != nil {
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
ClientIP: tcpID.SrcIP, ClientIP: reader.GetTcpID().SrcIP,
ClientPort: tcpID.SrcPort, ClientPort: reader.GetTcpID().SrcPort,
ServerIP: tcpID.DstIP, ServerIP: reader.GetTcpID().DstIP,
ServerPort: tcpID.DstPort, ServerPort: reader.GetTcpID().DstPort,
IsOutgoing: true, IsOutgoing: true,
} }
item.Capture = capture item.Capture = reader.GetParent().GetOrigin()
filterAndEmit(item, emitter, options) filterAndEmit(item, reader.GetEmitter(), options)
} }
} }
} else { } else {
switchingProtocolsHTTP2, err = handleHTTP1ServerStream(b, progress, capture, tcpID, counterPair, superTimer, emitter, options, reqResMatcher) switchingProtocolsHTTP2, err = handleHTTP1ServerStream(b, reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), options, reqResMatcher)
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
break break
} else if err != nil { } else if err != nil {
continue continue
} }
superIdentifier.Protocol = &http11protocol reader.GetParent().SetProtocol(&http11protocol)
} }
} }
if superIdentifier.Protocol == nil { if reader.GetParent().GetProtoIdentifier().Protocol == nil {
return err return err
} }

View File

@ -108,7 +108,6 @@ func TestDissect(t *testing.T) {
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
superIdentifier := &api.SuperIdentifier{}
// Request // Request
pathClient := _path pathClient := _path
@ -124,7 +123,21 @@ func TestDissect(t *testing.T) {
DstPort: "2", DstPort: "2",
} }
reqResMatcher := dissector.NewResponseRequestMatcher() reqResMatcher := dissector.NewResponseRequestMatcher()
err = dissector.Dissect(bufferClient, &api.ReadProgress{}, api.Pcap, true, tcpIDClient, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) stream := NewTcpStream(api.Pcap)
reader := NewTcpReader(
&api.ReadProgress{},
"",
tcpIDClient,
time.Time{},
stream,
true,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferClient, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
panic(err) panic(err)
} }
@ -142,7 +155,20 @@ func TestDissect(t *testing.T) {
SrcPort: "2", SrcPort: "2",
DstPort: "1", DstPort: "1",
} }
err = dissector.Dissect(bufferServer, &api.ReadProgress{}, api.Pcap, false, tcpIDServer, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) reader = NewTcpReader(
&api.ReadProgress{},
"",
tcpIDServer,
time.Time{},
stream,
false,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferServer, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
panic(err) panic(err)
} }

View File

@ -0,0 +1,84 @@
package http
import (
"sync"
"time"
"github.com/up9inc/mizu/tap/api"
)
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
progress *api.ReadProgress
captureTime time.Time
parent api.TcpStream
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpReader(progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader {
return &tcpReader{
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
}
}
func (reader *tcpReader) Read(p []byte) (int, error) {
return 0, nil
}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher
}
func (reader *tcpReader) GetIsClient() bool {
return reader.isClient
}
func (reader *tcpReader) GetReadProgress() *api.ReadProgress {
return reader.progress
}
func (reader *tcpReader) GetParent() api.TcpStream {
return reader.parent
}
func (reader *tcpReader) GetTcpID() *api.TcpID {
return reader.tcpID
}
func (reader *tcpReader) GetCounterPair() *api.CounterPair {
return reader.counterPair
}
func (reader *tcpReader) GetCaptureTime() time.Time {
return reader.captureTime
}
func (reader *tcpReader) GetEmitter() api.Emitter {
return reader.emitter
}
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -0,0 +1,45 @@
package http
import (
"sync"
"github.com/up9inc/mizu/tap/api"
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
}
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
}
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
return t.reqResMatcher
}
func (t *tcpStream) GetIsTapTarget() bool {
return t.isTapTarget
}
func (t *tcpStream) GetIsClosed() bool {
return t.isClosed
}

View File

@ -22,3 +22,5 @@ require (
) )
replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api
replace github.com/up9inc/mizu/logger v0.0.0 => ../../../logger

View File

@ -35,25 +35,25 @@ func (d dissecting) Ping() {
log.Printf("pong %s", _protocol.Name) log.Printf("pong %s", _protocol.Name)
} }
func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter, options *api.TrafficFilteringOptions, _reqResMatcher api.RequestResponseMatcher) error { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
reqResMatcher := _reqResMatcher.(*requestResponseMatcher) reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher)
for { for {
if superIdentifier.Protocol != nil && superIdentifier.Protocol != &_protocol { if reader.GetParent().GetProtoIdentifier().Protocol != nil && reader.GetParent().GetProtoIdentifier().Protocol != &_protocol {
return errors.New("Identified by another protocol") return errors.New("Identified by another protocol")
} }
if isClient { if reader.GetIsClient() {
_, _, err := ReadRequest(b, tcpID, counterPair, superTimer, reqResMatcher) _, _, err := ReadRequest(b, reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reqResMatcher)
if err != nil { if err != nil {
return err return err
} }
superIdentifier.Protocol = &_protocol reader.GetParent().SetProtocol(&_protocol)
} else { } else {
err := ReadResponse(b, capture, tcpID, counterPair, superTimer, emitter, reqResMatcher) err := ReadResponse(b, reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), reqResMatcher)
if err != nil { if err != nil {
return err return err
} }
superIdentifier.Protocol = &_protocol reader.GetParent().SetProtocol(&_protocol)
} }
} }
} }

View File

@ -106,7 +106,6 @@ func TestDissect(t *testing.T) {
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
superIdentifier := &api.SuperIdentifier{}
// Request // Request
pathClient := _path pathClient := _path
@ -123,7 +122,21 @@ func TestDissect(t *testing.T) {
} }
reqResMatcher := dissector.NewResponseRequestMatcher() reqResMatcher := dissector.NewResponseRequestMatcher()
reqResMatcher.SetMaxTry(10) reqResMatcher.SetMaxTry(10)
err = dissector.Dissect(bufferClient, &api.ReadProgress{}, api.Pcap, true, tcpIDClient, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) stream := NewTcpStream(api.Pcap)
reader := NewTcpReader(
&api.ReadProgress{},
"",
tcpIDClient,
time.Time{},
stream,
true,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferClient, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println(err) log.Println(err)
} }
@ -141,7 +154,20 @@ func TestDissect(t *testing.T) {
SrcPort: "2", SrcPort: "2",
DstPort: "1", DstPort: "1",
} }
err = dissector.Dissect(bufferServer, &api.ReadProgress{}, api.Pcap, false, tcpIDServer, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) reader = NewTcpReader(
&api.ReadProgress{},
"",
tcpIDServer,
time.Time{},
stream,
false,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferServer, reader, options)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println(err) log.Println(err)
} }

View File

@ -19,7 +19,7 @@ type Request struct {
CaptureTime time.Time `json:"captureTime"` CaptureTime time.Time `json:"captureTime"`
} }
func ReadRequest(r io.Reader, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, reqResMatcher *requestResponseMatcher) (apiKey ApiKey, apiVersion int16, err error) { func ReadRequest(r io.Reader, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, reqResMatcher *requestResponseMatcher) (apiKey ApiKey, apiVersion int16, err error) {
d := &decoder{reader: r, remain: 4} d := &decoder{reader: r, remain: 4}
size := d.readInt32() size := d.readInt32()
@ -206,7 +206,7 @@ func ReadRequest(r io.Reader, tcpID *api.TcpID, counterPair *api.CounterPair, su
ApiVersion: apiVersion, ApiVersion: apiVersion,
CorrelationID: correlationID, CorrelationID: correlationID,
ClientID: clientID, ClientID: clientID,
CaptureTime: superTimer.CaptureTime, CaptureTime: captureTime,
Payload: payload, Payload: payload,
} }

View File

@ -16,7 +16,7 @@ type Response struct {
CaptureTime time.Time `json:"captureTime"` CaptureTime time.Time `json:"captureTime"`
} }
func ReadResponse(r io.Reader, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter, reqResMatcher *requestResponseMatcher) (err error) { func ReadResponse(r io.Reader, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, emitter api.Emitter, reqResMatcher *requestResponseMatcher) (err error) {
d := &decoder{reader: r, remain: 4} d := &decoder{reader: r, remain: 4}
size := d.readInt32() size := d.readInt32()
@ -43,7 +43,7 @@ func ReadResponse(r io.Reader, capture api.Capture, tcpID *api.TcpID, counterPai
Size: size, Size: size,
CorrelationID: correlationID, CorrelationID: correlationID,
Payload: payload, Payload: payload,
CaptureTime: superTimer.CaptureTime, CaptureTime: captureTime,
} }
key := fmt.Sprintf( key := fmt.Sprintf(

View File

@ -0,0 +1,84 @@
package kafka
import (
"sync"
"time"
"github.com/up9inc/mizu/tap/api"
)
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
progress *api.ReadProgress
captureTime time.Time
parent api.TcpStream
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpReader(progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader {
return &tcpReader{
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
}
}
func (reader *tcpReader) Read(p []byte) (int, error) {
return 0, nil
}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher
}
func (reader *tcpReader) GetIsClient() bool {
return reader.isClient
}
func (reader *tcpReader) GetReadProgress() *api.ReadProgress {
return reader.progress
}
func (reader *tcpReader) GetParent() api.TcpStream {
return reader.parent
}
func (reader *tcpReader) GetTcpID() *api.TcpID {
return reader.tcpID
}
func (reader *tcpReader) GetCounterPair() *api.CounterPair {
return reader.counterPair
}
func (reader *tcpReader) GetCaptureTime() time.Time {
return reader.captureTime
}
func (reader *tcpReader) GetEmitter() api.Emitter {
return reader.emitter
}
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -0,0 +1,45 @@
package kafka
import (
"sync"
"github.com/up9inc/mizu/tap/api"
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
}
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
}
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
return t.reqResMatcher
}
func (t *tcpStream) GetIsTapTarget() bool {
return t.isTapTarget
}
func (t *tcpStream) GetIsClosed() bool {
return t.isClosed
}

View File

@ -15,3 +15,5 @@ require (
) )
replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api replace github.com/up9inc/mizu/tap/api v0.0.0 => ../../api
replace github.com/up9inc/mizu/logger v0.0.0 => ../../../logger

View File

@ -2,11 +2,12 @@ package redis
import ( import (
"fmt" "fmt"
"time"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
) )
func handleClientStream(progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter, request *RedisPacket, reqResMatcher *requestResponseMatcher) error { func handleClientStream(progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, emitter api.Emitter, request *RedisPacket, reqResMatcher *requestResponseMatcher) error {
counterPair.Lock() counterPair.Lock()
counterPair.Request++ counterPair.Request++
requestCounter := counterPair.Request requestCounter := counterPair.Request
@ -21,7 +22,7 @@ func handleClientStream(progress *api.ReadProgress, capture api.Capture, tcpID *
requestCounter, requestCounter,
) )
item := reqResMatcher.registerRequest(ident, request, superTimer.CaptureTime, progress.Current()) item := reqResMatcher.registerRequest(ident, request, captureTime, progress.Current())
if item != nil { if item != nil {
item.Capture = capture item.Capture = capture
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{
@ -36,7 +37,7 @@ func handleClientStream(progress *api.ReadProgress, capture api.Capture, tcpID *
return nil return nil
} }
func handleServerStream(progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, emitter api.Emitter, response *RedisPacket, reqResMatcher *requestResponseMatcher) error { func handleServerStream(progress *api.ReadProgress, capture api.Capture, tcpID *api.TcpID, counterPair *api.CounterPair, captureTime time.Time, emitter api.Emitter, response *RedisPacket, reqResMatcher *requestResponseMatcher) error {
counterPair.Lock() counterPair.Lock()
counterPair.Response++ counterPair.Response++
responseCounter := counterPair.Response responseCounter := counterPair.Response
@ -51,7 +52,7 @@ func handleServerStream(progress *api.ReadProgress, capture api.Capture, tcpID *
responseCounter, responseCounter,
) )
item := reqResMatcher.registerResponse(ident, response, superTimer.CaptureTime, progress.Current()) item := reqResMatcher.registerResponse(ident, response, captureTime, progress.Current())
if item != nil { if item != nil {
item.Capture = capture item.Capture = capture
item.ConnectionInfo = &api.ConnectionInfo{ item.ConnectionInfo = &api.ConnectionInfo{

View File

@ -34,8 +34,8 @@ func (d dissecting) Ping() {
log.Printf("pong %s", protocol.Name) log.Printf("pong %s", protocol.Name)
} }
func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture api.Capture, isClient bool, tcpID *api.TcpID, counterPair *api.CounterPair, superTimer *api.SuperTimer, superIdentifier *api.SuperIdentifier, emitter api.Emitter, options *api.TrafficFilteringOptions, _reqResMatcher api.RequestResponseMatcher) error { func (d dissecting) Dissect(b *bufio.Reader, reader api.TcpReader, options *api.TrafficFilteringOptions) error {
reqResMatcher := _reqResMatcher.(*requestResponseMatcher) reqResMatcher := reader.GetReqResMatcher().(*requestResponseMatcher)
is := &RedisInputStream{ is := &RedisInputStream{
Reader: b, Reader: b,
Buf: make([]byte, 8192), Buf: make([]byte, 8192),
@ -47,10 +47,10 @@ func (d dissecting) Dissect(b *bufio.Reader, progress *api.ReadProgress, capture
return err return err
} }
if isClient { if reader.GetIsClient() {
err = handleClientStream(progress, capture, tcpID, counterPair, superTimer, emitter, redisPacket, reqResMatcher) err = handleClientStream(reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), redisPacket, reqResMatcher)
} else { } else {
err = handleServerStream(progress, capture, tcpID, counterPair, superTimer, emitter, redisPacket, reqResMatcher) err = handleServerStream(reader.GetReadProgress(), reader.GetParent().GetOrigin(), reader.GetTcpID(), reader.GetCounterPair(), reader.GetCaptureTime(), reader.GetEmitter(), redisPacket, reqResMatcher)
} }
if err != nil { if err != nil {

View File

@ -107,7 +107,6 @@ func TestDissect(t *testing.T) {
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
superIdentifier := &api.SuperIdentifier{}
// Request // Request
pathClient := _path pathClient := _path
@ -123,7 +122,21 @@ func TestDissect(t *testing.T) {
DstPort: "2", DstPort: "2",
} }
reqResMatcher := dissector.NewResponseRequestMatcher() reqResMatcher := dissector.NewResponseRequestMatcher()
err = dissector.Dissect(bufferClient, &api.ReadProgress{}, api.Pcap, true, tcpIDClient, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) stream := NewTcpStream(api.Pcap)
reader := NewTcpReader(
&api.ReadProgress{},
"",
tcpIDClient,
time.Time{},
stream,
true,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferClient, reader, options)
if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println(err) log.Println(err)
} }
@ -141,7 +154,20 @@ func TestDissect(t *testing.T) {
SrcPort: "2", SrcPort: "2",
DstPort: "1", DstPort: "1",
} }
err = dissector.Dissect(bufferServer, &api.ReadProgress{}, api.Pcap, false, tcpIDServer, counterPair, &api.SuperTimer{}, superIdentifier, emitter, options, reqResMatcher) reader = NewTcpReader(
&api.ReadProgress{},
"",
tcpIDServer,
time.Time{},
stream,
false,
false,
nil,
emitter,
counterPair,
reqResMatcher,
)
err = dissector.Dissect(bufferServer, reader, options)
if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF { if err != nil && reflect.TypeOf(err) != reflect.TypeOf(&ConnectError{}) && err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println(err) log.Println(err)
} }

View File

@ -0,0 +1,84 @@
package redis
import (
"sync"
"time"
"github.com/up9inc/mizu/tap/api"
)
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
progress *api.ReadProgress
captureTime time.Time
parent api.TcpStream
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpReader(progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader {
return &tcpReader{
progress: progress,
ident: ident,
tcpID: tcpId,
captureTime: captureTime,
parent: parent,
isClient: isClient,
isOutgoing: isOutgoing,
extension: extension,
emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
}
}
func (reader *tcpReader) Read(p []byte) (int, error) {
return 0, nil
}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher
}
func (reader *tcpReader) GetIsClient() bool {
return reader.isClient
}
func (reader *tcpReader) GetReadProgress() *api.ReadProgress {
return reader.progress
}
func (reader *tcpReader) GetParent() api.TcpStream {
return reader.parent
}
func (reader *tcpReader) GetTcpID() *api.TcpID {
return reader.tcpID
}
func (reader *tcpReader) GetCounterPair() *api.CounterPair {
return reader.counterPair
}
func (reader *tcpReader) GetCaptureTime() time.Time {
return reader.captureTime
}
func (reader *tcpReader) GetEmitter() api.Emitter {
return reader.emitter
}
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -0,0 +1,45 @@
package redis
import (
"sync"
"github.com/up9inc/mizu/tap/api"
)
type tcpStream struct {
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func NewTcpStream(capture api.Capture) api.TcpStream {
return &tcpStream{
origin: capture,
protoIdentifier: &api.ProtoIdentifier{},
}
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
}
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
return t.reqResMatcher
}
func (t *tcpStream) GetIsTapTarget() bool {
return t.isTapTarget
}
func (t *tcpStream) GetIsClosed() bool {
return t.isClosed
}

View File

@ -7,7 +7,6 @@ require (
github.com/go-errors/errors v1.4.2 github.com/go-errors/errors v1.4.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/up9inc/mizu/logger v0.0.0 github.com/up9inc/mizu/logger v0.0.0
github.com/up9inc/mizu/shared v0.0.0
github.com/up9inc/mizu/tap/api v0.0.0 github.com/up9inc/mizu/tap/api v0.0.0
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74
k8s.io/api v0.23.3 k8s.io/api v0.23.3
@ -16,7 +15,6 @@ require (
require ( require (
github.com/go-logr/logr v1.2.2 // indirect github.com/go-logr/logr v1.2.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.2.0 // indirect
github.com/google/go-cmp v0.5.7 // indirect github.com/google/go-cmp v0.5.7 // indirect
github.com/google/gofuzz v1.2.0 // indirect github.com/google/gofuzz v1.2.0 // indirect
github.com/google/martian v2.1.0+incompatible // indirect github.com/google/martian v2.1.0+incompatible // indirect
@ -29,12 +27,12 @@ require (
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
k8s.io/apimachinery v0.23.3 // indirect k8s.io/apimachinery v0.23.3 // indirect
k8s.io/klog/v2 v2.40.1 // indirect k8s.io/klog/v2 v2.40.1 // indirect
k8s.io/utils v0.0.0-20220127004650-9b3446523e65 // indirect k8s.io/utils v0.0.0-20220127004650-9b3446523e65 // indirect
sigs.k8s.io/json v0.0.0-20211208200746-9f7c6b3444d2 // indirect sigs.k8s.io/json v0.0.0-20211208200746-9f7c6b3444d2 // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.2.1 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.1 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
) )
replace github.com/up9inc/mizu/logger v0.0.0 => ../logger replace github.com/up9inc/mizu/logger v0.0.0 => ../logger

File diff suppressed because it is too large Load Diff

View File

@ -181,7 +181,7 @@ func initializePacketSources() error {
return err return err
} }
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*tcpStreamMap, *tcpAssembler) { func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (api.TcpStreamMap, *tcpAssembler) {
streamsMap := NewTcpStreamMap() streamsMap := NewTcpStreamMap()
diagnose.InitializeErrorsMap(*debug, *verbose, *quiet) diagnose.InitializeErrorsMap(*debug, *verbose, *quiet)
@ -198,8 +198,8 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI
return streamsMap, assembler return streamsMap, assembler
} }
func startPassiveTapper(streamsMap *tcpStreamMap, assembler *tcpAssembler) { func startPassiveTapper(streamsMap api.TcpStreamMap, assembler *tcpAssembler) {
go streamsMap.closeTimedoutTcpStreamChannels() go streamsMap.CloseTimedoutTcpStreamChannels()
diagnose.AppStats.SetStartTime(time.Now()) diagnose.AppStats.SetStartTime(time.Now())

View File

@ -4,19 +4,24 @@ import (
"os" "os"
"strconv" "strconv"
"time" "time"
"github.com/up9inc/mizu/logger"
) )
const ( const (
MemoryProfilingEnabledEnvVarName = "MEMORY_PROFILING_ENABLED" MemoryProfilingEnabledEnvVarName = "MEMORY_PROFILING_ENABLED"
MemoryProfilingDumpPath = "MEMORY_PROFILING_DUMP_PATH" MemoryProfilingDumpPath = "MEMORY_PROFILING_DUMP_PATH"
MemoryProfilingTimeIntervalSeconds = "MEMORY_PROFILING_TIME_INTERVAL" MemoryProfilingTimeIntervalSeconds = "MEMORY_PROFILING_TIME_INTERVAL"
MaxBufferedPagesTotalEnvVarName = "MAX_BUFFERED_PAGES_TOTAL" MaxBufferedPagesTotalEnvVarName = "MAX_BUFFERED_PAGES_TOTAL"
MaxBufferedPagesPerConnectionEnvVarName = "MAX_BUFFERED_PAGES_PER_CONNECTION" MaxBufferedPagesPerConnectionEnvVarName = "MAX_BUFFERED_PAGES_PER_CONNECTION"
TcpStreamChannelTimeoutMsEnvVarName = "TCP_STREAM_CHANNEL_TIMEOUT_MS" MaxBufferedPagesTotalDefaultValue = 5000
CloseTimedoutTcpChannelsIntervalMsEnvVar = "CLOSE_TIMEDOUT_TCP_STREAM_CHANNELS_INTERVAL_MS" MaxBufferedPagesPerConnectionDefaultValue = 5000
MaxBufferedPagesTotalDefaultValue = 5000 TcpStreamChannelTimeoutMsEnvVarName = "TCP_STREAM_CHANNEL_TIMEOUT_MS"
MaxBufferedPagesPerConnectionDefaultValue = 5000 TcpStreamChannelTimeoutMsDefaultValue = 10000
TcpStreamChannelTimeoutMsDefaultValue = 10000 CloseTimedoutTcpChannelsIntervalMsEnvVarName = "CLOSE_TIMEDOUT_TCP_STREAM_CHANNELS_INTERVAL_MS"
CloseTimedoutTcpChannelsIntervalMsDefaultValue = 1000
CloseTimedoutTcpChannelsIntervalMsMinValue = 10
CloseTimedoutTcpChannelsIntervalMsMaxValue = 10000
) )
func GetMaxBufferedPagesTotal() int { func GetMaxBufferedPagesTotal() int {
@ -35,6 +40,10 @@ func GetMaxBufferedPagesPerConnection() int {
return valueFromEnv return valueFromEnv
} }
func GetMemoryProfilingEnabled() bool {
return os.Getenv(MemoryProfilingEnabledEnvVarName) == "1"
}
func GetTcpChannelTimeoutMs() time.Duration { func GetTcpChannelTimeoutMs() time.Duration {
valueFromEnv, err := strconv.Atoi(os.Getenv(TcpStreamChannelTimeoutMsEnvVarName)) valueFromEnv, err := strconv.Atoi(os.Getenv(TcpStreamChannelTimeoutMsEnvVarName))
if err != nil { if err != nil {
@ -43,6 +52,25 @@ func GetTcpChannelTimeoutMs() time.Duration {
return time.Duration(valueFromEnv) * time.Millisecond return time.Duration(valueFromEnv) * time.Millisecond
} }
func GetMemoryProfilingEnabled() bool { func GetCloseTimedoutTcpChannelsInterval() time.Duration {
return os.Getenv(MemoryProfilingEnabledEnvVarName) == "1" defaultDuration := CloseTimedoutTcpChannelsIntervalMsDefaultValue * time.Millisecond
rangeMin := CloseTimedoutTcpChannelsIntervalMsMinValue
rangeMax := CloseTimedoutTcpChannelsIntervalMsMaxValue
closeTimedoutTcpChannelsIntervalMsStr := os.Getenv(CloseTimedoutTcpChannelsIntervalMsEnvVarName)
if closeTimedoutTcpChannelsIntervalMsStr == "" {
return defaultDuration
} else {
closeTimedoutTcpChannelsIntervalMs, err := strconv.Atoi(closeTimedoutTcpChannelsIntervalMsStr)
if err != nil {
logger.Log.Warningf("Error parsing environment variable %s: %v\n", CloseTimedoutTcpChannelsIntervalMsEnvVarName, err)
return defaultDuration
} else {
if closeTimedoutTcpChannelsIntervalMs < rangeMin || closeTimedoutTcpChannelsIntervalMs > rangeMax {
logger.Log.Warningf("The value of environment variable %s is not in acceptable range: %d - %d\n", CloseTimedoutTcpChannelsIntervalMsEnvVarName, rangeMin, rangeMax)
return defaultDuration
} else {
return time.Duration(closeTimedoutTcpChannelsIntervalMs) * time.Millisecond
}
}
}
} }

View File

@ -36,7 +36,7 @@ func (c *context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo return c.CaptureInfo
} }
func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *tcpStreamMap, opts *TapOpts) *tcpAssembler { func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpAssembler {
var emitter api.Emitter = &api.Emitting{ var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats, AppStats: &diagnose.AppStats,
OutputChannel: outputItems, OutputChannel: outputItems,
@ -82,12 +82,6 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
if tcp != nil { if tcp != nil {
diagnose.AppStats.IncTcpPacketsCount() diagnose.AppStats.IncTcpPacketsCount()
tcp := tcp.(*layers.TCP) tcp := tcp.(*layers.TCP)
if *checksum {
err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer())
if err != nil {
logger.Log.Fatalf("Failed to set network layer for checksum: %s", err)
}
}
c := context{ c := context{
CaptureInfo: packet.Metadata().CaptureInfo, CaptureInfo: packet.Metadata().CaptureInfo,

View File

@ -11,22 +11,9 @@ import (
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
) )
type tcpReaderDataMsg struct { /* TcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
bytes []byte
timestamp time.Time
}
type ConnectionInfo struct {
ClientIP string
ClientPort string
ServerIP string
ServerPort string
IsOutgoing bool
}
/* tcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
* The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection. * The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection.
* An tcpReader object is unidirectional: it parses either a client stream or a server stream. * An TcpReader object is unidirectional: it parses either a client stream or a server stream.
* Implements io.Reader interface (Read) * Implements io.Reader interface (Read)
*/ */
type tcpReader struct { type tcpReader struct {
@ -35,11 +22,11 @@ type tcpReader struct {
isClosed bool isClosed bool
isClient bool isClient bool
isOutgoing bool isOutgoing bool
msgQueue chan tcpReaderDataMsg // Channel of captured reassembled tcp payload msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte data []byte
progress *api.ReadProgress progress *api.ReadProgress
superTimer *api.SuperTimer captureTime time.Time
parent *tcpStream parent api.TcpStream
packetsSeen uint packetsSeen uint
extension *api.Extension extension *api.Extension
emitter api.Emitter emitter api.Emitter
@ -48,47 +35,114 @@ type tcpReader struct {
sync.Mutex sync.Mutex
} }
func (h *tcpReader) Read(p []byte) (int, error) { func NewTcpReader(msgQueue chan api.TcpReaderDataMsg, progress *api.ReadProgress, ident string, tcpId *api.TcpID, captureTime time.Time, parent api.TcpStream, isClient bool, isOutgoing bool, extension *api.Extension, emitter api.Emitter, counterPair *api.CounterPair, reqResMatcher api.RequestResponseMatcher) api.TcpReader {
var msg tcpReaderDataMsg return &tcpReader{
msgQueue: msgQueue,
ok := true progress: progress,
for ok && len(h.data) == 0 { ident: ident,
msg, ok = <-h.msgQueue tcpID: tcpId,
h.data = msg.bytes captureTime: captureTime,
parent: parent,
h.superTimer.CaptureTime = msg.timestamp isClient: isClient,
if len(h.data) > 0 { isOutgoing: isOutgoing,
h.packetsSeen += 1 extension: extension,
} emitter: emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
} }
if !ok || len(h.data) == 0 {
return 0, io.EOF
}
l := copy(p, h.data)
h.data = h.data[l:]
h.progress.Feed(l)
return l, nil
} }
func (h *tcpReader) Close() { func (reader *tcpReader) run(options *api.TrafficFilteringOptions, wg *sync.WaitGroup) {
h.Lock()
if !h.isClosed {
h.isClosed = true
close(h.msgQueue)
}
h.Unlock()
}
func (h *tcpReader) run(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
b := bufio.NewReader(h) b := bufio.NewReader(reader)
err := h.extension.Dissector.Dissect(b, h.progress, h.parent.origin, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter, filteringOptions, h.reqResMatcher) err := reader.extension.Dissector.Dissect(b, reader, options)
if err != nil { if err != nil {
_, err = io.Copy(ioutil.Discard, b) _, err = io.Copy(ioutil.Discard, reader)
if err != nil { if err != nil {
logger.Log.Errorf("%v", err) logger.Log.Errorf("%v", err)
} }
} }
} }
func (reader *tcpReader) close() {
reader.Lock()
if !reader.isClosed {
reader.isClosed = true
close(reader.msgQueue)
}
reader.Unlock()
}
func (reader *tcpReader) sendMsgIfNotClosed(msg api.TcpReaderDataMsg) {
reader.Lock()
if !reader.isClosed {
reader.msgQueue <- msg
}
reader.Unlock()
}
func (reader *tcpReader) Read(p []byte) (int, error) {
var msg api.TcpReaderDataMsg
ok := true
for ok && len(reader.data) == 0 {
msg, ok = <-reader.msgQueue
if msg != nil {
reader.data = msg.GetBytes()
reader.captureTime = msg.GetTimestamp()
}
if len(reader.data) > 0 {
reader.packetsSeen += 1
}
}
if !ok || len(reader.data) == 0 {
return 0, io.EOF
}
l := copy(p, reader.data)
reader.data = reader.data[l:]
reader.progress.Feed(l)
return l, nil
}
func (reader *tcpReader) GetReqResMatcher() api.RequestResponseMatcher {
return reader.reqResMatcher
}
func (reader *tcpReader) GetIsClient() bool {
return reader.isClient
}
func (reader *tcpReader) GetReadProgress() *api.ReadProgress {
return reader.progress
}
func (reader *tcpReader) GetParent() api.TcpStream {
return reader.parent
}
func (reader *tcpReader) GetTcpID() *api.TcpID {
return reader.tcpID
}
func (reader *tcpReader) GetCounterPair() *api.CounterPair {
return reader.counterPair
}
func (reader *tcpReader) GetCaptureTime() time.Time {
return reader.captureTime
}
func (reader *tcpReader) GetEmitter() api.Emitter {
return reader.emitter
}
func (reader *tcpReader) GetIsClosed() bool {
return reader.isClosed
}
func (reader *tcpReader) GetExtension() *api.Extension {
return reader.extension
}

View File

@ -0,0 +1,24 @@
package tap
import (
"time"
"github.com/up9inc/mizu/tap/api"
)
type tcpReaderDataMsg struct {
bytes []byte
timestamp time.Time
}
func NewTcpReaderDataMsg(data []byte, timestamp time.Time) api.TcpReaderDataMsg {
return &tcpReaderDataMsg{data, timestamp}
}
func (dataMsg *tcpReaderDataMsg) GetBytes() []byte {
return dataMsg.bytes
}
func (dataMsg *tcpReaderDataMsg) GetTimestamp() time.Time {
return dataMsg.timestamp
}

View File

@ -0,0 +1,170 @@
package tap
import (
"encoding/binary"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
)
type ReassemblyStream interface {
Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool
ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext)
ReassemblyComplete(ac reassembly.AssemblerContext) bool
}
type tcpReassemblyStream struct {
ident string
tcpState *reassembly.TCPSimpleFSM
fsmerr bool
optchecker reassembly.TCPOptionCheck
isDNS bool
tcpStream api.TcpStream
}
func NewTcpReassemblyStream(ident string, tcp *layers.TCP, fsmOptions reassembly.TCPSimpleFSMOptions, stream api.TcpStream) ReassemblyStream {
return &tcpReassemblyStream{
ident: ident,
tcpState: reassembly.NewTCPSimpleFSM(fsmOptions),
optchecker: reassembly.NewTCPOptionCheck(),
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
tcpStream: stream,
}
}
func (t *tcpReassemblyStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
// FSM
if !t.tcpState.CheckState(tcp, dir) {
diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpState.String())
diagnose.InternalStats.RejectFsm++
if !t.fsmerr {
t.fsmerr = true
diagnose.InternalStats.RejectConnFsm++
}
if !*ignorefsmerr {
return false
}
}
// Options
err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start)
if err != nil {
diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err)
diagnose.InternalStats.RejectOpt++
if !*nooptcheck {
return false
}
}
// Checksum
accept := true
if *checksum {
c, err := tcp.ComputeChecksum()
if err != nil {
diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err)
accept = false
} else if c != 0x0 {
diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c)
accept = false
}
}
if !accept {
diagnose.InternalStats.RejectOpt++
}
*start = true
return accept
}
func (t *tcpReassemblyStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, _, _, skip := sg.Info()
length, saved := sg.Lengths()
// update stats
sgStats := sg.Stats()
if skip > 0 {
diagnose.InternalStats.MissedBytes += skip
}
diagnose.InternalStats.Sz += length - saved
diagnose.InternalStats.Pkt += sgStats.Packets
if sgStats.Chunks > 1 {
diagnose.InternalStats.Reassembled++
}
diagnose.InternalStats.OutOfOrderPackets += sgStats.QueuedPackets
diagnose.InternalStats.OutOfOrderBytes += sgStats.QueuedBytes
if length > diagnose.InternalStats.BiggestChunkBytes {
diagnose.InternalStats.BiggestChunkBytes = length
}
if sgStats.Packets > diagnose.InternalStats.BiggestChunkPackets {
diagnose.InternalStats.BiggestChunkPackets = sgStats.Packets
}
if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 {
// In the original example this was handled with panic().
// I don't know what this error means or how to handle it properly.
diagnose.TapErrors.SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets)
}
diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes
diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets
if skip != -1 && skip != 0 {
// Missing bytes in stream: do not even try to parse it
return
}
data := sg.Fetch(length)
if t.isDNS {
dns := &layers.DNS{}
var decoded []gopacket.LayerType
if len(data) < 2 {
if len(data) > 0 {
sg.KeepFrom(0)
}
return
}
dnsSize := binary.BigEndian.Uint16(data[:2])
missing := int(dnsSize) - len(data[2:])
diagnose.TapErrors.Debug("dnsSize: %d, missing: %d", dnsSize, missing)
if missing > 0 {
diagnose.TapErrors.Debug("Missing some bytes: %d", missing)
sg.KeepFrom(0)
return
}
p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns)
err := p.DecodeLayers(data[2:], &decoded)
if err != nil {
diagnose.TapErrors.SilentError("DNS-parser", "Failed to decode DNS: %v", err)
} else {
diagnose.TapErrors.Debug("DNS: %s", gopacket.LayerDump(dns))
}
if len(data) > 2+int(dnsSize) {
sg.KeepFrom(2 + int(dnsSize))
}
} else if t.tcpStream.GetIsTapTarget() {
if length > 0 {
// This is where we pass the reassembled information onwards
// This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp
stream := t.tcpStream.(*tcpStream)
if dir == reassembly.TCPDirClientToServer {
for i := range stream.getClients() {
reader := stream.getClient(i).(*tcpReader)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
} else {
for i := range stream.getServers() {
reader := stream.getServer(i).(*tcpReader)
reader.sendMsgIfNotClosed(NewTcpReaderDataMsg(data, timestamp))
}
}
}
}
}
func (t *tcpReassemblyStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.tcpStream.GetIsTapTarget() && !t.tcpStream.GetIsClosed() {
t.tcpStream.(*tcpStream).close()
}
// do not remove the connection to allow last ACK
return false
}

View File

@ -1,202 +1,136 @@
package tap package tap
import ( import (
"encoding/binary"
"sync" "sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
) )
/* It's a connection (bidirectional) /* It's a connection (bidirectional)
* Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete)
* ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete)
* In our implementation, we pass information from ReassembledSG to the tcpReader through a shared channel. * In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel.
*/ */
type tcpStream struct { type tcpStream struct {
id int64 id int64
isClosed bool isClosed bool
superIdentifier *api.SuperIdentifier protoIdentifier *api.ProtoIdentifier
tcpstate *reassembly.TCPSimpleFSM
fsmerr bool
optchecker reassembly.TCPOptionCheck
net, transport gopacket.Flow
isDNS bool
isTapTarget bool isTapTarget bool
clients []tcpReader clients []api.TcpReader
servers []tcpReader servers []api.TcpReader
ident string
origin api.Capture origin api.Capture
reqResMatcher api.RequestResponseMatcher
createdAt time.Time
streamsMap api.TcpStreamMap
sync.Mutex sync.Mutex
streamsMap *tcpStreamMap
} }
func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { func NewTcpStream(isTapTarget bool, streamsMap api.TcpStreamMap, capture api.Capture) api.TcpStream {
// FSM return &tcpStream{
if !t.tcpstate.CheckState(tcp, dir) { isTapTarget: isTapTarget,
diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpstate.String()) protoIdentifier: &api.ProtoIdentifier{},
diagnose.InternalStats.RejectFsm++ streamsMap: streamsMap,
if !t.fsmerr { origin: capture,
t.fsmerr = true
diagnose.InternalStats.RejectConnFsm++
}
if !*ignorefsmerr {
return false
}
}
// Options
err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start)
if err != nil {
diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err)
diagnose.InternalStats.RejectOpt++
if !*nooptcheck {
return false
}
}
// Checksum
accept := true
if *checksum {
c, err := tcp.ComputeChecksum()
if err != nil {
diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err)
accept = false
} else if c != 0x0 {
diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c)
accept = false
}
}
if !accept {
diagnose.InternalStats.RejectOpt++
}
*start = true
return accept
}
func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, _, _, skip := sg.Info()
length, saved := sg.Lengths()
// update stats
sgStats := sg.Stats()
if skip > 0 {
diagnose.InternalStats.MissedBytes += skip
}
diagnose.InternalStats.Sz += length - saved
diagnose.InternalStats.Pkt += sgStats.Packets
if sgStats.Chunks > 1 {
diagnose.InternalStats.Reassembled++
}
diagnose.InternalStats.OutOfOrderPackets += sgStats.QueuedPackets
diagnose.InternalStats.OutOfOrderBytes += sgStats.QueuedBytes
if length > diagnose.InternalStats.BiggestChunkBytes {
diagnose.InternalStats.BiggestChunkBytes = length
}
if sgStats.Packets > diagnose.InternalStats.BiggestChunkPackets {
diagnose.InternalStats.BiggestChunkPackets = sgStats.Packets
}
if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 {
// In the original example this was handled with panic().
// I don't know what this error means or how to handle it properly.
diagnose.TapErrors.SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets)
}
diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes
diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets
if skip == -1 && *allowmissinginit {
// this is allowed
} else if skip != 0 {
// Missing bytes in stream: do not even try to parse it
return
}
data := sg.Fetch(length)
if t.isDNS {
dns := &layers.DNS{}
var decoded []gopacket.LayerType
if len(data) < 2 {
if len(data) > 0 {
sg.KeepFrom(0)
}
return
}
dnsSize := binary.BigEndian.Uint16(data[:2])
missing := int(dnsSize) - len(data[2:])
diagnose.TapErrors.Debug("dnsSize: %d, missing: %d", dnsSize, missing)
if missing > 0 {
diagnose.TapErrors.Debug("Missing some bytes: %d", missing)
sg.KeepFrom(0)
return
}
p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns)
err := p.DecodeLayers(data[2:], &decoded)
if err != nil {
diagnose.TapErrors.SilentError("DNS-parser", "Failed to decode DNS: %v", err)
} else {
diagnose.TapErrors.Debug("DNS: %s", gopacket.LayerDump(dns))
}
if len(data) > 2+int(dnsSize) {
sg.KeepFrom(2 + int(dnsSize))
}
} else if t.isTapTarget {
if length > 0 {
// This is where we pass the reassembled information onwards
// This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp
if dir == reassembly.TCPDirClientToServer {
for i := range t.clients {
reader := &t.clients[i]
reader.Lock()
if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
}
reader.Unlock()
}
} else {
for i := range t.servers {
reader := &t.servers[i]
reader.Lock()
if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
}
reader.Unlock()
}
}
}
} }
} }
func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { func (t *tcpStream) getId() int64 {
if t.isTapTarget && !t.isClosed { return t.id
t.Close()
}
// do not remove the connection to allow last ACK
return false
} }
func (t *tcpStream) Close() { func (t *tcpStream) setId(id int64) {
shouldReturn := false t.id = id
}
func (t *tcpStream) close() {
t.Lock() t.Lock()
defer t.Unlock()
if t.isClosed { if t.isClosed {
shouldReturn = true
} else {
t.isClosed = true
}
t.Unlock()
if shouldReturn {
return return
} }
t.isClosed = true
t.streamsMap.Delete(t.id) t.streamsMap.Delete(t.id)
for i := range t.clients { for i := range t.clients {
reader := &t.clients[i] reader := t.clients[i]
reader.Close() reader.(*tcpReader).close()
} }
for i := range t.servers { for i := range t.servers {
reader := &t.servers[i] reader := t.servers[i]
reader.Close() reader.(*tcpReader).close()
} }
} }
func (t *tcpStream) addClient(reader api.TcpReader) {
t.clients = append(t.clients, reader)
}
func (t *tcpStream) addServer(reader api.TcpReader) {
t.servers = append(t.servers, reader)
}
func (t *tcpStream) getClients() []api.TcpReader {
return t.clients
}
func (t *tcpStream) getServers() []api.TcpReader {
return t.servers
}
func (t *tcpStream) getClient(index int) api.TcpReader {
return t.clients[index]
}
func (t *tcpStream) getServer(index int) api.TcpReader {
return t.servers[index]
}
func (t *tcpStream) SetProtocol(protocol *api.Protocol) {
t.Lock()
defer t.Unlock()
if t.protoIdentifier.IsClosedOthers {
return
}
t.protoIdentifier.Protocol = protocol
for i := range t.clients {
reader := t.clients[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.(*tcpReader).close()
}
}
for i := range t.servers {
reader := t.servers[i]
if reader.GetExtension().Protocol != t.protoIdentifier.Protocol {
reader.(*tcpReader).close()
}
}
t.protoIdentifier.IsClosedOthers = true
}
func (t *tcpStream) GetOrigin() api.Capture {
return t.origin
}
func (t *tcpStream) GetProtoIdentifier() *api.ProtoIdentifier {
return t.protoIdentifier
}
func (t *tcpStream) GetReqResMatcher() api.RequestResponseMatcher {
return t.reqResMatcher
}
func (t *tcpStream) GetIsTapTarget() bool {
return t.isTapTarget
}
func (t *tcpStream) GetIsClosed() bool {
return t.isClosed
}

View File

@ -21,19 +21,13 @@ import (
*/ */
type tcpStreamFactory struct { type tcpStreamFactory struct {
wg sync.WaitGroup wg sync.WaitGroup
Emitter api.Emitter emitter api.Emitter
streamsMap *tcpStreamMap streamsMap api.TcpStreamMap
ownIps []string ownIps []string
opts *TapOpts opts *TapOpts
} }
type tcpStreamWrapper struct { func NewTcpStreamFactory(emitter api.Emitter, streamsMap api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory {
stream *tcpStream
reqResMatcher api.RequestResponseMatcher
createdAt time.Time
}
func NewTcpStreamFactory(emitter api.Emitter, streamsMap *tcpStreamMap, opts *TapOpts) *tcpStreamFactory {
var ownIps []string var ownIps []string
if localhostIPs, err := getLocalhostIPs(); err != nil { if localhostIPs, err := getLocalhostIPs(); err != nil {
@ -46,14 +40,14 @@ func NewTcpStreamFactory(emitter api.Emitter, streamsMap *tcpStreamMap, opts *Ta
} }
return &tcpStreamFactory{ return &tcpStreamFactory{
Emitter: emitter, emitter: emitter,
streamsMap: streamsMap, streamsMap: streamsMap,
ownIps: ownIps, ownIps: ownIps,
opts: opts, opts: opts,
} }
} }
func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcpLayer *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
fsmOptions := reassembly.TCPSimpleFSMOptions{ fsmOptions := reassembly.TCPSimpleFSMOptions{
SupportMissingEstablishment: *allowmissinginit, SupportMissingEstablishment: *allowmissinginit,
} }
@ -64,78 +58,68 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort) props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
isTapTarget := props.isTapTarget isTapTarget := props.isTapTarget
stream := &tcpStream{ stream := NewTcpStream(isTapTarget, factory.streamsMap, getPacketOrigin(ac))
net: net, reassemblyStream := NewTcpReassemblyStream(fmt.Sprintf("%s:%s", net, transport), tcpLayer, fsmOptions, stream)
transport: transport, if stream.GetIsTapTarget() {
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, _stream := stream.(*tcpStream)
isTapTarget: isTapTarget, _stream.setId(factory.streamsMap.NextId())
tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions),
ident: fmt.Sprintf("%s:%s", net, transport),
optchecker: reassembly.NewTCPOptionCheck(),
superIdentifier: &api.SuperIdentifier{},
streamsMap: factory.streamsMap,
origin: getPacketOrigin(ac),
}
if stream.isTapTarget {
stream.id = factory.streamsMap.nextId()
for i, extension := range extensions { for i, extension := range extensions {
reqResMatcher := extension.Dissector.NewResponseRequestMatcher() reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
counterPair := &api.CounterPair{ counterPair := &api.CounterPair{
Request: 0, Request: 0,
Response: 0, Response: 0,
} }
stream.clients = append(stream.clients, tcpReader{ _stream.addClient(
msgQueue: make(chan tcpReaderDataMsg), NewTcpReader(
progress: &api.ReadProgress{}, make(chan api.TcpReaderDataMsg),
superTimer: &api.SuperTimer{}, &api.ReadProgress{},
ident: fmt.Sprintf("%s %s", net, transport), fmt.Sprintf("%s %s", net, transport),
tcpID: &api.TcpID{ &api.TcpID{
SrcIP: srcIp, SrcIP: srcIp,
DstIP: dstIp, DstIP: dstIp,
SrcPort: srcPort, SrcPort: srcPort,
DstPort: dstPort, DstPort: dstPort,
}, },
parent: stream, time.Time{},
isClient: true, stream,
isOutgoing: props.isOutgoing, true,
extension: extension, props.isOutgoing,
emitter: factory.Emitter, extension,
counterPair: counterPair, factory.emitter,
reqResMatcher: reqResMatcher, counterPair,
}) reqResMatcher,
stream.servers = append(stream.servers, tcpReader{ ),
msgQueue: make(chan tcpReaderDataMsg), )
progress: &api.ReadProgress{}, _stream.addServer(
superTimer: &api.SuperTimer{}, NewTcpReader(
ident: fmt.Sprintf("%s %s", net, transport), make(chan api.TcpReaderDataMsg),
tcpID: &api.TcpID{ &api.ReadProgress{},
SrcIP: net.Dst().String(), fmt.Sprintf("%s %s", net, transport),
DstIP: net.Src().String(), &api.TcpID{
SrcPort: transport.Dst().String(), SrcIP: net.Dst().String(),
DstPort: transport.Src().String(), DstIP: net.Src().String(),
}, SrcPort: transport.Dst().String(),
parent: stream, DstPort: transport.Src().String(),
isClient: false, },
isOutgoing: props.isOutgoing, time.Time{},
extension: extension, stream,
emitter: factory.Emitter, false,
counterPair: counterPair, props.isOutgoing,
reqResMatcher: reqResMatcher, extension,
}) factory.emitter,
counterPair,
reqResMatcher,
),
)
factory.streamsMap.Store(stream.id, &tcpStreamWrapper{ factory.streamsMap.Store(stream.(*tcpStream).getId(), stream)
stream: stream,
reqResMatcher: reqResMatcher,
createdAt: time.Now(),
})
factory.wg.Add(2) factory.wg.Add(2)
// Start reading from channel stream.reader.bytes go _stream.getClient(i).(*tcpReader).run(filteringOptions, &factory.wg)
go stream.clients[i].run(&factory.wg) go _stream.getServer(i).(*tcpReader).run(filteringOptions, &factory.wg)
go stream.servers[i].run(&factory.wg)
} }
} }
return stream return reassemblyStream
} }
func (factory *tcpStreamFactory) WaitGoRoutines() { func (factory *tcpStreamFactory) WaitGoRoutines() {

View File

@ -1,14 +1,12 @@
package tap package tap
import ( import (
"os"
"runtime" "runtime"
_debug "runtime/debug"
"strconv"
"sync" "sync"
"time" "time"
"github.com/up9inc/mizu/logger" "github.com/up9inc/mizu/logger"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose" "github.com/up9inc/mizu/tap/diagnose"
) )
@ -17,12 +15,16 @@ type tcpStreamMap struct {
streamId int64 streamId int64
} }
func NewTcpStreamMap() *tcpStreamMap { func NewTcpStreamMap() api.TcpStreamMap {
return &tcpStreamMap{ return &tcpStreamMap{
streams: &sync.Map{}, streams: &sync.Map{},
} }
} }
func (streamMap *tcpStreamMap) Range(f func(key, value interface{}) bool) {
streamMap.streams.Range(f)
}
func (streamMap *tcpStreamMap) Store(key, value interface{}) { func (streamMap *tcpStreamMap) Store(key, value interface{}) {
streamMap.streams.Store(key, value) streamMap.streams.Store(key, value)
} }
@ -31,66 +33,28 @@ func (streamMap *tcpStreamMap) Delete(key interface{}) {
streamMap.streams.Delete(key) streamMap.streams.Delete(key)
} }
func (streamMap *tcpStreamMap) nextId() int64 { func (streamMap *tcpStreamMap) NextId() int64 {
streamMap.streamId++ streamMap.streamId++
return streamMap.streamId return streamMap.streamId
} }
func (streamMap *tcpStreamMap) getCloseTimedoutTcpChannelsInterval() time.Duration { func (streamMap *tcpStreamMap) CloseTimedoutTcpStreamChannels() {
defaultDuration := 1000 * time.Millisecond tcpStreamChannelTimeoutMs := GetTcpChannelTimeoutMs()
rangeMin := 10 closeTimedoutTcpChannelsIntervalMs := GetCloseTimedoutTcpChannelsInterval()
rangeMax := 10000
closeTimedoutTcpChannelsIntervalMsStr := os.Getenv(CloseTimedoutTcpChannelsIntervalMsEnvVar)
if closeTimedoutTcpChannelsIntervalMsStr == "" {
return defaultDuration
} else {
closeTimedoutTcpChannelsIntervalMs, err := strconv.Atoi(closeTimedoutTcpChannelsIntervalMsStr)
if err != nil {
logger.Log.Warningf("Error parsing environment variable %s: %v\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, err)
return defaultDuration
} else {
if closeTimedoutTcpChannelsIntervalMs < rangeMin || closeTimedoutTcpChannelsIntervalMs > rangeMax {
logger.Log.Warningf("The value of environment variable %s is not in acceptable range: %d - %d\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, rangeMin, rangeMax)
return defaultDuration
} else {
return time.Duration(closeTimedoutTcpChannelsIntervalMs) * time.Millisecond
}
}
}
}
func (streamMap *tcpStreamMap) closeTimedoutTcpStreamChannels() {
tcpStreamChannelTimeout := GetTcpChannelTimeoutMs()
closeTimedoutTcpChannelsIntervalMs := streamMap.getCloseTimedoutTcpChannelsInterval()
logger.Log.Infof("Using %d ms as the close timedout TCP stream channels interval", closeTimedoutTcpChannelsIntervalMs/time.Millisecond) logger.Log.Infof("Using %d ms as the close timedout TCP stream channels interval", closeTimedoutTcpChannelsIntervalMs/time.Millisecond)
ticker := time.NewTicker(closeTimedoutTcpChannelsIntervalMs)
for { for {
time.Sleep(closeTimedoutTcpChannelsIntervalMs) <-ticker.C
_debug.FreeOSMemory()
streamMap.streams.Range(func(key interface{}, value interface{}) bool { streamMap.streams.Range(func(key interface{}, value interface{}) bool {
streamWrapper := value.(*tcpStreamWrapper) stream := value.(*tcpStream)
stream := streamWrapper.stream if stream.protoIdentifier.Protocol == nil {
if stream.superIdentifier.Protocol == nil { if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeoutMs)) {
if !stream.isClosed && time.Now().After(streamWrapper.createdAt.Add(tcpStreamChannelTimeout)) { stream.close()
stream.Close()
diagnose.AppStats.IncDroppedTcpStreams() diagnose.AppStats.IncDroppedTcpStreams()
logger.Log.Debugf("Dropped an unidentified TCP stream because of timeout. Total dropped: %d Total Goroutines: %d Timeout (ms): %d", logger.Log.Debugf("Dropped an unidentified TCP stream because of timeout. Total dropped: %d Total Goroutines: %d Timeout (ms): %d",
diagnose.AppStats.DroppedTcpStreams, runtime.NumGoroutine(), tcpStreamChannelTimeout/time.Millisecond) diagnose.AppStats.DroppedTcpStreams, runtime.NumGoroutine(), tcpStreamChannelTimeoutMs/time.Millisecond)
}
} else {
if !stream.superIdentifier.IsClosedOthers {
for i := range stream.clients {
reader := &stream.clients[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
for i := range stream.servers {
reader := &stream.servers[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
stream.superIdentifier.IsClosedOthers = true
} }
} }
return true return true

View File

@ -16,14 +16,14 @@ const FLAGS_IS_READ_BIT uint32 = (1 << 1)
// Be careful when editing, alignment and padding should be exactly the same in go/c. // Be careful when editing, alignment and padding should be exactly the same in go/c.
// //
type tlsChunk struct { type tlsChunk struct {
Pid uint32 // process id Pid uint32 // process id
Tgid uint32 // thread id inside the process Tgid uint32 // thread id inside the process
Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[]) Len uint32 // the size of the native buffer used to read/write the tls data (may be bigger than tlsChunk.Data[])
Start uint32 // the start offset withing the native buffer Start uint32 // the start offset withing the native buffer
Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[] Recorded uint32 // number of bytes copied from the native buffer to tlsChunk.Data[]
Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor) Fd uint32 // the file descriptor used to read/write the tls data (probably socket file descriptor)
Flags uint32 // bitwise flags Flags uint32 // bitwise flags
Address [16]byte // ipv4 address and port Address [16]byte // ipv4 address and port
Data [4096]byte // actual tls data Data [4096]byte // actual tls data
} }

View File

@ -21,25 +21,34 @@ import (
) )
type tlsPoller struct { type tlsPoller struct {
tls *TlsTapper tls *TlsTapper
readers map[string]*tlsReader readers map[string]api.TcpReader
closedReaders chan string closedReaders chan string
reqResMatcher api.RequestResponseMatcher reqResMatcher api.RequestResponseMatcher
chunksReader *perf.Reader chunksReader *perf.Reader
extension *api.Extension extension *api.Extension
procfs string procfs string
pidToNamespace sync.Map pidToNamespace sync.Map
isClosed bool
protoIdentifier *api.ProtoIdentifier
isTapTarget bool
origin api.Capture
createdAt time.Time
} }
func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller { func newTlsPoller(tls *TlsTapper, extension *api.Extension, procfs string) *tlsPoller {
return &tlsPoller{ return &tlsPoller{
tls: tls, tls: tls,
readers: make(map[string]*tlsReader), readers: make(map[string]api.TcpReader),
closedReaders: make(chan string, 100), closedReaders: make(chan string, 100),
reqResMatcher: extension.Dissector.NewResponseRequestMatcher(), reqResMatcher: extension.Dissector.NewResponseRequestMatcher(),
extension: extension, extension: extension,
chunksReader: nil, chunksReader: nil,
procfs: procfs, procfs: procfs,
protoIdentifier: &api.ProtoIdentifier{},
isTapTarget: true,
origin: api.Ebpf,
createdAt: time.Now(),
} }
} }
@ -126,13 +135,24 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
key := buildTlsKey(chunk, ip, port) key := buildTlsKey(chunk, ip, port)
reader, exists := p.readers[key] reader, exists := p.readers[key]
newReader := NewTlsReader(
key,
func(r *tlsReader) {
p.closeReader(key, r)
},
chunk.isRequest(),
p,
)
if !exists { if !exists {
reader = p.startNewTlsReader(chunk, ip, port, key, extension, emitter, options) reader = p.startNewTlsReader(chunk, ip, port, key, extension, newReader, options)
p.readers[key] = reader p.readers[key] = reader
} }
reader.timer.CaptureTime = time.Now() tlsReader := reader.(*tlsReader)
reader.chunks <- chunk
tlsReader.setCaptureTime(time.Now())
tlsReader.sendChunk(chunk)
if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" { if os.Getenv("MIZU_VERBOSE_TLS_TAPPER") == "true" {
p.logTls(chunk, ip, port) p.logTls(chunk, ip, port)
@ -142,40 +162,30 @@ func (p *tlsPoller) handleTlsChunk(chunk *tlsChunk, extension *api.Extension,
} }
func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension, func (p *tlsPoller) startNewTlsReader(chunk *tlsChunk, ip net.IP, port uint16, key string, extension *api.Extension,
emitter api.Emitter, options *api.TrafficFilteringOptions) *tlsReader { reader api.TcpReader, options *api.TrafficFilteringOptions) api.TcpReader {
reader := &tlsReader{
key: key,
chunks: make(chan *tlsChunk, 1),
doneHandler: func(r *tlsReader) {
p.closeReader(key, r)
},
progress: &api.ReadProgress{},
timer: api.SuperTimer{
CaptureTime: time.Now(),
},
}
tcpid := p.buildTcpId(chunk, ip, port) tcpid := p.buildTcpId(chunk, ip, port)
tlsEmitter := &tlsEmitter{ tlsReader := reader.(*tlsReader)
delegate: emitter, tlsReader.setTcpID(&tcpid)
namespace: p.getNamespace(chunk.Pid),
}
go dissect(extension, reader, chunk.isRequest(), &tcpid, tlsEmitter, options, p.reqResMatcher) tlsReader.setEmitter(&tlsEmitter{
delegate: reader.GetEmitter(),
namespace: p.getNamespace(chunk.Pid),
})
go dissect(extension, reader, options)
return reader return reader
} }
func dissect(extension *api.Extension, reader *tlsReader, isRequest bool, tcpid *api.TcpID, func dissect(extension *api.Extension, reader api.TcpReader,
tlsEmitter *tlsEmitter, options *api.TrafficFilteringOptions, reqResMatcher api.RequestResponseMatcher) { options *api.TrafficFilteringOptions) {
b := bufio.NewReader(reader) b := bufio.NewReader(reader)
err := extension.Dissector.Dissect(b, reader.progress, api.Ebpf, isRequest, tcpid, &api.CounterPair{}, err := extension.Dissector.Dissect(b, reader, options)
&reader.timer, &api.SuperIdentifier{}, tlsEmitter, options, reqResMatcher)
if err != nil { if err != nil {
logger.Log.Warningf("Error dissecting TLS %v - %v", tcpid, err) logger.Log.Warningf("Error dissecting TLS %v - %v", reader.GetTcpID(), err)
} }
} }
@ -269,3 +279,27 @@ func (p *tlsPoller) logTls(chunk *tlsChunk, ip net.IP, port uint16) {
srcIp, srcPort, dstIp, dstPort, srcIp, srcPort, dstIp, dstPort,
chunk.Recorded, chunk.Len, chunk.Start, str, hex.EncodeToString(chunk.Data[0:chunk.Recorded])) chunk.Recorded, chunk.Len, chunk.Start, str, hex.EncodeToString(chunk.Data[0:chunk.Recorded]))
} }
func (p *tlsPoller) SetProtocol(protocol *api.Protocol) {
// TODO: Implement
}
func (p *tlsPoller) GetOrigin() api.Capture {
return p.origin
}
func (p *tlsPoller) GetProtoIdentifier() *api.ProtoIdentifier {
return p.protoIdentifier
}
func (p *tlsPoller) GetReqResMatcher() api.RequestResponseMatcher {
return p.reqResMatcher
}
func (p *tlsPoller) GetIsTapTarget() bool {
return p.isTapTarget
}
func (p *tlsPoller) GetIsClosed() bool {
return p.isClosed
}

View File

@ -8,12 +8,45 @@ import (
) )
type tlsReader struct { type tlsReader struct {
key string key string
chunks chan *tlsChunk chunks chan *tlsChunk
data []byte data []byte
doneHandler func(r *tlsReader) doneHandler func(r *tlsReader)
progress *api.ReadProgress progress *api.ReadProgress
timer api.SuperTimer tcpID *api.TcpID
isClosed bool
isClient bool
captureTime time.Time
parent api.TcpStream
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
}
func NewTlsReader(key string, doneHandler func(r *tlsReader), isClient bool, stream api.TcpStream) api.TcpReader {
return &tlsReader{
key: key,
chunks: make(chan *tlsChunk, 1),
doneHandler: doneHandler,
parent: stream,
}
}
func (r *tlsReader) sendChunk(chunk *tlsChunk) {
r.chunks <- chunk
}
func (r *tlsReader) setTcpID(tcpID *api.TcpID) {
r.tcpID = tcpID
}
func (r *tlsReader) setCaptureTime(captureTime time.Time) {
r.captureTime = captureTime
}
func (r *tlsReader) setEmitter(emitter api.Emitter) {
r.emitter = emitter
} }
func (r *tlsReader) Read(p []byte) (int, error) { func (r *tlsReader) Read(p []byte) (int, error) {
@ -44,3 +77,43 @@ func (r *tlsReader) Read(p []byte) (int, error) {
return l, nil return l, nil
} }
func (r *tlsReader) GetReqResMatcher() api.RequestResponseMatcher {
return r.reqResMatcher
}
func (r *tlsReader) GetIsClient() bool {
return r.isClient
}
func (r *tlsReader) GetReadProgress() *api.ReadProgress {
return r.progress
}
func (r *tlsReader) GetParent() api.TcpStream {
return r.parent
}
func (r *tlsReader) GetTcpID() *api.TcpID {
return r.tcpID
}
func (r *tlsReader) GetCounterPair() *api.CounterPair {
return r.counterPair
}
func (r *tlsReader) GetCaptureTime() time.Time {
return r.captureTime
}
func (r *tlsReader) GetEmitter() api.Emitter {
return r.emitter
}
func (r *tlsReader) GetIsClosed() bool {
return r.isClosed
}
func (r *tlsReader) GetExtension() *api.Extension {
return r.extension
}