diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index 9a5202685..b8112e47a 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -52,7 +52,6 @@ var snaplen = flag.Int("s", 65536, "Snap length (number of bytes max to read per var tstype = flag.String("timestamp_type", "", "Type of timestamps to use") var promisc = flag.Bool("promisc", true, "Set promiscuous mode") var staleTimeoutSeconds = flag.Int("staletimout", 120, "Max time in seconds to keep connections which don't transmit data") -var pids = flag.String("pids", "", "A comma separated list of PIDs to capture their network namespaces") var servicemesh = flag.Bool("servicemesh", false, "Record decrypted traffic if the cluster is configured with a service mesh and with mtls") var tls = flag.Bool("tls", false, "Enable TLS tapper") @@ -190,7 +189,7 @@ func initializePacketSources() error { } var err error - if packetSourceManager, err = source.NewPacketSourceManager(*procfs, *pids, *fname, *iface, *servicemesh, tapTargets, behaviour); err != nil { + if packetSourceManager, err = source.NewPacketSourceManager(*procfs, *fname, *iface, *servicemesh, tapTargets, behaviour); err != nil { return err } else { packetSourceManager.ReadPackets(!*nodefrag, mainPacketInputChan) diff --git a/tap/source/netns_packet_source.go b/tap/source/netns_packet_source.go new file mode 100644 index 000000000..9cc8021c9 --- /dev/null +++ b/tap/source/netns_packet_source.go @@ -0,0 +1,83 @@ +package source + +import ( + "fmt" + "runtime" + + "github.com/up9inc/mizu/shared/logger" + "github.com/vishvananda/netns" +) + +func newNetnsPacketSource(procfs string, pid string, + interfaceName string, behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { + nsh, err := netns.GetFromPath(fmt.Sprintf("%s/%s/ns/net", procfs, pid)) + + if err != nil { + logger.Log.Errorf("Unable to get netns of pid %s - %w", pid, err) + return nil, err + } + + src, err := newPacketSourceFromNetnsHandle(pid, nsh, interfaceName, behaviour) + + if err != nil { + logger.Log.Errorf("Error starting netns packet source for %s - %w", pid, err) + return nil, err + } + + return src, nil +} + +func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceName string, + behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { + + done := make(chan *tcpPacketSource) + errors := make(chan error) + + go func(done chan<- *tcpPacketSource) { + // Setting a netns should be done from a dedicated OS thread. + // + // goroutines are not really OS threads, we try to mimic the issue by + // locking the OS thread to this goroutine + // + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + oldnetns, err := netns.Get() + + if err != nil { + logger.Log.Errorf("Unable to get netns of current thread %w", err) + errors <- err + return + } + + if err := netns.Set(nsh); err != nil { + logger.Log.Errorf("Unable to set netns of pid %s - %w", pid, err) + errors <- err + return + } + + name := fmt.Sprintf("netns-%s-%s", pid, interfaceName) + src, err := newTcpPacketSource(name, "", interfaceName, behaviour) + + if err != nil { + logger.Log.Errorf("Error listening to PID %s - %w", pid, err) + errors <- err + return + } + + if err := netns.Set(oldnetns); err != nil { + logger.Log.Errorf("Unable to set back netns of current thread %w", err) + errors <- err + return + } + + done <- src + }(done) + + select { + case err := <-errors: + return nil, err + case source := <-done: + return source, nil + } +} diff --git a/tap/source/packet_source_manager.go b/tap/source/packet_source_manager.go index 912d03c9a..9dc538377 100644 --- a/tap/source/packet_source_manager.go +++ b/tap/source/packet_source_manager.go @@ -2,109 +2,46 @@ package source import ( "fmt" - "runtime" - "strconv" "strings" "github.com/up9inc/mizu/shared/logger" - "github.com/vishvananda/netns" v1 "k8s.io/api/core/v1" ) +const bpfFilterMaxPods = 150 +const hostSourcePid = "0" + type PacketSourceManager struct { - sources []*tcpPacketSource + sources map[string]*tcpPacketSource } -func NewPacketSourceManager(procfs string, pids string, filename string, interfaceName string, +func NewPacketSourceManager(procfs string, filename string, interfaceName string, mtls bool, pods []v1.Pod, behaviour TcpPacketSourceBehaviour) (*PacketSourceManager, error) { - sources := make([]*tcpPacketSource, 0) - sources, err := createHostSource(sources, filename, interfaceName, behaviour) - + hostSource, err := newHostPacketSource(filename, interfaceName, behaviour) if err != nil { return nil, err } - sources = createSourcesFromPids(sources, procfs, pids, interfaceName, behaviour) - sources = createSourcesFromEnvoy(sources, mtls, procfs, pods, interfaceName, behaviour) - sources = createSourcesFromLinkerd(sources, mtls, procfs, pods, interfaceName, behaviour) - - return &PacketSourceManager{ - sources: sources, - }, nil -} - -func createHostSource(sources []*tcpPacketSource, filename string, interfaceName string, - behaviour TcpPacketSourceBehaviour) ([]*tcpPacketSource, error) { - hostSource, err := newHostPacketSource(filename, interfaceName, behaviour) - - if err != nil { - return sources, err + sourceManager := &PacketSourceManager{ + sources: map[string]*tcpPacketSource{ + hostSourcePid: hostSource, + }, } - return append(sources, hostSource), nil -} - -func createSourcesFromPids(sources []*tcpPacketSource, procfs string, pids string, - interfaceName string, behaviour TcpPacketSourceBehaviour) []*tcpPacketSource { - if pids == "" { - return sources - } - - netnsSources := newNetnsPacketSources(procfs, strings.Split(pids, ","), interfaceName, behaviour) - sources = append(sources, netnsSources...) - return sources -} - -func createSourcesFromEnvoy(sources []*tcpPacketSource, mtls bool, procfs string, pods []v1.Pod, - interfaceName string, behaviour TcpPacketSourceBehaviour) []*tcpPacketSource { - if !mtls { - return sources - } - - envoyPids, err := discoverRelevantEnvoyPids(procfs, pods) - - if err != nil { - logger.Log.Warningf("Unable to discover envoy pids - %v", err) - return sources - } - - netnsSources := newNetnsPacketSources(procfs, envoyPids, interfaceName, behaviour) - sources = append(sources, netnsSources...) - - return sources -} - -func createSourcesFromLinkerd(sources []*tcpPacketSource, mtls bool, procfs string, pods []v1.Pod, - interfaceName string, behaviour TcpPacketSourceBehaviour) []*tcpPacketSource { - if !mtls { - return sources - } - - linkerdPids, err := discoverRelevantLinkerdPids(procfs, pods) - - if err != nil { - logger.Log.Warningf("Unable to discover linkerd pids - %v", err) - return sources - } - - netnsSources := newNetnsPacketSources(procfs, linkerdPids, interfaceName, behaviour) - sources = append(sources, netnsSources...) - - return sources + sourceManager.UpdatePods(mtls, procfs, pods, interfaceName, behaviour) + return sourceManager, nil } func newHostPacketSource(filename string, interfaceName string, behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { var name string - if filename == "" { - name = fmt.Sprintf("host-%v", interfaceName) + name = fmt.Sprintf("host-%s", interfaceName) } else { - name = fmt.Sprintf("file-%v", filename) + name = fmt.Sprintf("file-%s", filename) } source, err := newTcpPacketSource(name, filename, interfaceName, behaviour) - if err != nil { return nil, err } @@ -112,90 +49,93 @@ func newHostPacketSource(filename string, interfaceName string, return source, nil } -func newNetnsPacketSources(procfs string, pids []string, interfaceName string, - behaviour TcpPacketSourceBehaviour) []*tcpPacketSource { - result := make([]*tcpPacketSource, 0) - - for _, pidstr := range pids { - pid, err := strconv.Atoi(pidstr) - - if err != nil { - logger.Log.Errorf("Invalid PID: %v - %v", pid, err) - continue - } - - nsh, err := netns.GetFromPath(fmt.Sprintf("%v/%v/ns/net", procfs, pid)) - - if err != nil { - logger.Log.Errorf("Unable to get netns of pid %v - %v", pid, err) - continue - } - - src, err := newNetnsPacketSource(pid, nsh, interfaceName, behaviour) - - if err != nil { - logger.Log.Errorf("Error starting netns packet source for %v - %v", pid, err) - continue - } - - result = append(result, src) +func (m *PacketSourceManager) UpdatePods(mtls bool, procfs string, pods []v1.Pod, + interfaceName string, behaviour TcpPacketSourceBehaviour) { + if mtls { + m.updateMtlsPods(procfs, pods, interfaceName, behaviour) } - return result + m.setBPFFilter(pods) } -func newNetnsPacketSource(pid int, nsh netns.NsHandle, interfaceName string, - behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { +func (m *PacketSourceManager) updateMtlsPods(procfs string, pods []v1.Pod, + interfaceName string, behaviour TcpPacketSourceBehaviour) { - done := make(chan *tcpPacketSource) - errors := make(chan error) + relevantPids := m.getRelevantPids(procfs, pods) + logger.Log.Infof("Updating mtls pods (new: %v) (current: %v)", relevantPids, m.sources) - go func(done chan<- *tcpPacketSource) { - // Setting a netns should be done from a dedicated OS thread. - // - // goroutines are not really OS threads, we try to mimic the issue by - // locking the OS thread to this goroutine - // - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - oldnetns, err := netns.Get() - - if err != nil { - logger.Log.Errorf("Unable to get netns of current thread %v", err) - errors <- err - return + for pid, src := range m.sources { + if _, ok := relevantPids[pid]; !ok { + src.close() + delete(m.sources, pid) } + } - if err := netns.Set(nsh); err != nil { - logger.Log.Errorf("Unable to set netns of pid %v - %v", pid, err) - errors <- err - return + for pid := range relevantPids { + if _, ok := m.sources[pid]; !ok { + source, err := newNetnsPacketSource(procfs, pid, interfaceName, behaviour) + + if err == nil { + m.sources[pid] = source + } } + } +} - name := fmt.Sprintf("netns-%v-%v", pid, interfaceName) - src, err := newTcpPacketSource(name, "", interfaceName, behaviour) +func (m *PacketSourceManager) getRelevantPids(procfs string, pods []v1.Pod) map[string]bool { + relevantPids := make(map[string]bool) + relevantPids[hostSourcePid] = true - if err != nil { - logger.Log.Errorf("Error listening to PID %v - %v", pid, err) - errors <- err - return + if envoyPids, err := discoverRelevantEnvoyPids(procfs, pods); err != nil { + logger.Log.Warningf("Unable to discover envoy pids - %w", err) + } else { + for _, pid := range envoyPids { + relevantPids[pid] = true } + } - if err := netns.Set(oldnetns); err != nil { - logger.Log.Errorf("Unable to set back netns of current thread %v", err) - errors <- err - return + if linkerdPids, err := discoverRelevantLinkerdPids(procfs, pods); err != nil { + logger.Log.Warningf("Unable to discover linkerd pids - %w", err) + } else { + for _, pid := range linkerdPids { + relevantPids[pid] = true } + } - done <- src - }(done) + return relevantPids +} - select { - case err := <-errors: - return nil, err - case source := <-done: - return source, nil +func buildBPFExpr(pods []v1.Pod) string { + hostsFilter := make([]string, 0) + + for _, pod := range pods { + hostsFilter = append(hostsFilter, fmt.Sprintf("host %s", pod.Status.PodIP)) + } + + return fmt.Sprintf("%s and port not 443", strings.Join(hostsFilter, " or ")) +} + +func (m *PacketSourceManager) setBPFFilter(pods []v1.Pod) { + if len(pods) == 0 { + logger.Log.Info("No pods provided, skipping pcap bpf filter") + return + } + + var expr string + + if len(pods) > bpfFilterMaxPods { + logger.Log.Info("Too many pods for setting ebpf filter %d, setting just not 443", len(pods)) + expr = "port not 443" + } else { + expr = buildBPFExpr(pods) + } + + logger.Log.Infof("Setting pcap bpf filter %s", expr) + + for pid, src := range m.sources { + if err := src.setBPFFilter(expr); err != nil { + logger.Log.Warningf("Error setting bpf filter for %s %v - %w", pid, src, err) + } } } diff --git a/tap/source/tcp_packet_source.go b/tap/source/tcp_packet_source.go index a7c0258ad..a2b94dadf 100644 --- a/tap/source/tcp_packet_source.go +++ b/tap/source/tcp_packet_source.go @@ -98,6 +98,14 @@ func newTcpPacketSource(name, filename string, interfaceName string, return result, nil } +func (source *tcpPacketSource) String() string { + return source.name +} + +func (source *tcpPacketSource) setBPFFilter(expr string) (err error) { + return source.handle.SetBPFFilter(expr) +} + func (source *tcpPacketSource) close() { if source.handle != nil { source.handle.Close()