diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index 6dea96f92..24f1cbddf 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -66,6 +66,7 @@ var filteringOptions *api.TrafficFilteringOptions // global var tapTargets []v1.Pod // global var packetSourceManager *source.PacketSourceManager // global var mainPacketInputChan chan source.TcpPacketInfo // global +var tlsTapperInstance *tlstapper.TlsTapper // global func inArrayInt(arr []int, valueToCheck int) bool { for _, value := range arr { @@ -92,7 +93,7 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, if *tls { for _, e := range extensions { if e.Protocol.Name == "http" { - startTlsTapper(e, outputItems, options) + tlsTapperInstance = startTlsTapper(e, outputItems, options) break } } @@ -106,20 +107,36 @@ func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, } func UpdateTapTargets(newTapTargets []v1.Pod) { + success := true + tapTargets = newTapTargets if err := initializePacketSources(); err != nil { logger.Log.Fatal(err) + success = false } - printNewTapTargets() + + if tlsTapperInstance != nil { + if err := tlstapper.UpdateTapTargets(tlsTapperInstance, &tapTargets, *procfs); err != nil { + tlstapper.LogError(err) + success = false + } + } + + printNewTapTargets(success) } -func printNewTapTargets() { +func printNewTapTargets(success bool) { printStr := "" for _, tapTarget := range tapTargets { printStr += fmt.Sprintf("%s (%s), ", tapTarget.Status.PodIP, tapTarget.Name) } printStr = strings.TrimRight(printStr, ", ") - logger.Log.Infof("Now tapping: %s", printStr) + + if success { + logger.Log.Infof("Now tapping: %s", printStr) + } else { + logger.Log.Errorf("Failed to start tapping: %s", printStr) + } } func printPeriodicStats(cleaner *Cleaner) { @@ -236,13 +253,18 @@ func startPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) logger.Log.Infof("AppStats: %v", diagnose.AppStats) } -func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) { +func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChannelItem, options *api.TrafficFilteringOptions) *tlstapper.TlsTapper { tls := tlstapper.TlsTapper{} tlsPerfBufferSize := os.Getpagesize() * 100 if err := tls.Init(tlsPerfBufferSize, *procfs, extension); err != nil { tlstapper.LogError(err) - return + return nil + } + + if err := tlstapper.UpdateTapTargets(&tls, &tapTargets, *procfs); err != nil { + tlstapper.LogError(err) + return nil } // A quick way to instrument libssl.so without PID filtering - used for debuging and troubleshooting @@ -250,19 +272,16 @@ func startTlsTapper(extension *api.Extension, outputItems chan *api.OutputChanne if os.Getenv("MIZU_GLOBAL_SSL_LIBRARY") != "" { if err := tls.GlobalTap(os.Getenv("MIZU_GLOBAL_SSL_LIBRARY")); err != nil { tlstapper.LogError(err) - return + return nil } } - if err := tlstapper.UpdateTapTargets(&tls, &tapTargets, *procfs); err != nil { - tlstapper.LogError(err) - return - } - var emitter api.Emitter = &api.Emitting{ AppStats: &diagnose.AppStats, OutputChannel: outputItems, } go tls.Poll(emitter, options) + + return &tls } diff --git a/tap/tlstapper/tls_process_discoverer.go b/tap/tlstapper/tls_process_discoverer.go index b3c217385..eabcedc92 100644 --- a/tap/tlstapper/tls_process_discoverer.go +++ b/tap/tlstapper/tls_process_discoverer.go @@ -24,6 +24,8 @@ func UpdateTapTargets(tls *TlsTapper, pods *[]v1.Pod, procfs string) error { if err != nil { return err } + + tls.ClearPids() for _, pid := range containerPids { if err := tls.AddPid(procfs, pid); err != nil { diff --git a/tap/tlstapper/tls_tapper.go b/tap/tlstapper/tls_tapper.go index f9d32672d..6886e147c 100644 --- a/tap/tlstapper/tls_tapper.go +++ b/tap/tlstapper/tls_tapper.go @@ -5,6 +5,7 @@ import ( "github.com/go-errors/errors" "github.com/up9inc/mizu/shared/logger" "github.com/up9inc/mizu/tap/api" + "sync" ) //go:generate go run github.com/cilium/ebpf/cmd/bpf2go tlsTapper bpf/tls_tapper.c -- -O2 -g -D__TARGET_ARCH_x86 @@ -14,6 +15,7 @@ type TlsTapper struct { syscallHooks syscallHooks sslHooksStructs []sslHooks poller *tlsPoller + registeredPids sync.Map } func (t *TlsTapper) Init(bufferSize int, procfs string, extension *api.Extension) error { @@ -70,6 +72,16 @@ func (t *TlsTapper) RemovePid(pid uint32) error { return nil } +func (t *TlsTapper) ClearPids() { + t.registeredPids.Range(func(key, v interface{}) bool { + if err := t.RemovePid(key.(uint32)); err != nil { + LogError(err) + } + t.registeredPids.Delete(key) + return true + }) +} + func (t *TlsTapper) Close() []error { errors := make([]error, 0) @@ -116,6 +128,8 @@ func (t *TlsTapper) tapPid(pid uint32, sslLibrary string) error { if err := pids.Put(pid, uint32(1)); err != nil { return errors.Wrap(err, 0) } + + t.registeredPids.Store(pid, true) return nil }