diff --git a/agent/go.sum b/agent/go.sum index 3b7f5fc30..c49e1f152 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -42,6 +42,8 @@ github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb0 github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/bradleyfalzon/tlsx v0.0.0-20170624122154-28fd0e59bac4 h1:NJOOlc6ZJjix0A1rAU+nxruZtR8KboG1848yqpIUo4M= +github.com/bradleyfalzon/tlsx v0.0.0-20170624122154-28fd0e59bac4/go.mod h1:DQPxZS994Ld1Y8uwnJT+dRL04XPD0cElP/pHH/zEBHM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= diff --git a/tap/go.mod b/tap/go.mod index c0b84a753..87ce42cf3 100644 --- a/tap/go.mod +++ b/tap/go.mod @@ -3,14 +3,15 @@ module github.com/up9inc/mizu/tap go 1.16 require ( + github.com/bradleyfalzon/tlsx v0.0.0-20170624122154-28fd0e59bac4 // indirect github.com/google/gopacket v1.1.19 github.com/romana/rlog v0.0.0-20171115192701-f018bc92e7d7 github.com/up9inc/mizu/tap/api v0.0.0 - golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 - golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073 - golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d - golang.org/x/text v0.3.5 - golang.org/x/tools v0.0.0-20210106214847-113979e3529a + golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 + golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073 + golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d + golang.org/x/text v0.3.5 + golang.org/x/tools v0.0.0-20210106214847-113979e3529a ) replace github.com/up9inc/mizu/tap/api v0.0.0 => ./api diff --git a/tap/go.sum b/tap/go.sum index a7ee32da5..83a0c8f74 100644 --- a/tap/go.sum +++ b/tap/go.sum @@ -1,3 +1,5 @@ +github.com/bradleyfalzon/tlsx v0.0.0-20170624122154-28fd0e59bac4 h1:NJOOlc6ZJjix0A1rAU+nxruZtR8KboG1848yqpIUo4M= +github.com/bradleyfalzon/tlsx v0.0.0-20170624122154-28fd0e59bac4/go.mod h1:DQPxZS994Ld1Y8uwnJT+dRL04XPD0cElP/pHH/zEBHM= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/romana/rlog v0.0.0-20171115192701-f018bc92e7d7 h1:jkvpcEatpwuMF5O5LVxTnehj6YZ/aEZN4NWD/Xml4pI= diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index bcb0f2860..c95147dec 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -9,11 +9,13 @@ package tap import ( + "encoding/hex" "flag" "fmt" "io/ioutil" "log" "os" + "os/signal" "path" "path/filepath" "plugin" @@ -28,9 +30,10 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/examples/util" + "github.com/google/gopacket/ip4defrag" "github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/pcap" - "github.com/google/gopacket/tcpassembly" + "github.com/google/gopacket/reassembly" "github.com/up9inc/mizu/tap/api" ) @@ -283,8 +286,14 @@ func startPassiveTapper(outputItems chan *api.OutputChannelItem) { log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile) defer util.Run()() - var handle *pcap.Handle - var err error + if *debug { + outputLevel = 2 + } else if *verbose { + outputLevel = 1 + } else if *quiet { + outputLevel = -1 + } + errorsMap = make(map[string]uint) if localhostIPs, err := getLocalhostIPs(); err != nil { // TODO: think this over @@ -295,58 +304,293 @@ func startPassiveTapper(outputItems chan *api.OutputChannelItem) { ownIps = localhostIPs } - // Set up pcap packet capture - if *fname != "" { - log.Printf("Reading from pcap dump %q", *fname) - handle, err = pcap.OpenOffline(*fname) + appPortsStr := os.Getenv(AppPortsEnvVar) + var appPorts []int + if appPortsStr == "" { + rlog.Info("Received empty/no APP_PORTS env var! only listening to http on port 80!") + appPorts = make([]int, 0) } else { - log.Printf("Starting capture on interface %q", *iface) - handle, err = pcap.OpenLive(*iface, int32(*snaplen), true, pcap.BlockForever) + appPorts = parseAppPorts(appPortsStr) } - if err != nil { - log.Fatal(err) + SetFilterPorts(appPorts) + // envVal := os.Getenv(maxHTTP2DataLenEnvVar) + // if envVal == "" { + // rlog.Infof("Received empty/no HTTP2_DATA_SIZE_LIMIT env var! falling back to %v", maxHTTP2DataLenDefault) + // maxHTTP2DataLen = maxHTTP2DataLenDefault + // } else { + // if convertedInt, err := strconv.Atoi(envVal); err != nil { + // rlog.Infof("Received invalid HTTP2_DATA_SIZE_LIMIT env var! falling back to %v", maxHTTP2DataLenDefault) + // maxHTTP2DataLen = maxHTTP2DataLenDefault + // } else { + // rlog.Infof("Received HTTP2_DATA_SIZE_LIMIT env var: %v", maxHTTP2DataLenDefault) + // maxHTTP2DataLen = convertedInt + // } + // } + + log.Printf("App Ports: %v", gSettings.filterPorts) + + var handle *pcap.Handle + var err error + if *fname != "" { + if handle, err = pcap.OpenOffline(*fname); err != nil { + log.Fatalf("PCAP OpenOffline error: %v", err) + } + } else { + // This is a little complicated because we want to allow all possible options + // for creating the packet capture handle... instead of all this you can + // just call pcap.OpenLive if you want a simple handle. + inactive, err := pcap.NewInactiveHandle(*iface) + if err != nil { + log.Fatalf("could not create: %v", err) + } + defer inactive.CleanUp() + if err = inactive.SetSnapLen(*snaplen); err != nil { + log.Fatalf("could not set snap length: %v", err) + } else if err = inactive.SetPromisc(*promisc); err != nil { + log.Fatalf("could not set promisc mode: %v", err) + } else if err = inactive.SetTimeout(time.Second); err != nil { + log.Fatalf("could not set timeout: %v", err) + } + if *tstype != "" { + if t, err := pcap.TimestampSourceFromString(*tstype); err != nil { + log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) + } else if err := inactive.SetTimestampSource(t); err != nil { + log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) + } + } + if handle, err = inactive.Activate(); err != nil { + log.Fatalf("PCAP Activate error: %v", err) + } + defer handle.Close() + } + if len(flag.Args()) > 0 { + bpffilter := strings.Join(flag.Args(), " ") + rlog.Infof("Using BPF filter %q", bpffilter) + if err = handle.SetBPFFilter(bpffilter); err != nil { + log.Fatalf("BPF filter error: %v", err) + } } - if err := handle.SetBPFFilter(*filter); err != nil { - log.Fatal(err) + // if *dumpToHar { + // harWriter.Start() + // defer harWriter.Stop() + // } + // defer outboundLinkWriter.Stop() + + var dec gopacket.Decoder + var ok bool + decoderName := *decoder + if decoderName == "" { + decoderName = fmt.Sprintf("%s", handle.LinkType()) } + if dec, ok = gopacket.DecodersByLayerName[decoderName]; !ok { + log.Fatalln("No decoder named", decoderName) + } + source := gopacket.NewPacketSource(handle, dec) + source.Lazy = *lazy + source.NoCopy = true + rlog.Info("Starting to read packets") + statsTracker.setStartTime(time.Now()) + defragger := ip4defrag.NewIPv4Defragmenter() var emitter api.Emitter = &api.Emitting{ OutputChannel: outputItems, } - // Set up assembly streamFactory := &tcpStreamFactory{ + doHTTP: !*nohttp, Emitter: emitter, } - streamPool := tcpassembly.NewStreamPool(streamFactory) - assembler := tcpassembly.NewAssembler(streamPool) + streamPool := reassembly.NewStreamPool(streamFactory) + assembler := reassembly.NewAssembler(streamPool) - log.Println("reading in packets") - // Read in packets, pass to assembler. - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packets := packetSource.Packets() - ticker := time.Tick(time.Minute) - for { - select { - case packet := <-packets: - // A nil packet indicates the end of a pcap file. - if packet == nil { - return - } - if *verbose { - log.Println(packet) - } - if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP { - log.Println("Unusable packet") + maxBufferedPagesTotal := GetMaxBufferedPagesPerConnection() + maxBufferedPagesPerConnection := GetMaxBufferedPagesTotal() + rlog.Infof("Assembler options: maxBufferedPagesTotal=%d, maxBufferedPagesPerConnection=%d", maxBufferedPagesTotal, maxBufferedPagesPerConnection) + assembler.AssemblerOptions.MaxBufferedPagesTotal = maxBufferedPagesTotal + assembler.AssemblerOptions.MaxBufferedPagesPerConnection = maxBufferedPagesPerConnection + + var assemblerMutex sync.Mutex + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + + staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds) + cleaner := Cleaner{ + assemblerMutex: &assemblerMutex, + cleanPeriod: cleanPeriod, + connectionTimeout: staleConnectionTimeout, + } + cleaner.start() + + go func() { + statsPeriod := time.Second * time.Duration(*statsevery) + ticker := time.NewTicker(statsPeriod) + + for true { + <-ticker.C + + // Since the start + errorsMapMutex.Lock() + errorMapLen := len(errorsMap) + errorsSummery := fmt.Sprintf("%v", errorsMap) + errorsMapMutex.Unlock() + log.Printf("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v) - Errors Summary: %s", + statsTracker.appStats.TotalPacketsCount, + statsTracker.appStats.TotalProcessedBytes, + time.Since(statsTracker.appStats.StartTime), + nErrors, + errorMapLen, + errorsSummery, + ) + + // At this moment + memStats := runtime.MemStats{} + runtime.ReadMemStats(&memStats) + log.Printf( + "mem: %d, goroutines: %d, unmatched messages:", + memStats.HeapAlloc, + runtime.NumGoroutine(), + ) + + // Since the last print + cleanStats := cleaner.dumpStats() + matchedMessages := statsTracker.dumpStats() + log.Printf( + "flushed connections %d, closed connections: %d, deleted messages: %d, matched messages: %d", + cleanStats.flushed, + cleanStats.closed, + cleanStats.deleted, + matchedMessages, + ) + } + }() + + if GetMemoryProfilingEnabled() { + startMemoryProfiler() + } + + for packet := range source.Packets() { + packetsCount := statsTracker.incPacketsCount() + rlog.Debugf("PACKET #%d", packetsCount) + data := packet.Data() + statsTracker.updateProcessedSize(int64(len(data))) + if *hexdumppkt { + rlog.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data)) + } + + // defrag the IPv4 packet if required + if !*nodefrag { + ip4Layer := packet.Layer(layers.LayerTypeIPv4) + if ip4Layer == nil { continue } - tcp := packet.TransportLayer().(*layers.TCP) - assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, packet.Metadata().Timestamp) + ip4 := ip4Layer.(*layers.IPv4) + l := ip4.Length + newip4, err := defragger.DefragIPv4(ip4) + if err != nil { + log.Fatalln("Error while de-fragmenting", err) + } else if newip4 == nil { + rlog.Debugf("Fragment...") + continue // packet fragment, we don't have whole packet yet. + } + if newip4.Length != l { + stats.ipdefrag++ + rlog.Debugf("Decoding re-assembled packet: %s", newip4.NextLayerType()) + pb, ok := packet.(gopacket.PacketBuilder) + if !ok { + log.Panic("Not a PacketBuilder") + } + nextDecoder := newip4.NextLayerType() + _ = nextDecoder.Decode(newip4.Payload, pb) + } + } - case <-ticker: - // Every minute, flush connections that haven't seen activity in the past 2 minutes. - assembler.FlushOlderThan(time.Now().Add(time.Minute * -2)) + tcp := packet.Layer(layers.LayerTypeTCP) + if tcp != nil { + tcp := tcp.(*layers.TCP) + if *checksum { + err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer()) + if err != nil { + log.Fatalf("Failed to set network layer for checksum: %s\n", err) + } + } + c := Context{ + CaptureInfo: packet.Metadata().CaptureInfo, + } + stats.totalsz += len(tcp.Payload) + rlog.Debugf("%s : %v -> %s : %v", packet.NetworkLayer().NetworkFlow().Src(), tcp.SrcPort, packet.NetworkLayer().NetworkFlow().Dst(), tcp.DstPort) + assemblerMutex.Lock() + assembler.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c) + assemblerMutex.Unlock() + } + + done := *maxcount > 0 && statsTracker.appStats.TotalPacketsCount >= *maxcount + if done { + errorsMapMutex.Lock() + errorMapLen := len(errorsMap) + errorsMapMutex.Unlock() + log.Printf("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)", + statsTracker.appStats.TotalPacketsCount, + statsTracker.appStats.TotalProcessedBytes, + time.Since(statsTracker.appStats.StartTime), + nErrors, + errorMapLen) + } + select { + case <-signalChan: + log.Printf("Caught SIGINT: aborting") + done = true + default: + // NOP: continue + } + if done { + break } } + + assemblerMutex.Lock() + closed := assembler.FlushAll() + assemblerMutex.Unlock() + rlog.Debugf("Final flush: %d closed", closed) + if outputLevel >= 2 { + streamPool.Dump() + } + + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal(err) + } + _ = pprof.WriteHeapProfile(f) + _ = f.Close() + } + + streamFactory.WaitGoRoutines() + assemblerMutex.Lock() + rlog.Debugf("%s", assembler.Dump()) + assemblerMutex.Unlock() + if !*nodefrag { + log.Printf("IPdefrag:\t\t%d", stats.ipdefrag) + } + log.Printf("TCP stats:") + log.Printf(" missed bytes:\t\t%d", stats.missedBytes) + log.Printf(" total packets:\t\t%d", stats.pkt) + log.Printf(" rejected FSM:\t\t%d", stats.rejectFsm) + log.Printf(" rejected Options:\t%d", stats.rejectOpt) + log.Printf(" reassembled bytes:\t%d", stats.sz) + log.Printf(" total TCP bytes:\t%d", stats.totalsz) + log.Printf(" conn rejected FSM:\t%d", stats.rejectConnFsm) + log.Printf(" reassembled chunks:\t%d", stats.reassembled) + log.Printf(" out-of-order packets:\t%d", stats.outOfOrderPackets) + log.Printf(" out-of-order bytes:\t%d", stats.outOfOrderBytes) + log.Printf(" biggest-chunk packets:\t%d", stats.biggestChunkPackets) + log.Printf(" biggest-chunk bytes:\t%d", stats.biggestChunkBytes) + log.Printf(" overlap packets:\t%d", stats.overlapPackets) + log.Printf(" overlap bytes:\t\t%d", stats.overlapBytes) + log.Printf("Errors: %d", nErrors) + for e := range errorsMap { + log.Printf(" %s:\t\t%d", e, errorsMap[e]) + } + log.Printf("AppStats: %v", GetStats()) } diff --git a/tap/tcp_reader.go b/tap/tcp_reader.go new file mode 100644 index 000000000..bde20aaa7 --- /dev/null +++ b/tap/tcp_reader.go @@ -0,0 +1,126 @@ +package tap + +import ( + "bufio" + "fmt" + "io" + "log" + "strconv" + "sync" + "time" + + "github.com/bradleyfalzon/tlsx" + "github.com/up9inc/mizu/tap/api" +) + +const checkTLSPacketAmount = 100 + +type httpReaderDataMsg struct { + bytes []byte + timestamp time.Time +} + +type tcpID struct { + srcIP string + dstIP string + srcPort string + dstPort string +} + +type ConnectionInfo struct { + ClientIP string + ClientPort string + ServerIP string + ServerPort string + IsOutgoing bool +} + +func (tid *tcpID) String() string { + return fmt.Sprintf("%s->%s %s->%s", tid.srcIP, tid.dstIP, tid.srcPort, tid.dstPort) +} + +/* httpReader gets reads from a channel of bytes of tcp payload, and parses it into HTTP/1 requests and responses. + * The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection. + * An httpReader object is unidirectional: it parses either a client stream or a server stream. + * Implements io.Reader interface (Read) + */ +type tcpReader struct { + ident string + tcpID *api.TcpID + isClient bool + isHTTP2 bool + isOutgoing bool + msgQueue chan httpReaderDataMsg // Channel of captured reassembled tcp payload + data []byte + captureTime time.Time + hexdump bool + parent *tcpStream + messageCount uint + packetsSeen uint + outboundLinkWriter *OutboundLinkWriter + Emitter api.Emitter +} + +func (h *tcpReader) Read(p []byte) (int, error) { + var msg httpReaderDataMsg + + ok := true + for ok && len(h.data) == 0 { + msg, ok = <-h.msgQueue + h.data = msg.bytes + + h.captureTime = msg.timestamp + if len(h.data) > 0 { + h.packetsSeen += 1 + } + if h.packetsSeen < checkTLSPacketAmount && len(msg.bytes) > 5 { // packets with less than 5 bytes cause tlsx to panic + clientHello := tlsx.ClientHello{} + err := clientHello.Unmarshall(msg.bytes) + if err == nil { + fmt.Printf("Detected TLS client hello with SNI %s\n", clientHello.SNI) + numericPort, _ := strconv.Atoi(h.tcpID.DstPort) + h.outboundLinkWriter.WriteOutboundLink(h.tcpID.SrcIP, h.tcpID.DstIP, numericPort, clientHello.SNI, TLSProtocol) + } + } + } + if !ok || len(h.data) == 0 { + return 0, io.EOF + } + + l := copy(p, h.data) + h.data = h.data[l:] + return l, nil +} + +func containsPort(ports []string, port string) bool { + for _, x := range ports { + if x == port { + return true + } + } + return false +} + +func (h *tcpReader) run(wg *sync.WaitGroup) { + defer wg.Done() + log.Printf("Called run h.isClient: %v\n", h.isClient) + b := bufio.NewReader(h) + if h.isClient { + extensions[1].Dissector.Dissect(b, h.isClient, h.tcpID, h.Emitter) + } else { + extensions[1].Dissector.Dissect(b, h.isClient, h.tcpID, h.Emitter) + } + // for _, extension := range extensions { + // var subjectPorts []string + // if h.isClient { + // subjectPorts = extension.OutboundPorts + // } else { + // subjectPorts = extension.InboundPorts + // } + // if containsPort(subjectPorts, "80") { + // extension.Dissector.Ping() + // fmt.Printf("h.isClient: %v\n", h.isClient) + // extension.Dissector.Dissect(b, h.isClient, h.tcpID, h.Emitter) + // } + // } +} diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index 9453423f5..3cd0972a4 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -1,18 +1,168 @@ package tap import ( - "time" + "encoding/binary" + "encoding/hex" + "fmt" + "sync" - "github.com/google/gopacket" // pulls in all layers decoders - "github.com/google/gopacket/tcpassembly/tcpreader" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" // pulls in all layers decoders + "github.com/google/gopacket/reassembly" ) -type tcpReaderDataMsg struct { - bytes []byte - timestamp time.Time +/* It's a connection (bidirectional) + * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) + * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) + * In our implementation, we pass information from ReassembledSG to the httpReader through a shared channel. + */ +type tcpStream struct { + tcpstate *reassembly.TCPSimpleFSM + fsmerr bool + optchecker reassembly.TCPOptionCheck + net, transport gopacket.Flow + isDNS bool + isHTTP bool + reversed bool + client tcpReader + server tcpReader + urls []string + ident string + sync.Mutex } -type tcpStream struct { - net, transport gopacket.Flow - r tcpreader.ReaderStream +func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { + // FSM + if !t.tcpstate.CheckState(tcp, dir) { + SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpstate.String()) + stats.rejectFsm++ + if !t.fsmerr { + t.fsmerr = true + stats.rejectConnFsm++ + } + if !*ignorefsmerr { + return false + } + } + // Options + err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) + if err != nil { + SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) + stats.rejectOpt++ + if !*nooptcheck { + return false + } + } + // Checksum + accept := true + if *checksum { + c, err := tcp.ComputeChecksum() + if err != nil { + SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err) + accept = false + } else if c != 0x0 { + SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c) + accept = false + } + } + if !accept { + stats.rejectOpt++ + } + return accept +} + +func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { + dir, start, end, skip := sg.Info() + length, saved := sg.Lengths() + // update stats + sgStats := sg.Stats() + if skip > 0 { + stats.missedBytes += skip + } + stats.sz += length - saved + stats.pkt += sgStats.Packets + if sgStats.Chunks > 1 { + stats.reassembled++ + } + stats.outOfOrderPackets += sgStats.QueuedPackets + stats.outOfOrderBytes += sgStats.QueuedBytes + if length > stats.biggestChunkBytes { + stats.biggestChunkBytes = length + } + if sgStats.Packets > stats.biggestChunkPackets { + stats.biggestChunkPackets = sgStats.Packets + } + if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 { + // In the original example this was handled with panic(). + // I don't know what this error means or how to handle it properly. + SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets) + } + stats.overlapBytes += sgStats.OverlapBytes + stats.overlapPackets += sgStats.OverlapPackets + + var ident string + if dir == reassembly.TCPDirClientToServer { + ident = fmt.Sprintf("%v %v(%s): ", t.net, t.transport, dir) + } else { + ident = fmt.Sprintf("%v %v(%s): ", t.net.Reverse(), t.transport.Reverse(), dir) + } + Trace("%s: SG reassembled packet with %d bytes (start:%v,end:%v,skip:%d,saved:%d,nb:%d,%d,overlap:%d,%d)", ident, length, start, end, skip, saved, sgStats.Packets, sgStats.Chunks, sgStats.OverlapBytes, sgStats.OverlapPackets) + if skip == -1 && *allowmissinginit { + // this is allowed + } else if skip != 0 { + // Missing bytes in stream: do not even try to parse it + return + } + data := sg.Fetch(length) + if t.isDNS { + dns := &layers.DNS{} + var decoded []gopacket.LayerType + if len(data) < 2 { + if len(data) > 0 { + sg.KeepFrom(0) + } + return + } + dnsSize := binary.BigEndian.Uint16(data[:2]) + missing := int(dnsSize) - len(data[2:]) + Trace("dnsSize: %d, missing: %d", dnsSize, missing) + if missing > 0 { + Debug("Missing some bytes: %d", missing) + sg.KeepFrom(0) + return + } + p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns) + err := p.DecodeLayers(data[2:], &decoded) + if err != nil { + SilentError("DNS-parser", "Failed to decode DNS: %v", err) + } else { + Trace("DNS: %s", gopacket.LayerDump(dns)) + } + if len(data) > 2+int(dnsSize) { + sg.KeepFrom(2 + int(dnsSize)) + } + } else if t.isHTTP { + if length > 0 { + if *hexdump { + Trace("Feeding http with:%s", hex.Dump(data)) + } + // This is where we pass the reassembled information onwards + // This channel is read by an httpReader object + if dir == reassembly.TCPDirClientToServer && !t.reversed { + t.client.msgQueue <- httpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} + } else { + t.server.msgQueue <- httpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} + } + } + } +} + +func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { + Trace("%s: Connection closed", t.ident) + if t.isHTTP { + close(t.client.msgQueue) + close(t.server.msgQueue) + } + // do not remove the connection to allow last ACK + return false } diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index e214c787f..9c5dd50b9 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -1,80 +1,97 @@ package tap import ( - "bufio" "fmt" "log" - "strconv" + "sync" "github.com/romana/rlog" "github.com/up9inc/mizu/tap/api" - "github.com/google/gopacket" // pulls in all layers decoders - "github.com/google/gopacket/tcpassembly" - "github.com/google/gopacket/tcpassembly/tcpreader" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" // pulls in all layers decoders + "github.com/google/gopacket/reassembly" ) +/* + * The TCP factory: returns a new Stream + * Implements gopacket.reassembly.StreamFactory interface (New) + * Generates a new tcp stream for each new tcp connection. Closes the stream when the connection closes. + */ type tcpStreamFactory struct { + wg sync.WaitGroup + doHTTP bool outbountLinkWriter *OutboundLinkWriter Emitter api.Emitter } -func containsPort(ports []string, port string) bool { - for _, x := range ports { - if x == port { - return true - } +func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { + log.Printf("* NEW: %s %s", net, transport) + fsmOptions := reassembly.TCPSimpleFSMOptions{ + SupportMissingEstablishment: *allowmissinginit, } - return false -} + rlog.Debugf("Current App Ports: %v", gSettings.filterPorts) + srcIp := net.Src().String() + dstIp := net.Dst().String() + dstPort := int(tcp.DstPort) -func (h *tcpStream) clientRun(tcpID *api.TcpID, emitter api.Emitter) { - b := bufio.NewReader(&h.r) - for _, extension := range extensions { - if containsPort(extension.OutboundPorts, h.transport.Dst().String()) { - extension.Dissector.Ping() - extension.Dissector.Dissect(b, true, tcpID, emitter) - } - } -} - -func (h *tcpStream) serverRun(tcpID *api.TcpID, emitter api.Emitter) { - b := bufio.NewReader(&h.r) - for _, extension := range extensions { - if containsPort(extension.OutboundPorts, h.transport.Src().String()) { - extension.Dissector.Ping() - extension.Dissector.Dissect(b, false, tcpID, emitter) - } - } -} - -func (h *tcpStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - log.Printf("* NEW: %s %s\n", net, transport) + // if factory.shouldNotifyOnOutboundLink(dstIp, dstPort) { + // factory.outbountLinkWriter.WriteOutboundLink(net.Src().String(), dstIp, dstPort, "", "") + // } + props := factory.getStreamProps(srcIp, dstIp, dstPort) + isHTTP := props.isTapTarget stream := &tcpStream{ - net: net, - transport: transport, - r: tcpreader.NewReaderStream(), + net: net, + transport: transport, + isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, + isHTTP: isHTTP && factory.doHTTP, + reversed: tcp.SrcPort == 80, + tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions), + ident: fmt.Sprintf("%s:%s", net, transport), + optchecker: reassembly.NewTCPOptionCheck(), } - tcpID := &api.TcpID{ - SrcIP: net.Src().String(), - DstIP: net.Dst().String(), - SrcPort: transport.Src().String(), - DstPort: transport.Dst().String(), - Ident: fmt.Sprintf("%s:%s", net, transport), - } - dstPort, _ := strconv.Atoi(transport.Dst().String()) - streamProps := h.getStreamProps(net.Src().String(), net.Dst().String(), dstPort) - if streamProps.isTapTarget { - if containsPort(allOutboundPorts, transport.Dst().String()) { - go stream.clientRun(tcpID, h.Emitter) - } else if containsPort(allOutboundPorts, transport.Src().String()) { - go stream.serverRun(tcpID, h.Emitter) + if stream.isHTTP { + stream.client = tcpReader{ + msgQueue: make(chan httpReaderDataMsg), + ident: fmt.Sprintf("%s %s", net, transport), + tcpID: &api.TcpID{ + SrcIP: net.Src().String(), + DstIP: net.Dst().String(), + SrcPort: transport.Src().String(), + DstPort: transport.Dst().String(), + }, + hexdump: *hexdump, + parent: stream, + isClient: true, + isOutgoing: props.isOutgoing, + outboundLinkWriter: factory.outbountLinkWriter, + Emitter: factory.Emitter, } + stream.server = tcpReader{ + msgQueue: make(chan httpReaderDataMsg), + ident: fmt.Sprintf("%s %s", net.Reverse(), transport.Reverse()), + tcpID: &api.TcpID{ + SrcIP: net.Dst().String(), + DstIP: net.Src().String(), + SrcPort: transport.Dst().String(), + DstPort: transport.Src().String(), + }, + hexdump: *hexdump, + parent: stream, + isOutgoing: props.isOutgoing, + outboundLinkWriter: factory.outbountLinkWriter, + Emitter: factory.Emitter, + } + factory.wg.Add(2) + // Start reading from channels stream.client.bytes and stream.server.bytes + go stream.client.run(&factory.wg) + go stream.server.run(&factory.wg) } - //if h.shouldNotifyOnOutboundLink(net.Dst().String(), dstPort) { - // h.outbountLinkWriter.WriteOutboundLink(net.Src().String(), net.Dst().String(), dstPort, "", "") - //} - return &stream.r + return stream +} + +func (factory *tcpStreamFactory) WaitGoRoutines() { + factory.wg.Wait() } func (factory *tcpStreamFactory) getStreamProps(srcIP string, dstIP string, dstPort int) *streamProps { @@ -91,6 +108,12 @@ func (factory *tcpStreamFactory) getStreamProps(srcIP string, dstIP string, dstP } return &streamProps{isTapTarget: false} } else { + isTappedPort := dstPort == 80 || (gSettings.filterPorts != nil && (inArrayInt(gSettings.filterPorts, dstPort))) + if !isTappedPort { + rlog.Debugf("getStreamProps %s", fmt.Sprintf("- notHost1 %d", dstPort)) + return &streamProps{isTapTarget: false, isOutgoing: false} + } + isOutgoing := !inArrayString(ownIps, dstIP) if !*anydirection && isOutgoing {