diff --git a/tap/source/netns_packet_source.go b/tap/source/netns_packet_source.go index 9cc8021c9..a8183ff0f 100644 --- a/tap/source/netns_packet_source.go +++ b/tap/source/netns_packet_source.go @@ -5,11 +5,12 @@ import ( "runtime" "github.com/up9inc/mizu/shared/logger" + "github.com/up9inc/mizu/tap/api" "github.com/vishvananda/netns" ) -func newNetnsPacketSource(procfs string, pid string, - interfaceName string, behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { +func newNetnsPacketSource(procfs string, pid string, interfaceName string, + behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) { nsh, err := netns.GetFromPath(fmt.Sprintf("%s/%s/ns/net", procfs, pid)) if err != nil { @@ -17,7 +18,7 @@ func newNetnsPacketSource(procfs string, pid string, return nil, err } - src, err := newPacketSourceFromNetnsHandle(pid, nsh, interfaceName, behaviour) + src, err := newPacketSourceFromNetnsHandle(pid, nsh, interfaceName, behaviour, origin) if err != nil { logger.Log.Errorf("Error starting netns packet source for %s - %w", pid, err) @@ -28,7 +29,7 @@ func newNetnsPacketSource(procfs string, pid string, } func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceName string, - behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { + behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) { done := make(chan *tcpPacketSource) errors := make(chan error) @@ -57,7 +58,7 @@ func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceNam } name := fmt.Sprintf("netns-%s-%s", pid, interfaceName) - src, err := newTcpPacketSource(name, "", interfaceName, behaviour) + src, err := newTcpPacketSource(name, "", interfaceName, behaviour, origin) if err != nil { logger.Log.Errorf("Error listening to PID %s - %w", pid, err) diff --git a/tap/source/packet_source_manager.go b/tap/source/packet_source_manager.go index bf527ef2e..bbedb5d39 100644 --- a/tap/source/packet_source_manager.go +++ b/tap/source/packet_source_manager.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/up9inc/mizu/shared/logger" + "github.com/up9inc/mizu/tap/api" v1 "k8s.io/api/core/v1" ) @@ -37,10 +38,10 @@ func NewPacketSourceManager(procfs string, filename string, interfaceName string } sourceManager.config = PacketSourceManagerConfig{ - mtls: mtls, - procfs: procfs, + mtls: mtls, + procfs: procfs, interfaceName: interfaceName, - behaviour: behaviour, + behaviour: behaviour, } go hostSource.readPackets(ipdefrag, packets) @@ -56,7 +57,7 @@ func newHostPacketSource(filename string, interfaceName string, name = fmt.Sprintf("file-%s", filename) } - source, err := newTcpPacketSource(name, filename, interfaceName, behaviour) + source, err := newTcpPacketSource(name, filename, interfaceName, behaviour, api.Pcap) if err != nil { return nil, err } @@ -85,9 +86,9 @@ func (m *PacketSourceManager) updateMtlsPods(procfs string, pods []v1.Pod, } } - for pid := range relevantPids { + for pid, origin := range relevantPids { if _, ok := m.sources[pid]; !ok { - source, err := newNetnsPacketSource(procfs, pid, interfaceName, behaviour) + source, err := newNetnsPacketSource(procfs, pid, interfaceName, behaviour, origin) if err == nil { go source.readPackets(ipdefrag, packets) @@ -97,15 +98,15 @@ func (m *PacketSourceManager) updateMtlsPods(procfs string, pods []v1.Pod, } } -func (m *PacketSourceManager) getRelevantPids(procfs string, pods []v1.Pod) map[string]bool { - relevantPids := make(map[string]bool) - relevantPids[hostSourcePid] = true +func (m *PacketSourceManager) getRelevantPids(procfs string, pods []v1.Pod) map[string]api.Capture { + relevantPids := make(map[string]api.Capture) + relevantPids[hostSourcePid] = api.Pcap if envoyPids, err := discoverRelevantEnvoyPids(procfs, pods); err != nil { logger.Log.Warningf("Unable to discover envoy pids - %w", err) } else { for _, pid := range envoyPids { - relevantPids[pid] = true + relevantPids[pid] = api.Envoy } } @@ -113,7 +114,7 @@ func (m *PacketSourceManager) getRelevantPids(procfs string, pods []v1.Pod) map[ logger.Log.Warningf("Unable to discover linkerd pids - %w", err) } else { for _, pid := range linkerdPids { - relevantPids[pid] = true + relevantPids[pid] = api.Linkerd } } diff --git a/tap/source/tcp_packet_source.go b/tap/source/tcp_packet_source.go index a2b94dadf..f71ee3a62 100644 --- a/tap/source/tcp_packet_source.go +++ b/tap/source/tcp_packet_source.go @@ -10,6 +10,7 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" "github.com/up9inc/mizu/shared/logger" + "github.com/up9inc/mizu/tap/api" "github.com/up9inc/mizu/tap/diagnose" ) @@ -19,6 +20,7 @@ type tcpPacketSource struct { defragger *ip4defrag.IPv4Defragmenter Behaviour *TcpPacketSourceBehaviour name string + Origin api.Capture } type TcpPacketSourceBehaviour struct { @@ -36,13 +38,14 @@ type TcpPacketInfo struct { } func newTcpPacketSource(name, filename string, interfaceName string, - behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) { + behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) { var err error result := &tcpPacketSource{ name: name, defragger: ip4defrag.NewIPv4Defragmenter(), Behaviour: &behaviour, + Origin: origin, } if filename != "" { diff --git a/tap/tcp_assembler.go b/tap/tcp_assembler.go index 8bba8f81d..0751ac78d 100644 --- a/tap/tcp_assembler.go +++ b/tap/tcp_assembler.go @@ -29,6 +29,7 @@ type tcpAssembler struct { // The assembler context type context struct { CaptureInfo gopacket.CaptureInfo + Origin api.Capture } func (c *context) GetCaptureInfo() gopacket.CaptureInfo { @@ -87,8 +88,10 @@ func (a *tcpAssembler) processPackets(dumpPacket bool, packets <-chan source.Tcp logger.Log.Fatalf("Failed to set network layer for checksum: %s", err) } } + c := context{ CaptureInfo: packet.Metadata().CaptureInfo, + Origin: packetInfo.Source.Origin, } diagnose.InternalStats.Totalsz += len(tcp.Payload) a.assemblerMutex.Lock() diff --git a/tap/tcp_reader.go b/tap/tcp_reader.go index 49af7a42c..61b853492 100644 --- a/tap/tcp_reader.go +++ b/tap/tcp_reader.go @@ -98,8 +98,7 @@ func (h *tcpReader) Close() { func (h *tcpReader) run(wg *sync.WaitGroup) { defer wg.Done() b := bufio.NewReader(h) - // TODO: Add api.Pcap, api.Envoy and api.Linkerd distinction by refactoring NewPacketSourceManager method - err := h.extension.Dissector.Dissect(b, h.progress, api.Pcap, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter, filteringOptions, h.reqResMatcher) + err := h.extension.Dissector.Dissect(b, h.progress, h.parent.origin, h.isClient, h.tcpID, h.counterPair, h.superTimer, h.parent.superIdentifier, h.emitter, filteringOptions, h.reqResMatcher) if err != nil { _, err = io.Copy(ioutil.Discard, b) if err != nil { diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index 47790bfcd..ea82861fc 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -29,6 +29,7 @@ type tcpStream struct { clients []tcpReader servers []tcpReader ident string + origin api.Capture sync.Mutex streamsMap *tcpStreamMap } diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 4367d2a6f..527b6e44d 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -78,6 +78,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T optchecker: reassembly.NewTCPOptionCheck(), superIdentifier: &api.SuperIdentifier{}, streamsMap: factory.streamsMap, + origin: getPacketOrigin(ac), } if stream.isTapTarget { stream.id = factory.streamsMap.nextId() @@ -182,6 +183,17 @@ func (factory *tcpStreamFactory) shouldNotifyOnOutboundLink(dstIP string, dstPor return true } +func getPacketOrigin(ac reassembly.AssemblerContext) api.Capture { + c, ok := ac.(*context) + + if !ok { + // If ac is not our context, fallback to Pcap + return api.Pcap + } + + return c.Origin +} + type streamProps struct { isTapTarget bool isOutgoing bool