Refactor tap module and move some of the code to tap/api module

This commit is contained in:
M. Mert Yildiran
2022-04-20 11:32:48 +03:00
parent ba4bab3ed5
commit 110325598a
18 changed files with 1319 additions and 344 deletions

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/google/martian/har"
"github.com/up9inc/mizu/tap/api/diagnose"
)
const mizuTestEnvVar = "MIZU_TEST"
@@ -144,7 +145,7 @@ type RequestResponseMatcher interface {
}
type Emitting struct {
AppStats *AppStats
AppStats *diagnose.AppStats
OutputChannel chan *OutputChannelItem
}

View File

@@ -9,10 +9,9 @@ import (
"time"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
)
var AppStats = api.AppStats{}
var AppStatsInst = AppStats{}
func StartMemoryProfiler(envDumpPath string, envTimeInterval string) {
dumpPath := "/app/pprof"

View File

@@ -1,4 +1,4 @@
package api
package diagnose
import (
"sync/atomic"

View File

@@ -2,4 +2,12 @@ module github.com/up9inc/mizu/tap/api
go 1.17
require github.com/google/martian v2.1.0+incompatible
require (
github.com/google/gopacket v1.1.19
github.com/google/martian v2.1.0+incompatible
github.com/up9inc/mizu/shared v0.0.0
)
require github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 // indirect
replace github.com/up9inc/mizu/shared v0.0.0 => ../../shared

File diff suppressed because it is too large Load Diff

85
tap/api/tcp_reader.go Normal file
View File

@@ -0,0 +1,85 @@
package api
import (
"bufio"
"io"
"io/ioutil"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
)
type TcpReaderDataMsg struct {
bytes []byte
timestamp time.Time
}
/* TcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
* The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection.
* An TcpReader object is unidirectional: it parses either a client stream or a server stream.
* Implements io.Reader interface (Read)
*/
type TcpReader struct {
Ident string
TcpID *TcpID
isClosed bool
IsClient bool
IsOutgoing bool
MsgQueue chan TcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte
Progress *ReadProgress
SuperTimer *SuperTimer
Parent *TcpStream
packetsSeen uint
Extension *Extension
Emitter Emitter
CounterPair *CounterPair
ReqResMatcher RequestResponseMatcher
sync.Mutex
}
func (h *TcpReader) Read(p []byte) (int, error) {
var msg TcpReaderDataMsg
ok := true
for ok && len(h.data) == 0 {
msg, ok = <-h.MsgQueue
h.data = msg.bytes
h.SuperTimer.CaptureTime = msg.timestamp
if len(h.data) > 0 {
h.packetsSeen += 1
}
}
if !ok || len(h.data) == 0 {
return 0, io.EOF
}
l := copy(p, h.data)
h.data = h.data[l:]
h.Progress.Feed(l)
return l, nil
}
func (h *TcpReader) Close() {
h.Lock()
if !h.isClosed {
h.isClosed = true
close(h.MsgQueue)
}
h.Unlock()
}
func (h *TcpReader) Run(filteringOptions *TrafficFilteringOptions, wg *sync.WaitGroup) {
defer wg.Done()
b := bufio.NewReader(h)
err := h.Extension.Dissector.Dissect(b, h.Progress, h.Parent.Origin, h.IsClient, h.TcpID, h.CounterPair, h.SuperTimer, h.Parent.SuperIdentifier, h.Emitter, filteringOptions, h.ReqResMatcher)
if err != nil {
_, err = io.Copy(ioutil.Discard, b)
if err != nil {
logger.Log.Errorf("%v", err)
}
}
}

View File

@@ -1,4 +1,4 @@
package tap
package api
import (
"encoding/binary"
@@ -8,79 +8,57 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers" // pulls in all layers decoders
"github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
"github.com/up9inc/mizu/tap/api/diagnose"
)
/* It's a connection (bidirectional)
* Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete)
* ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete)
* In our implementation, we pass information from ReassembledSG to the tcpReader through a shared channel.
* In our implementation, we pass information from ReassembledSG to the TcpReader through a shared channel.
*/
type tcpStream struct {
id int64
type TcpStream struct {
Id int64
isClosed bool
superIdentifier *api.SuperIdentifier
tcpstate *reassembly.TCPSimpleFSM
SuperIdentifier *SuperIdentifier
TcpState *reassembly.TCPSimpleFSM
fsmerr bool
optchecker reassembly.TCPOptionCheck
net, transport gopacket.Flow
isDNS bool
isTapTarget bool
clients []tcpReader
servers []tcpReader
ident string
origin api.Capture
reqResMatcher api.RequestResponseMatcher
Optchecker reassembly.TCPOptionCheck
Net, Transport gopacket.Flow
IsDNS bool
IsTapTarget bool
Clients []TcpReader
Servers []TcpReader
Ident string
Origin Capture
ReqResMatcher RequestResponseMatcher
createdAt time.Time
StreamsMap *TcpStreamMap
sync.Mutex
streamsMap *tcpStreamMap
}
func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
func (t *TcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
// FSM
if !t.tcpstate.CheckState(tcp, dir) {
diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpstate.String())
if !t.TcpState.CheckState(tcp, dir) {
diagnose.TapErrors.SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.Ident, t.TcpState.String())
diagnose.InternalStats.RejectFsm++
if !t.fsmerr {
t.fsmerr = true
diagnose.InternalStats.RejectConnFsm++
}
if !*ignorefsmerr {
return false
}
}
// Options
err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start)
err := t.Optchecker.Accept(tcp, ci, dir, nextSeq, start)
if err != nil {
diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err)
diagnose.InternalStats.RejectOpt++
if !*nooptcheck {
return false
}
}
// Checksum
accept := true
if *checksum {
c, err := tcp.ComputeChecksum()
if err != nil {
diagnose.TapErrors.SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err)
accept = false
} else if c != 0x0 {
diagnose.TapErrors.SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c)
accept = false
}
}
if !accept {
diagnose.TapErrors.SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.Ident, err)
diagnose.InternalStats.RejectOpt++
}
*start = true
return accept
return true
}
func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
func (t *TcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, _, _, skip := sg.Info()
length, saved := sg.Lengths()
// update stats
@@ -109,14 +87,12 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
diagnose.InternalStats.OverlapBytes += sgStats.OverlapBytes
diagnose.InternalStats.OverlapPackets += sgStats.OverlapPackets
if skip == -1 && *allowmissinginit {
// this is allowed
} else if skip != 0 {
if skip != -1 && skip != 0 {
// Missing bytes in stream: do not even try to parse it
return
}
data := sg.Fetch(length)
if t.isDNS {
if t.IsDNS {
dns := &layers.DNS{}
var decoded []gopacket.LayerType
if len(data) < 2 {
@@ -143,27 +119,27 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
if len(data) > 2+int(dnsSize) {
sg.KeepFrom(2 + int(dnsSize))
}
} else if t.isTapTarget {
} else if t.IsTapTarget {
if length > 0 {
// This is where we pass the reassembled information onwards
// This channel is read by an tcpReader object
diagnose.AppStats.IncReassembledTcpPayloadsCount()
diagnose.AppStatsInst.IncReassembledTcpPayloadsCount()
timestamp := ac.GetCaptureInfo().Timestamp
if dir == reassembly.TCPDirClientToServer {
for i := range t.clients {
reader := &t.clients[i]
for i := range t.Clients {
reader := &t.Clients[i]
reader.Lock()
if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
reader.MsgQueue <- TcpReaderDataMsg{data, timestamp}
}
reader.Unlock()
}
} else {
for i := range t.servers {
reader := &t.servers[i]
for i := range t.Servers {
reader := &t.Servers[i]
reader.Lock()
if !reader.isClosed {
reader.msgQueue <- tcpReaderDataMsg{data, timestamp}
reader.MsgQueue <- TcpReaderDataMsg{data, timestamp}
}
reader.Unlock()
}
@@ -172,15 +148,15 @@ func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
}
}
func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.isTapTarget && !t.isClosed {
func (t *TcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
if t.IsTapTarget && !t.isClosed {
t.Close()
}
// do not remove the connection to allow last ACK
return false
}
func (t *tcpStream) Close() {
func (t *TcpStream) Close() {
shouldReturn := false
t.Lock()
if t.isClosed {
@@ -192,14 +168,14 @@ func (t *tcpStream) Close() {
if shouldReturn {
return
}
t.streamsMap.Delete(t.id)
t.StreamsMap.Delete(t.Id)
for i := range t.clients {
reader := &t.clients[i]
for i := range t.Clients {
reader := &t.Clients[i]
reader.Close()
}
for i := range t.servers {
reader := &t.servers[i]
for i := range t.Servers {
reader := &t.Servers[i]
reader.Close()
}
}

View File

@@ -0,0 +1,29 @@
package api
import (
"sync"
)
type TcpStreamMap struct {
Streams *sync.Map
streamId int64
}
func NewTcpStreamMap() *TcpStreamMap {
return &TcpStreamMap{
Streams: &sync.Map{},
}
}
func (streamMap *TcpStreamMap) Store(key, value interface{}) {
streamMap.Streams.Store(key, value)
}
func (streamMap *TcpStreamMap) Delete(key interface{}) {
streamMap.Streams.Delete(key)
}
func (streamMap *TcpStreamMap) NextId() int64 {
streamMap.streamId++
return streamMap.streamId
}

View File

@@ -22,7 +22,7 @@ type Cleaner struct {
connectionTimeout time.Duration
stats CleanerStats
statsMutex sync.Mutex
streamsMap *tcpStreamMap
streamsMap *api.TcpStreamMap
}
func (cl *Cleaner) clean() {
@@ -33,8 +33,8 @@ func (cl *Cleaner) clean() {
flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout))
cl.assemblerMutex.Unlock()
cl.streamsMap.streams.Range(func(k, v interface{}) bool {
reqResMatcher := v.(*tcpStream).reqResMatcher
cl.streamsMap.Streams.Range(func(k, v interface{}) bool {
reqResMatcher := v.(*api.TcpStream).ReqResMatcher
if reqResMatcher == nil {
return true
}

View File

@@ -4,7 +4,7 @@ import (
"net"
"strings"
"github.com/up9inc/mizu/tap/diagnose"
"github.com/up9inc/mizu/tap/api/diagnose"
)
var privateIPBlocks []*net.IPNet

View File

@@ -19,7 +19,7 @@ import (
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
"github.com/up9inc/mizu/tap/api/diagnose"
"github.com/up9inc/mizu/tap/source"
"github.com/up9inc/mizu/tap/tlstapper"
v1 "k8s.io/api/core/v1"
@@ -31,11 +31,8 @@ var maxcount = flag.Int64("c", -1, "Only grab this many packets, then exit")
var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)")
var statsevery = flag.Int("stats", 60, "Output statistics every N seconds")
var lazy = flag.Bool("lazy", false, "If true, do lazy decoding")
var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag")
var checksum = flag.Bool("checksum", false, "Check TCP checksum") // global
var nooptcheck = flag.Bool("nooptcheck", true, "Do not check TCP options (useful to ignore MSS on captures with TSO)") // global
var ignorefsmerr = flag.Bool("ignorefsmerr", true, "Ignore TCP FSM errors") // global
var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global
var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") // global
var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global
var verbose = flag.Bool("verbose", false, "Be verbose")
var debug = flag.Bool("debug", false, "Display debug information")
var quiet = flag.Bool("quiet", false, "Be quiet regarding errors")
@@ -128,7 +125,7 @@ func printPeriodicStats(cleaner *Cleaner) {
errorMapLen, errorsSummery := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("%v (errors: %v, errTypes:%v) - Errors Summary: %s",
time.Since(diagnose.AppStats.StartTime),
time.Since(diagnose.AppStatsInst.StartTime),
diagnose.TapErrors.ErrorsCount,
errorMapLen,
errorsSummery,
@@ -151,7 +148,7 @@ func printPeriodicStats(cleaner *Cleaner) {
cleanStats.closed,
cleanStats.deleted,
)
currentAppStats := diagnose.AppStats.DumpStats()
currentAppStats := diagnose.AppStatsInst.DumpStats()
appStatsJSON, _ := json.Marshal(currentAppStats)
logger.Log.Infof("app stats - %v", string(appStatsJSON))
}
@@ -181,8 +178,8 @@ func initializePacketSources() error {
return err
}
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*tcpStreamMap, *tcpAssembler) {
streamsMap := NewTcpStreamMap()
func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) (*api.TcpStreamMap, *tcpAssembler) {
streamsMap := api.NewTcpStreamMap()
diagnose.InitializeErrorsMap(*debug, *verbose, *quiet)
diagnose.InitializeTapperInternalStats()
@@ -198,10 +195,8 @@ func initializePassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelI
return streamsMap, assembler
}
func startPassiveTapper(streamsMap *tcpStreamMap, assembler *tcpAssembler) {
go streamsMap.closeTimedoutTcpStreamChannels()
diagnose.AppStats.SetStartTime(time.Now())
func startPassiveTapper(streamsMap *api.TcpStreamMap, assembler *tcpAssembler) {
diagnose.AppStatsInst.SetStartTime(time.Now())
staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds)
cleaner := Cleaner{
@@ -229,7 +224,7 @@ func startPassiveTapper(streamsMap *tcpStreamMap, assembler *tcpAssembler) {
diagnose.InternalStats.PrintStatsSummary()
diagnose.TapErrors.PrintSummary()
logger.Log.Infof("AppStats: %v", diagnose.AppStats)
logger.Log.Infof("AppStats: %v", diagnose.AppStatsInst)
}
func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper {
@@ -257,7 +252,7 @@ func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChanne
}
var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats,
AppStats: &diagnose.AppStatsInst,
OutputChannel: outputItems,
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/gopacket/pcap"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
"github.com/up9inc/mizu/tap/api/diagnose"
)
type tcpPacketSource struct {

View File

@@ -12,7 +12,7 @@ import (
"github.com/google/gopacket/reassembly"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
"github.com/up9inc/mizu/tap/diagnose"
"github.com/up9inc/mizu/tap/api/diagnose"
"github.com/up9inc/mizu/tap/source"
)
@@ -36,9 +36,9 @@ func (c *context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo
}
func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *tcpStreamMap, opts *TapOpts) *tcpAssembler {
func NewTcpAssembler(outputItems chan *api.OutputChannelItem, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpAssembler {
var emitter api.Emitter = &api.Emitting{
AppStats: &diagnose.AppStats,
AppStats: &diagnose.AppStatsInst,
OutputChannel: outputItems,
}
@@ -65,7 +65,7 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
signal.Notify(signalChan, os.Interrupt)
for packetInfo := range packets {
packetsCount := diagnose.AppStats.IncPacketsCount()
packetsCount := diagnose.AppStatsInst.IncPacketsCount()
if packetsCount%PACKETS_SEEN_LOG_THRESHOLD == 0 {
logger.Log.Debugf("Packets seen: #%d", packetsCount)
@@ -73,21 +73,15 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
packet := packetInfo.Packet
data := packet.Data()
diagnose.AppStats.UpdateProcessedBytes(uint64(len(data)))
diagnose.AppStatsInst.UpdateProcessedBytes(uint64(len(data)))
if dumpPacket {
logger.Log.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data))
}
tcp := packet.Layer(layers.LayerTypeTCP)
if tcp != nil {
diagnose.AppStats.IncTcpPacketsCount()
diagnose.AppStatsInst.IncTcpPacketsCount()
tcp := tcp.(*layers.TCP)
if *checksum {
err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer())
if err != nil {
logger.Log.Fatalf("Failed to set network layer for checksum: %s", err)
}
}
c := context{
CaptureInfo: packet.Metadata().CaptureInfo,
@@ -99,13 +93,13 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp
a.assemblerMutex.Unlock()
}
done := *maxcount > 0 && int64(diagnose.AppStats.PacketsCount) >= *maxcount
done := *maxcount > 0 && int64(diagnose.AppStatsInst.PacketsCount) >= *maxcount
if done {
errorMapLen, _ := diagnose.TapErrors.GetErrorsSummary()
logger.Log.Infof("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)",
diagnose.AppStats.PacketsCount,
diagnose.AppStats.ProcessedBytes,
time.Since(diagnose.AppStats.StartTime),
diagnose.AppStatsInst.PacketsCount,
diagnose.AppStatsInst.ProcessedBytes,
time.Since(diagnose.AppStatsInst.StartTime),
diagnose.TapErrors.ErrorsCount,
errorMapLen)
}

View File

@@ -1,94 +0,0 @@
package tap
import (
"bufio"
"io"
"io/ioutil"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/api"
)
type tcpReaderDataMsg struct {
bytes []byte
timestamp time.Time
}
type ConnectionInfo struct {
ClientIP string
ClientPort string
ServerIP string
ServerPort string
IsOutgoing bool
}
/* tcpReader gets reads from a channel of bytes of tcp payload, and parses it into requests and responses.
* The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection.
* An tcpReader object is unidirectional: it parses either a client stream or a server stream.
* Implements io.Reader interface (Read)
*/
type tcpReader struct {
ident string
tcpID *api.TcpID
isClosed bool
isClient bool
isOutgoing bool
msgQueue chan tcpReaderDataMsg // Channel of captured reassembled tcp payload
data []byte
progress *api.ReadProgress
superTimer *api.SuperTimer
parent *tcpStream
packetsSeen uint
extension *api.Extension
emitter api.Emitter
counterPair *api.CounterPair
reqResMatcher api.RequestResponseMatcher
sync.Mutex
}
func (h *tcpReader) Read(p []byte) (int, error) {
var msg tcpReaderDataMsg
ok := true
for ok && len(h.data) == 0 {
msg, ok = <-h.msgQueue
h.data = msg.bytes
h.superTimer.CaptureTime = msg.timestamp
if len(h.data) > 0 {
h.packetsSeen += 1
}
}
if !ok || len(h.data) == 0 {
return 0, io.EOF
}
l := copy(p, h.data)
h.data = h.data[l:]
h.progress.Feed(l)
return l, nil
}
func (h *tcpReader) Close() {
h.Lock()
if !h.isClosed {
h.isClosed = true
close(h.msgQueue)
}
h.Unlock()
}
func (h *tcpReader) run(wg *sync.WaitGroup) {
defer wg.Done()
b := bufio.NewReader(h)
err := h.extension.Dissector.Dissect(b, h.progress, h.parent.origin, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter, filteringOptions, h.reqResMatcher)
if err != nil {
_, err = io.Copy(ioutil.Discard, b)
if err != nil {
logger.Log.Errorf("%v", err)
}
}
}

View File

@@ -21,12 +21,12 @@ import (
type tcpStreamFactory struct {
wg sync.WaitGroup
Emitter api.Emitter
streamsMap *tcpStreamMap
streamsMap *api.TcpStreamMap
ownIps []string
opts *TapOpts
}
func NewTcpStreamFactory(emitter api.Emitter, streamsMap *tcpStreamMap, opts *TapOpts) *tcpStreamFactory {
func NewTcpStreamFactory(emitter api.Emitter, streamsMap *api.TcpStreamMap, opts *TapOpts) *tcpStreamFactory {
var ownIps []string
if localhostIPs, err := getLocalhostIPs(); err != nil {
@@ -57,71 +57,71 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T
props := factory.getStreamProps(srcIp, srcPort, dstIp, dstPort)
isTapTarget := props.isTapTarget
stream := &tcpStream{
net: net,
transport: transport,
isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
isTapTarget: isTapTarget,
tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions),
ident: fmt.Sprintf("%s:%s", net, transport),
optchecker: reassembly.NewTCPOptionCheck(),
superIdentifier: &api.SuperIdentifier{},
streamsMap: factory.streamsMap,
origin: getPacketOrigin(ac),
stream := &api.TcpStream{
Net: net,
Transport: transport,
IsDNS: tcp.SrcPort == 53 || tcp.DstPort == 53,
IsTapTarget: isTapTarget,
TcpState: reassembly.NewTCPSimpleFSM(fsmOptions),
Ident: fmt.Sprintf("%s:%s", net, transport),
Optchecker: reassembly.NewTCPOptionCheck(),
SuperIdentifier: &api.SuperIdentifier{},
StreamsMap: factory.streamsMap,
Origin: getPacketOrigin(ac),
}
if stream.isTapTarget {
stream.id = factory.streamsMap.nextId()
if stream.IsTapTarget {
stream.Id = factory.streamsMap.NextId()
for i, extension := range extensions {
reqResMatcher := extension.Dissector.NewResponseRequestMatcher()
counterPair := &api.CounterPair{
Request: 0,
Response: 0,
}
stream.clients = append(stream.clients, tcpReader{
msgQueue: make(chan tcpReaderDataMsg),
progress: &api.ReadProgress{},
superTimer: &api.SuperTimer{},
ident: fmt.Sprintf("%s %s", net, transport),
tcpID: &api.TcpID{
stream.Clients = append(stream.Clients, api.TcpReader{
MsgQueue: make(chan api.TcpReaderDataMsg),
Progress: &api.ReadProgress{},
SuperTimer: &api.SuperTimer{},
Ident: fmt.Sprintf("%s %s", net, transport),
TcpID: &api.TcpID{
SrcIP: srcIp,
DstIP: dstIp,
SrcPort: srcPort,
DstPort: dstPort,
},
parent: stream,
isClient: true,
isOutgoing: props.isOutgoing,
extension: extension,
emitter: factory.Emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
Parent: stream,
IsClient: true,
IsOutgoing: props.isOutgoing,
Extension: extension,
Emitter: factory.Emitter,
CounterPair: counterPair,
ReqResMatcher: reqResMatcher,
})
stream.servers = append(stream.servers, tcpReader{
msgQueue: make(chan tcpReaderDataMsg),
progress: &api.ReadProgress{},
superTimer: &api.SuperTimer{},
ident: fmt.Sprintf("%s %s", net, transport),
tcpID: &api.TcpID{
stream.Servers = append(stream.Servers, api.TcpReader{
MsgQueue: make(chan api.TcpReaderDataMsg),
Progress: &api.ReadProgress{},
SuperTimer: &api.SuperTimer{},
Ident: fmt.Sprintf("%s %s", net, transport),
TcpID: &api.TcpID{
SrcIP: net.Dst().String(),
DstIP: net.Src().String(),
SrcPort: transport.Dst().String(),
DstPort: transport.Src().String(),
},
parent: stream,
isClient: false,
isOutgoing: props.isOutgoing,
extension: extension,
emitter: factory.Emitter,
counterPair: counterPair,
reqResMatcher: reqResMatcher,
Parent: stream,
IsClient: false,
IsOutgoing: props.isOutgoing,
Extension: extension,
Emitter: factory.Emitter,
CounterPair: counterPair,
ReqResMatcher: reqResMatcher,
})
factory.streamsMap.Store(stream.id, stream)
factory.streamsMap.Store(stream.Id, stream)
factory.wg.Add(2)
// Start reading from channel stream.reader.bytes
go stream.clients[i].run(&factory.wg)
go stream.servers[i].run(&factory.wg)
go stream.Clients[i].Run(filteringOptions, &factory.wg)
go stream.Servers[i].Run(filteringOptions, &factory.wg)
}
}
return stream

View File

@@ -1,98 +0,0 @@
package tap
import (
"os"
"runtime"
_debug "runtime/debug"
"strconv"
"sync"
"time"
"github.com/up9inc/mizu/shared/logger"
"github.com/up9inc/mizu/tap/diagnose"
)
type tcpStreamMap struct {
streams *sync.Map
streamId int64
}
func NewTcpStreamMap() *tcpStreamMap {
return &tcpStreamMap{
streams: &sync.Map{},
}
}
func (streamMap *tcpStreamMap) Store(key, value interface{}) {
streamMap.streams.Store(key, value)
}
func (streamMap *tcpStreamMap) Delete(key interface{}) {
streamMap.streams.Delete(key)
}
func (streamMap *tcpStreamMap) nextId() int64 {
streamMap.streamId++
return streamMap.streamId
}
func (streamMap *tcpStreamMap) getCloseTimedoutTcpChannelsInterval() time.Duration {
defaultDuration := 1000 * time.Millisecond
rangeMin := 10
rangeMax := 10000
closeTimedoutTcpChannelsIntervalMsStr := os.Getenv(CloseTimedoutTcpChannelsIntervalMsEnvVar)
if closeTimedoutTcpChannelsIntervalMsStr == "" {
return defaultDuration
} else {
closeTimedoutTcpChannelsIntervalMs, err := strconv.Atoi(closeTimedoutTcpChannelsIntervalMsStr)
if err != nil {
logger.Log.Warningf("Error parsing environment variable %s: %v\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, err)
return defaultDuration
} else {
if closeTimedoutTcpChannelsIntervalMs < rangeMin || closeTimedoutTcpChannelsIntervalMs > rangeMax {
logger.Log.Warningf("The value of environment variable %s is not in acceptable range: %d - %d\n", CloseTimedoutTcpChannelsIntervalMsEnvVar, rangeMin, rangeMax)
return defaultDuration
} else {
return time.Duration(closeTimedoutTcpChannelsIntervalMs) * time.Millisecond
}
}
}
}
func (streamMap *tcpStreamMap) closeTimedoutTcpStreamChannels() {
tcpStreamChannelTimeout := GetTcpChannelTimeoutMs()
closeTimedoutTcpChannelsIntervalMs := streamMap.getCloseTimedoutTcpChannelsInterval()
logger.Log.Infof("Using %d ms as the close timedout TCP stream channels interval", closeTimedoutTcpChannelsIntervalMs/time.Millisecond)
for {
time.Sleep(closeTimedoutTcpChannelsIntervalMs)
_debug.FreeOSMemory()
streamMap.streams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*tcpStream)
if stream.superIdentifier.Protocol == nil {
if !stream.isClosed && time.Now().After(stream.createdAt.Add(tcpStreamChannelTimeout)) {
stream.Close()
diagnose.AppStats.IncDroppedTcpStreams()
logger.Log.Debugf("Dropped an unidentified TCP stream because of timeout. Total dropped: %d Total Goroutines: %d Timeout (ms): %d",
diagnose.AppStats.DroppedTcpStreams, runtime.NumGoroutine(), tcpStreamChannelTimeout/time.Millisecond)
}
} else {
if !stream.superIdentifier.IsClosedOthers {
for i := range stream.clients {
reader := &stream.clients[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
for i := range stream.servers {
reader := &stream.servers[i]
if reader.extension.Protocol != stream.superIdentifier.Protocol {
reader.Close()
}
}
stream.superIdentifier.IsClosedOthers = true
}
}
return true
})
}
}