diff --git a/tap/cleaner.go b/tap/cleaner.go index 738b8b239..02a147eda 100644 --- a/tap/cleaner.go +++ b/tap/cleaner.go @@ -3,10 +3,6 @@ package tap import ( "sync" "time" - - "github.com/romana/rlog" - - "github.com/google/gopacket/reassembly" ) type CleanerStats struct { @@ -16,7 +12,6 @@ type CleanerStats struct { } type Cleaner struct { - assembler *reassembly.Assembler assemblerMutex *sync.Mutex cleanPeriod time.Duration connectionTimeout time.Duration @@ -25,18 +20,18 @@ type Cleaner struct { } func (cl *Cleaner) clean() { - startCleanTime := time.Now() + // startCleanTime := time.Now() - cl.assemblerMutex.Lock() - rlog.Debugf("Assembler Stats before cleaning %s", cl.assembler.Dump()) - flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) - cl.assemblerMutex.Unlock() + // cl.assemblerMutex.Lock() + // rlog.Debugf("Assembler Stats before cleaning %s", cl.assembler.Dump()) + // flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) + // cl.assemblerMutex.Unlock() - cl.statsMutex.Lock() - rlog.Debugf("Assembler Stats after cleaning %s", cl.assembler.Dump()) - cl.stats.flushed += flushed - cl.stats.closed += closed - cl.statsMutex.Unlock() + // cl.statsMutex.Lock() + // rlog.Debugf("Assembler Stats after cleaning %s", cl.assembler.Dump()) + // cl.stats.flushed += flushed + // cl.stats.closed += closed + // cl.statsMutex.Unlock() } func (cl *Cleaner) start() { diff --git a/tap/extensions/http/main.go b/tap/extensions/http/main.go index 198fe51e0..efea9f96b 100644 --- a/tap/extensions/http/main.go +++ b/tap/extensions/http/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "io" "log" "net/http" @@ -12,6 +13,8 @@ func init() { log.Println("Initializing HTTP extension.") } +var discardBuffer = make([]byte, 4096) + type dissecting string func (g dissecting) Register(extension *api.Extension) { @@ -23,11 +26,40 @@ func (g dissecting) Ping() { log.Printf("pong HTTP\n") } +func DiscardBytesToFirstError(r io.Reader) (discarded int, err error) { + for { + n, e := r.Read(discardBuffer) + discarded += n + if e != nil { + return discarded, e + } + } +} + +func DiscardBytesToEOF(r io.Reader) (discarded int) { + for { + n, e := DiscardBytesToFirstError(r) + discarded += n + if e == io.EOF { + return + } + } +} + func (g dissecting) Dissect(b *bufio.Reader) interface{} { - log.Printf("called Dissect!") - req, _ := http.ReadRequest(b) - log.Printf("HTTP Request: %+v\n", req) - return nil + for { + req, err := http.ReadRequest(b) + if err == io.EOF { + // We must read until we see an EOF... very important! + return nil + } else if err != nil { + log.Println("Error reading stream:", err) + } else { + bodyBytes := DiscardBytesToEOF(req.Body) + req.Body.Close() + log.Println("Received request from stream:", req, "with", bodyBytes, "bytes in request body") + } + } } // exported as symbol named "Greeter" diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index a66f8327b..5ef4ea361 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -9,13 +9,11 @@ package tap import ( - "encoding/hex" "flag" "fmt" "io/ioutil" "log" "os" - "os/signal" "path" "path/filepath" "plugin" @@ -30,10 +28,9 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/examples/util" - "github.com/google/gopacket/ip4defrag" "github.com/google/gopacket/layers" // pulls in all layers decoders "github.com/google/gopacket/pcap" - "github.com/google/gopacket/reassembly" + "github.com/google/gopacket/tcpassembly" "github.com/up9inc/mizu/tap/api" ) @@ -94,6 +91,8 @@ var dumpToHar = flag.Bool("hardump", false, "Dump traffic to har files") var HarOutputDir = flag.String("hardir", "", "Directory in which to store output har files") var harEntriesPerFile = flag.Int("harentriesperfile", 200, "Number of max number of har entries to store in each file") +var filter = flag.String("f", "tcp and dst port 80", "BPF filter for pcap") + var statsTracker = StatsTracker{} // global @@ -275,309 +274,55 @@ func startPassiveTapper(outboundLinkWriter *OutboundLinkWriter) { log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile) defer util.Run()() - if *debug { - outputLevel = 2 - } else if *verbose { - outputLevel = 1 - } else if *quiet { - outputLevel = -1 - } - errorsMap = make(map[string]uint) - - if localhostIPs, err := getLocalhostIPs(); err != nil { - // TODO: think this over - rlog.Info("Failed to get self IP addresses") - rlog.Errorf("Getting-Self-Address", "Error getting self ip address: %s (%v,%+v)", err, err, err) - ownIps = make([]string, 0) - } else { - ownIps = localhostIPs - } - - appPortsStr := os.Getenv(AppPortsEnvVar) - var appPorts []int - if appPortsStr == "" { - rlog.Info("Received empty/no APP_PORTS env var! only listening to http on port 80!") - appPorts = make([]int, 0) - } else { - appPorts = parseAppPorts(appPortsStr) - } - SetFilterPorts(appPorts) - // envVal := os.Getenv(maxHTTP2DataLenEnvVar) - // if envVal == "" { - // rlog.Infof("Received empty/no HTTP2_DATA_SIZE_LIMIT env var! falling back to %v", maxHTTP2DataLenDefault) - // maxHTTP2DataLen = maxHTTP2DataLenDefault - // } else { - // if convertedInt, err := strconv.Atoi(envVal); err != nil { - // rlog.Infof("Received invalid HTTP2_DATA_SIZE_LIMIT env var! falling back to %v", maxHTTP2DataLenDefault) - // maxHTTP2DataLen = maxHTTP2DataLenDefault - // } else { - // rlog.Infof("Received HTTP2_DATA_SIZE_LIMIT env var: %v", maxHTTP2DataLenDefault) - // maxHTTP2DataLen = convertedInt - // } - // } - - log.Printf("App Ports: %v", gSettings.filterPorts) - var handle *pcap.Handle var err error + + // Set up pcap packet capture if *fname != "" { - if handle, err = pcap.OpenOffline(*fname); err != nil { - log.Fatalf("PCAP OpenOffline error: %v", err) - } + log.Printf("Reading from pcap dump %q", *fname) + handle, err = pcap.OpenOffline(*fname) } else { - // This is a little complicated because we want to allow all possible options - // for creating the packet capture handle... instead of all this you can - // just call pcap.OpenLive if you want a simple handle. - inactive, err := pcap.NewInactiveHandle(*iface) - if err != nil { - log.Fatalf("could not create: %v", err) - } - defer inactive.CleanUp() - if err = inactive.SetSnapLen(*snaplen); err != nil { - log.Fatalf("could not set snap length: %v", err) - } else if err = inactive.SetPromisc(*promisc); err != nil { - log.Fatalf("could not set promisc mode: %v", err) - } else if err = inactive.SetTimeout(time.Second); err != nil { - log.Fatalf("could not set timeout: %v", err) - } - if *tstype != "" { - if t, err := pcap.TimestampSourceFromString(*tstype); err != nil { - log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) - } else if err := inactive.SetTimestampSource(t); err != nil { - log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) + log.Printf("Starting capture on interface %q", *iface) + handle, err = pcap.OpenLive(*iface, int32(*snaplen), true, pcap.BlockForever) + } + if err != nil { + log.Fatal(err) + } + + if err := handle.SetBPFFilter(*filter); err != nil { + log.Fatal(err) + } + + // Set up assembly + streamFactory := &tcpStreamFactory{} + streamPool := tcpassembly.NewStreamPool(streamFactory) + assembler := tcpassembly.NewAssembler(streamPool) + + log.Println("reading in packets") + // Read in packets, pass to assembler. + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packets := packetSource.Packets() + ticker := time.Tick(time.Minute) + for { + select { + case packet := <-packets: + // A nil packet indicates the end of a pcap file. + if packet == nil { + return } - } - if handle, err = inactive.Activate(); err != nil { - log.Fatalf("PCAP Activate error: %v", err) - } - defer handle.Close() - } - if len(flag.Args()) > 0 { - bpffilter := strings.Join(flag.Args(), " ") - rlog.Infof("Using BPF filter %q", bpffilter) - if err = handle.SetBPFFilter(bpffilter); err != nil { - log.Fatalf("BPF filter error: %v", err) - } - } - - // if *dumpToHar { - // harWriter.Start() - // defer harWriter.Stop() - // } - defer outboundLinkWriter.Stop() - - var dec gopacket.Decoder - var ok bool - decoderName := *decoder - if decoderName == "" { - decoderName = fmt.Sprintf("%s", handle.LinkType()) - } - if dec, ok = gopacket.DecodersByLayerName[decoderName]; !ok { - log.Fatalln("No decoder named", decoderName) - } - source := gopacket.NewPacketSource(handle, dec) - source.Lazy = *lazy - source.NoCopy = true - rlog.Info("Starting to read packets") - statsTracker.setStartTime(time.Now()) - defragger := ip4defrag.NewIPv4Defragmenter() - - streamFactory := &tcpStreamFactory{ - doHTTP: !*nohttp, - // harWriter: harWriter, - outbountLinkWriter: outboundLinkWriter, - } - streamPool := reassembly.NewStreamPool(streamFactory) - assembler := reassembly.NewAssembler(streamPool) - - maxBufferedPagesTotal := GetMaxBufferedPagesPerConnection() - maxBufferedPagesPerConnection := GetMaxBufferedPagesTotal() - rlog.Infof("Assembler options: maxBufferedPagesTotal=%d, maxBufferedPagesPerConnection=%d", maxBufferedPagesTotal, maxBufferedPagesPerConnection) - assembler.AssemblerOptions.MaxBufferedPagesTotal = maxBufferedPagesTotal - assembler.AssemblerOptions.MaxBufferedPagesPerConnection = maxBufferedPagesPerConnection - - var assemblerMutex sync.Mutex - - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt) - - staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds) - cleaner := Cleaner{ - assembler: assembler, - assemblerMutex: &assemblerMutex, - cleanPeriod: cleanPeriod, - connectionTimeout: staleConnectionTimeout, - } - cleaner.start() - - go func() { - statsPeriod := time.Second * time.Duration(*statsevery) - ticker := time.NewTicker(statsPeriod) - - for true { - <-ticker.C - - // Since the start - errorsMapMutex.Lock() - errorMapLen := len(errorsMap) - errorsSummery := fmt.Sprintf("%v", errorsMap) - errorsMapMutex.Unlock() - log.Printf("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v) - Errors Summary: %s", - statsTracker.appStats.TotalPacketsCount, - statsTracker.appStats.TotalProcessedBytes, - time.Since(statsTracker.appStats.StartTime), - nErrors, - errorMapLen, - errorsSummery, - ) - - // At this moment - memStats := runtime.MemStats{} - runtime.ReadMemStats(&memStats) - log.Printf( - "mem: %d, goroutines: %d", - memStats.HeapAlloc, - runtime.NumGoroutine(), - ) - - // Since the last print - cleanStats := cleaner.dumpStats() - matchedMessages := statsTracker.dumpStats() - log.Printf( - "flushed connections %d, closed connections: %d, deleted messages: %d, matched messages: %d", - cleanStats.flushed, - cleanStats.closed, - cleanStats.deleted, - matchedMessages, - ) - } - }() - - if GetMemoryProfilingEnabled() { - startMemoryProfiler() - } - - for packet := range source.Packets() { - packetsCount := statsTracker.incPacketsCount() - rlog.Debugf("PACKET #%d", packetsCount) - data := packet.Data() - statsTracker.updateProcessedSize(int64(len(data))) - if *hexdumppkt { - rlog.Debugf("Packet content (%d/0x%x) - %s", len(data), len(data), hex.Dump(data)) - } - - // defrag the IPv4 packet if required - if !*nodefrag { - ip4Layer := packet.Layer(layers.LayerTypeIPv4) - if ip4Layer == nil { + if *verbose { + log.Println(packet) + } + if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP { + log.Println("Unusable packet") continue } - ip4 := ip4Layer.(*layers.IPv4) - l := ip4.Length - newip4, err := defragger.DefragIPv4(ip4) - if err != nil { - log.Fatalln("Error while de-fragmenting", err) - } else if newip4 == nil { - rlog.Debugf("Fragment...") - continue // packet fragment, we don't have whole packet yet. - } - if newip4.Length != l { - stats.ipdefrag++ - rlog.Debugf("Decoding re-assembled packet: %s", newip4.NextLayerType()) - pb, ok := packet.(gopacket.PacketBuilder) - if !ok { - log.Panic("Not a PacketBuilder") - } - nextDecoder := newip4.NextLayerType() - _ = nextDecoder.Decode(newip4.Payload, pb) - } - } + tcp := packet.TransportLayer().(*layers.TCP) + assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, packet.Metadata().Timestamp) - tcp := packet.Layer(layers.LayerTypeTCP) - if tcp != nil { - tcp := tcp.(*layers.TCP) - if *checksum { - err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer()) - if err != nil { - log.Fatalf("Failed to set network layer for checksum: %s\n", err) - } - } - c := Context{ - CaptureInfo: packet.Metadata().CaptureInfo, - } - stats.totalsz += len(tcp.Payload) - rlog.Debugf("%s : %v -> %s : %v", packet.NetworkLayer().NetworkFlow().Src(), tcp.SrcPort, packet.NetworkLayer().NetworkFlow().Dst(), tcp.DstPort) - assemblerMutex.Lock() - assembler.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c) - assemblerMutex.Unlock() - } - - done := *maxcount > 0 && statsTracker.appStats.TotalPacketsCount >= *maxcount - if done { - errorsMapMutex.Lock() - errorMapLen := len(errorsMap) - errorsMapMutex.Unlock() - log.Printf("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)", - statsTracker.appStats.TotalPacketsCount, - statsTracker.appStats.TotalProcessedBytes, - time.Since(statsTracker.appStats.StartTime), - nErrors, - errorMapLen) - } - select { - case <-signalChan: - log.Printf("Caught SIGINT: aborting") - done = true - default: - // NOP: continue - } - if done { - break + case <-ticker: + // Every minute, flush connections that haven't seen activity in the past 2 minutes. + assembler.FlushOlderThan(time.Now().Add(time.Minute * -2)) } } - - assemblerMutex.Lock() - closed := assembler.FlushAll() - assemblerMutex.Unlock() - rlog.Debugf("Final flush: %d closed", closed) - if outputLevel >= 2 { - streamPool.Dump() - } - - if *memprofile != "" { - f, err := os.Create(*memprofile) - if err != nil { - log.Fatal(err) - } - _ = pprof.WriteHeapProfile(f) - _ = f.Close() - } - - streamFactory.WaitGoRoutines() - assemblerMutex.Lock() - rlog.Debugf("%s", assembler.Dump()) - assemblerMutex.Unlock() - if !*nodefrag { - log.Printf("IPdefrag:\t\t%d", stats.ipdefrag) - } - log.Printf("TCP stats:") - log.Printf(" missed bytes:\t\t%d", stats.missedBytes) - log.Printf(" total packets:\t\t%d", stats.pkt) - log.Printf(" rejected FSM:\t\t%d", stats.rejectFsm) - log.Printf(" rejected Options:\t%d", stats.rejectOpt) - log.Printf(" reassembled bytes:\t%d", stats.sz) - log.Printf(" total TCP bytes:\t%d", stats.totalsz) - log.Printf(" conn rejected FSM:\t%d", stats.rejectConnFsm) - log.Printf(" reassembled chunks:\t%d", stats.reassembled) - log.Printf(" out-of-order packets:\t%d", stats.outOfOrderPackets) - log.Printf(" out-of-order bytes:\t%d", stats.outOfOrderBytes) - log.Printf(" biggest-chunk packets:\t%d", stats.biggestChunkPackets) - log.Printf(" biggest-chunk bytes:\t%d", stats.biggestChunkBytes) - log.Printf(" overlap packets:\t%d", stats.overlapPackets) - log.Printf(" overlap bytes:\t\t%d", stats.overlapBytes) - log.Printf("Errors: %d", nErrors) - for e := range errorsMap { - log.Printf(" %s:\t\t%d", e, errorsMap[e]) - } - log.Printf("AppStats: %v", GetStats()) } diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index b424bde98..4cbfcf9db 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -1,15 +1,10 @@ package tap import ( - "encoding/binary" - "encoding/hex" - "fmt" - "sync" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" // pulls in all layers decoders - "github.com/google/gopacket/reassembly" + "github.com/google/gopacket" // pulls in all layers decoders + "github.com/google/gopacket/tcpassembly/tcpreader" ) type tcpID struct { @@ -24,161 +19,7 @@ type tcpReaderDataMsg struct { timestamp time.Time } -/* It's a connection (bidirectional) - * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) - * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) - * In our implementation, we pass information from ReassembledSG to the httpReader through a shared channel. - */ type tcpStream struct { - tcpstate *reassembly.TCPSimpleFSM - fsmerr bool - optchecker reassembly.TCPOptionCheck net, transport gopacket.Flow - isDNS bool - isHTTP bool - reversed bool - urls []string - ident string - data []byte - msgQueue chan tcpReaderDataMsg - captureTime time.Time - packetsSeen uint - tcpID tcpID - sync.Mutex -} - -func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { - // FSM - if !t.tcpstate.CheckState(tcp, dir) { - SilentError("FSM-rejection", "%s: Packet rejected by FSM (state:%s)", t.ident, t.tcpstate.String()) - stats.rejectFsm++ - if !t.fsmerr { - t.fsmerr = true - stats.rejectConnFsm++ - } - if !*ignorefsmerr { - return false - } - } - // Options - err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) - if err != nil { - SilentError("OptionChecker-rejection", "%s: Packet rejected by OptionChecker: %s", t.ident, err) - stats.rejectOpt++ - if !*nooptcheck { - return false - } - } - // Checksum - accept := true - if *checksum { - c, err := tcp.ComputeChecksum() - if err != nil { - SilentError("ChecksumCompute", "%s: Got error computing checksum: %s", t.ident, err) - accept = false - } else if c != 0x0 { - SilentError("Checksum", "%s: Invalid checksum: 0x%x", t.ident, c) - accept = false - } - } - if !accept { - stats.rejectOpt++ - } - return accept -} - -func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { - dir, start, end, skip := sg.Info() - length, saved := sg.Lengths() - // update stats - sgStats := sg.Stats() - if skip > 0 { - stats.missedBytes += skip - } - stats.sz += length - saved - stats.pkt += sgStats.Packets - if sgStats.Chunks > 1 { - stats.reassembled++ - } - stats.outOfOrderPackets += sgStats.QueuedPackets - stats.outOfOrderBytes += sgStats.QueuedBytes - if length > stats.biggestChunkBytes { - stats.biggestChunkBytes = length - } - if sgStats.Packets > stats.biggestChunkPackets { - stats.biggestChunkPackets = sgStats.Packets - } - if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 { - // In the original example this was handled with panic(). - // I don't know what this error means or how to handle it properly. - SilentError("Invalid-Overlap", "bytes:%d, pkts:%d", sgStats.OverlapBytes, sgStats.OverlapPackets) - } - stats.overlapBytes += sgStats.OverlapBytes - stats.overlapPackets += sgStats.OverlapPackets - - var ident string - if dir == reassembly.TCPDirClientToServer { - ident = fmt.Sprintf("%v %v(%s): ", t.net, t.transport, dir) - } else { - ident = fmt.Sprintf("%v %v(%s): ", t.net.Reverse(), t.transport.Reverse(), dir) - } - Trace("%s: SG reassembled packet with %d bytes (start:%v,end:%v,skip:%d,saved:%d,nb:%d,%d,overlap:%d,%d)", ident, length, start, end, skip, saved, sgStats.Packets, sgStats.Chunks, sgStats.OverlapBytes, sgStats.OverlapPackets) - if skip == -1 && *allowmissinginit { - // this is allowed - } else if skip != 0 { - // Missing bytes in stream: do not even try to parse it - return - } - data := sg.Fetch(length) - if t.isDNS { - dns := &layers.DNS{} - var decoded []gopacket.LayerType - if len(data) < 2 { - if len(data) > 0 { - sg.KeepFrom(0) - } - return - } - dnsSize := binary.BigEndian.Uint16(data[:2]) - missing := int(dnsSize) - len(data[2:]) - Trace("dnsSize: %d, missing: %d", dnsSize, missing) - if missing > 0 { - Debug("Missing some bytes: %d", missing) - sg.KeepFrom(0) - return - } - p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns) - err := p.DecodeLayers(data[2:], &decoded) - if err != nil { - SilentError("DNS-parser", "Failed to decode DNS: %v", err) - } else { - Trace("DNS: %s", gopacket.LayerDump(dns)) - } - if len(data) > 2+int(dnsSize) { - sg.KeepFrom(2 + int(dnsSize)) - } - } else if t.isHTTP { - if length > 0 { - if *hexdump { - Trace("Feeding http with:%s", hex.Dump(data)) - } - // This is where we pass the reassembled information onwards - // This channel is read by an httpReader object - // if dir == reassembly.TCPDirClientToServer && !t.reversed { - // t.client.msgQueue <- tcpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} - // } else { - // t.server.msgQueue <- tcpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} - // } - } - } -} - -func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { - Trace("%s: Connection closed", t.ident) - // if t.isHTTP { - // close(t.client.msgQueue) - // close(t.server.msgQueue) - // } - // do not remove the connection to allow last ACK - return false + r tcpreader.ReaderStream } diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 370f6434a..775c02de5 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -3,22 +3,15 @@ package tap import ( "bufio" "fmt" - "io" "sync" "github.com/romana/rlog" - "github.com/bradleyfalzon/tlsx" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" // pulls in all layers decoders - "github.com/google/gopacket/reassembly" + "github.com/google/gopacket" // pulls in all layers decoders + "github.com/google/gopacket/tcpassembly" + "github.com/google/gopacket/tcpassembly/tcpreader" ) -/* - * The TCP factory: returns a new Stream - * Implements gopacket.reassembly.StreamFactory interface (New) - * Generates a new tcp stream for each new tcp connection. Closes the stream when the connection closes. - */ type tcpStreamFactory struct { wg sync.WaitGroup doHTTP bool @@ -36,79 +29,27 @@ func containsPort(ports []string, port string) bool { return false } -func (h *tcpStream) run(wg *sync.WaitGroup) { - defer wg.Done() +func (h *tcpStream) run() { + b := bufio.NewReader(&h.r) for _, extension := range extensions { if containsPort(extension.Ports, h.transport.Dst().String()) { - b := bufio.NewReader(h) extension.Dissector.Ping() extension.Dissector.Dissect(b) } } - // b := bufio.NewReader(h) - // fmt.Printf("b: %v\n", b) } -func (h *tcpStream) Read(p []byte) (int, error) { - var msg tcpReaderDataMsg - - ok := true - for ok && len(h.data) == 0 { - msg, ok = <-h.msgQueue - h.data = msg.bytes - - h.captureTime = msg.timestamp - if len(h.data) > 0 { - h.packetsSeen += 1 - } - if h.packetsSeen < checkTLSPacketAmount && len(msg.bytes) > 5 { // packets with less than 5 bytes cause tlsx to panic - clientHello := tlsx.ClientHello{} - err := clientHello.Unmarshall(msg.bytes) - if err == nil { - // statsTracker.incTlsConnectionsCount() - fmt.Printf("Detected TLS client hello with SNI %s\n", clientHello.SNI) - // numericPort, _ := strconv.Atoi(h.tcpID.dstPort) - // h.outboundLinkWriter.WriteOutboundLink(h.tcpID.srcIP, h.tcpID.dstIP, numericPort, clientHello.SNI, TLSProtocol) - } - } - } - if !ok || len(h.data) == 0 { - return 0, io.EOF - } - - l := copy(p, h.data) - h.data = h.data[l:] - return l, nil -} - -func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { - rlog.Debugf("* NEW: %s %s", net, transport) - fsmOptions := reassembly.TCPSimpleFSMOptions{ - SupportMissingEstablishment: *allowmissinginit, - } - rlog.Debugf("Current App Ports: %v", gSettings.filterPorts) - srcIp := net.Src().String() - dstIp := net.Dst().String() - dstPort := int(tcp.DstPort) - - // if factory.shouldNotifyOnOutboundLink(dstIp, dstPort) { - // factory.outbountLinkWriter.WriteOutboundLink(net.Src().String(), dstIp, dstPort, "", "") - // } - props := factory.getStreamProps(srcIp, dstIp, dstPort) - isHTTP := props.isTapTarget +func (h *tcpStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { + fmt.Printf("* NEW: %s %s\n", net, transport) stream := &tcpStream{ - net: net, - transport: transport, - isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, - isHTTP: isHTTP && factory.doHTTP, - reversed: tcp.SrcPort == 80, - tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions), - ident: fmt.Sprintf("%s:%s", net, transport), - optchecker: reassembly.NewTCPOptionCheck(), + net: net, + transport: transport, + r: tcpreader.NewReaderStream(), } - factory.wg.Add(1) - go stream.run(&factory.wg) - return stream + if transport.Dst().String() == "80" { + go stream.run() + } + return &stream.r } func (factory *tcpStreamFactory) WaitGoRoutines() {