From 7d8655feab926b243eeddb516e48c6484b12d76b Mon Sep 17 00:00:00 2001 From: nimrod-up9 <59927337+nimrod-up9@users.noreply.github.com> Date: Wed, 28 Apr 2021 14:21:39 +0300 Subject: [PATCH] Add tap as a separate executable (#10) * Added tap. * Ignore build directories. * Added tapper build to Makefile. --- .gitignore | 3 + Makefile | 9 +- tap/go.mod | 12 + tap/go.sum | 31 +++ tap/src/cleaner.go | 69 +++++ tap/src/grpc_assembler.go | 229 +++++++++++++++ tap/src/har_writer.go | 200 +++++++++++++ tap/src/http_matcher.go | 206 ++++++++++++++ tap/src/http_reader.go | 290 +++++++++++++++++++ tap/src/net_utils.go | 62 +++++ tap/src/passive_tapper.go | 509 ++++++++++++++++++++++++++++++++++ tap/src/stats_tracker.go | 35 +++ tap/src/tap_output.go | 241 ++++++++++++++++ tap/src/tcp_stream.go | 168 +++++++++++ tap/src/tcp_stream_factory.go | 112 ++++++++ 15 files changed, 2175 insertions(+), 1 deletion(-) create mode 100644 tap/go.mod create mode 100644 tap/go.sum create mode 100644 tap/src/cleaner.go create mode 100644 tap/src/grpc_assembler.go create mode 100644 tap/src/har_writer.go create mode 100644 tap/src/http_matcher.go create mode 100644 tap/src/http_reader.go create mode 100644 tap/src/net_utils.go create mode 100644 tap/src/passive_tapper.go create mode 100644 tap/src/stats_tracker.go create mode 100644 tap/src/tap_output.go create mode 100644 tap/src/tcp_stream.go create mode 100644 tap/src/tcp_stream_factory.go diff --git a/.gitignore b/.gitignore index f2ceede70..179f9bdd9 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ # vendor/ .idea/ *.db + +# Build directories +build diff --git a/Makefile b/Makefile index 9393cbeb3..bd71b972c 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ SHELL=/bin/bash # HELP # This will output the help for each task # thanks to https://marmelab.com/blog/2016/02/29/auto-documented-makefile.html -.PHONY: help ui api cli docker +.PHONY: help ui api cli tap docker help: ## This help. @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) @@ -30,6 +30,10 @@ api: ## build API server @(cd api; go build -o build/apiserver main.go) @ls -l api/build +tap: ## build tap binary + @(cd tap; go build -o build/tap ./src) + @ls -l tap/build + docker: ## build Docker image @(echo "building docker image" ) @@ -49,6 +53,9 @@ clean-api: clean-cli: @(echo "CLI cleanup - NOT IMPLEMENTED YET " ) +clean-tap: + @(cd tap; rm -rf build ; echo "tap cleanup done") + clean-docker: @(echo "DOCKER cleanup - NOT IMPLEMENTED YET " ) diff --git a/tap/go.mod b/tap/go.mod new file mode 100644 index 000000000..d6ca4aff3 --- /dev/null +++ b/tap/go.mod @@ -0,0 +1,12 @@ +module passive-tapper + +go 1.13 + +require ( + github.com/google/gopacket v1.1.19 + github.com/google/martian v2.1.0+incompatible + github.com/gorilla/websocket v1.4.2 + github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231 + github.com/patrickmn/go-cache v2.1.0+incompatible + golang.org/x/net v0.0.0-20210421230115-4e50805a0758 +) diff --git a/tap/go.sum b/tap/go.sum new file mode 100644 index 000000000..1b9b47ecc --- /dev/null +++ b/tap/go.sum @@ -0,0 +1,31 @@ +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231 h1:fa50YL1pzKW+1SsBnJDOHppJN9stOEwS+CRWyUtyYGU= +github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758 h1:aEpZnXcAmXkd6AvLb2OPt+EN1Zu/8Ne3pCqPjja5PXY= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe h1:WdX7u8s3yOigWAhHEaDl8r9G+4XwFQEQFtBMYyN+kXQ= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/tap/src/cleaner.go b/tap/src/cleaner.go new file mode 100644 index 000000000..5ee45a961 --- /dev/null +++ b/tap/src/cleaner.go @@ -0,0 +1,69 @@ +package main + +import ( + "sync" + "time" + + "github.com/google/gopacket/reassembly" +) + +type CleanerStats struct { + flushed int + closed int + deleted int +} + +type Cleaner struct { + assembler *reassembly.Assembler + assemblerMutex *sync.Mutex + matcher *requestResponseMatcher + cleanPeriod time.Duration + connectionTimeout time.Duration + stats CleanerStats + statsMutex sync.Mutex +} + +func (cl *Cleaner) clean() { + startCleanTime := time.Now() + + cl.assemblerMutex.Lock() + flushed, closed := cl.assembler.FlushCloseOlderThan(startCleanTime.Add(-cl.connectionTimeout)) + cl.assemblerMutex.Unlock() + + deleted := cl.matcher.deleteOlderThan(startCleanTime.Add(-cl.connectionTimeout)) + + cl.statsMutex.Lock() + cl.stats.flushed += flushed + cl.stats.closed += closed + cl.stats.deleted += deleted + cl.statsMutex.Unlock() +} + +func (cl *Cleaner) start() { + go func() { + ticker := time.NewTicker(cl.cleanPeriod) + + for true { + <-ticker.C + cl.clean() + } + }() +} + +func (cl *Cleaner) dumpStats() CleanerStats { + cl.statsMutex.Lock() + + stats := CleanerStats{ + flushed: cl.stats.flushed, + closed : cl.stats.closed, + deleted: cl.stats.deleted, + } + + cl.stats.flushed = 0 + cl.stats.closed = 0 + cl.stats.deleted = 0 + + cl.statsMutex.Unlock() + + return stats +} diff --git a/tap/src/grpc_assembler.go b/tap/src/grpc_assembler.go new file mode 100644 index 000000000..04d246b82 --- /dev/null +++ b/tap/src/grpc_assembler.go @@ -0,0 +1,229 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/binary" + "errors" + "math" + "net/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +const frameHeaderLen = 9 +var clientPreface = []byte(http2.ClientPreface) +const initialHeaderTableSize = 4096 +const protoHTTP2 = "HTTP/2.0" +const protoMajorHTTP2 = 2 +const protoMinorHTTP2 = 0 + +var maxHTTP2DataLen int = maxHTTP2DataLenDefault // value initialized during init + +type messageFragment struct { + headers []hpack.HeaderField + data []byte +} + +type fragmentsByStream map[uint32]*messageFragment + +func (fbs *fragmentsByStream) appendFrame(streamID uint32, frame http2.Frame) { + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + if existingFragment, ok := (*fbs)[streamID]; ok { + existingFragment.headers = append(existingFragment.headers, frame.Fields...) + } else { + // new fragment + (*fbs)[streamID] = &messageFragment{headers: frame.Fields} + } + case *http2.DataFrame: + newDataLen := len(frame.Data()) + if existingFragment, ok := (*fbs)[streamID]; ok { + existingDataLen := len(existingFragment.data) + // Never save more than maxHTTP2DataLen bytes + numBytesToAppend := int(math.Min(float64(maxHTTP2DataLen - existingDataLen), float64(newDataLen))) + + existingFragment.data = append(existingFragment.data, frame.Data()[:numBytesToAppend]...) + } else { + // new fragment + // In principle, should not happen with DATA frames, because they are always preceded by HEADERS + + // Never save more than maxHTTP2DataLen bytes + numBytesToAppend := int(math.Min(float64(maxHTTP2DataLen), float64(newDataLen))) + + (*fbs)[streamID] = &messageFragment{data: frame.Data()[:numBytesToAppend]} + } + } +} + +func (fbs *fragmentsByStream) pop(streamID uint32) ([]hpack.HeaderField, []byte) { + headers := (*fbs)[streamID].headers + data := (*fbs)[streamID].data + delete((*fbs), streamID) + return headers, data +} + +func createGrpcAssembler(b *bufio.Reader) GrpcAssembler { + var framerOutput bytes.Buffer + framer := http2.NewFramer(&framerOutput, b) + framer.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + return GrpcAssembler{ + fragmentsByStream: make(fragmentsByStream), + framer: framer, + } +} + +type GrpcAssembler struct { + fragmentsByStream fragmentsByStream + framer *http2.Framer +} + +func (ga *GrpcAssembler) readMessage() (uint32, interface{}, string, error) { + // Exactly one Framer is used for each half connection. + // (Instead of creating a new Framer for each ReadFrame operation) + // This is needed in order to decompress the headers, + // because the compression context is updated with each requests/response. + frame, err := ga.framer.ReadFrame() + if err != nil { + return 0, nil, "", err + } + + streamID := frame.Header().StreamID + + ga.fragmentsByStream.appendFrame(streamID, frame) + + if !(ga.isStreamEnd(frame)) { + return 0, nil, "", nil + } + + headers, data := ga.fragmentsByStream.pop(streamID) + + headersHTTP1 := make(http.Header) + for _, header := range headers { + headersHTTP1[header.Name] = []string{header.Value} + } + dataString := base64.StdEncoding.EncodeToString(data) + + // Use http1 types only because they are expected in http_matcher. + // TODO: Create an interface that will be used by http_matcher:registerRequest and http_matcher:registerRequest + // to accept both HTTP/1.x and HTTP/2 requests and responses + var messageHTTP1 interface{} + if _, ok := headersHTTP1[":method"]; ok { + messageHTTP1 = http.Request{ + Header: headersHTTP1, + Proto: protoHTTP2, + ProtoMajor: protoMajorHTTP2, + ProtoMinor: protoMinorHTTP2, + } + } else if _, ok := headersHTTP1[":status"]; ok { + messageHTTP1 = http.Response{ + Header: headersHTTP1, + Proto: protoHTTP2, + ProtoMajor: protoMajorHTTP2, + ProtoMinor: protoMinorHTTP2, + } + } else { + return 0, nil, "", errors.New("Failed to assemble stream: neither a request nor a message") + } + + return streamID, messageHTTP1, dataString, nil +} + +func (ga *GrpcAssembler) isStreamEnd(frame http2.Frame) bool { + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + if frame.StreamEnded() { + return true + } + case *http2.DataFrame: + if frame.StreamEnded() { + return true + } + } + + return false +} + +/* Check if HTTP/2. Remove HTTP/2 client preface from start of buffer if present + */ +func checkIsHTTP2Connection(b *bufio.Reader, isClient bool) (bool, error) { + if isClient { + return checkIsHTTP2ClientStream(b) + } + + return checkIsHTTP2ServerStream(b) +} + +func prepareHTTP2Connection(b *bufio.Reader, isClient bool) error { + if !isClient { + return nil + } + + return discardClientPreface(b) +} + +func checkIsHTTP2ClientStream(b *bufio.Reader) (bool, error) { + return checkClientPreface(b) +} + +func checkIsHTTP2ServerStream(b *bufio.Reader) (bool, error) { + buf, err := b.Peek(frameHeaderLen) + if err != nil { + return false, err + } + + // If response starts with this text, it is HTTP/1.x + if bytes.Compare(buf, []byte("HTTP/1.0 ")) == 0 || bytes.Compare(buf, []byte("HTTP/1.1 ")) == 0 { + return false, nil + } + + // Check server connection preface (a settings frame) + frameHeader := http2.FrameHeader{ + Length: (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])), + Type: http2.FrameType(buf[3]), + Flags: http2.Flags(buf[4]), + StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1), + } + + if frameHeader.Type != http2.FrameSettings { + // If HTTP/2, but not start of stream, will also fulfill this condition. + return false, nil + } + + return true, nil +} + +func checkClientPreface(b *bufio.Reader) (bool, error) { + bytesStart, err := b.Peek(len(clientPreface)) + if err != nil { + return false, err + } else if len(bytesStart) != len(clientPreface) { + return false, errors.New("checkClientPreface: not enough bytes read") + } + + if !bytes.Equal(bytesStart, clientPreface) { + return false, nil + } + + return true, nil +} + +func discardClientPreface(b *bufio.Reader) error { + if isClientPrefacePresent, err := checkClientPreface(b); err != nil { + return err + } else if !isClientPrefacePresent{ + return errors.New("discardClientPreface: does not begin with client preface") + } + + // Remove client preface string from the buffer + n, err := b.Discard(len(clientPreface)) + if err != nil { + return err + } else if n != len(clientPreface) { + return errors.New("discardClientPreface: failed to discard client preface") + } + + return nil +} diff --git a/tap/src/har_writer.go b/tap/src/har_writer.go new file mode 100644 index 000000000..cc38eb84e --- /dev/null +++ b/tap/src/har_writer.go @@ -0,0 +1,200 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/google/martian/har" +) + +const readPermission = 0644 +const tempFilenamePrefix = "har_writer" + +type PairChanItem struct { + Request *http.Request + RequestTime time.Time + Response *http.Response + ResponseTime time.Time +} + +func openNewHarFile(filename string) *HarFile { + file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, readPermission) + if err != nil { + panic(fmt.Sprintf("Failed to open output file: %s (%v,%+v)", err, err, err)) + } + + harFile := HarFile{file: file, entryCount: 0} + harFile.writeHeader() + + return &harFile +} + +type HarFile struct { + file *os.File + entryCount int +} + +func (f *HarFile) WriteEntry(request *http.Request, requestTime time.Time, response *http.Response, responseTime time.Time) { + harRequest, err := har.NewRequest(request, true) + if err != nil { + SilentError("convert-request-to-har", "Failed converting request to HAR %s (%v,%+v)\n", err, err, err) + return + } + + // Martian copies http.Request.URL.String() to har.Request.URL. + // According to the spec, the URL field needs to be the absolute URL. + harRequest.URL = fmt.Sprintf("http://%s%s", request.Host, request.URL) + + harResponse, err := har.NewResponse(response, true) + if err != nil { + SilentError("convert-response-to-har", "Failed converting response to HAR %s (%v,%+v)\n", err, err, err) + return + } + + totalTime := responseTime.Sub(requestTime).Round(time.Millisecond).Milliseconds() + if totalTime < 1 { + totalTime = 1 + } + + harEntry := har.Entry{ + StartedDateTime: time.Now().UTC(), + Time: totalTime, + Request: harRequest, + Response: harResponse, + Cache: &har.Cache{}, + Timings: &har.Timings{ + Send: -1, + Wait: -1, + Receive: totalTime, + }, + } + + harEntryJson, err := json.Marshal(harEntry) + if err != nil { + SilentError("har-entry-marshal", "Failed converting har entry object to JSON%s (%v,%+v)\n", err, err, err) + return + } + + var separator string + if f.GetEntryCount() > 0 { + separator = "," + } else { + separator = "" + } + + harEntryString := append([]byte(separator), harEntryJson...) + + if _, err := f.file.Write(harEntryString); err != nil { + panic(fmt.Sprintf("Failed to write to output file: %s (%v,%+v)", err, err, err)) + } + + f.entryCount++ +} + +func (f *HarFile) GetEntryCount() int { + return f.entryCount +} + +func (f *HarFile) Close() { + f.writeTrailer() + + err := f.file.Close() + if err != nil { + panic(fmt.Sprintf("Failed to close output file: %s (%v,%+v)", err, err, err)) + } +} + +func (f*HarFile) writeHeader() { + header := []byte(`{"log": {"version": "1.2", "creator": {"name": "Mizu", "version": "0.0.1"}, "entries": [`) + if _, err := f.file.Write(header); err != nil { + panic(fmt.Sprintf("Failed to write header to output file: %s (%v,%+v)", err, err, err)) + } +} + +func (f*HarFile) writeTrailer() { + trailer := []byte("]}}") + if _, err := f.file.Write(trailer); err != nil { + panic(fmt.Sprintf("Failed to write trailer to output file: %s (%v,%+v)", err, err, err)) + } +} + +func NewHarWriter(outputDir string, maxEntries int) *HarWriter { + return &HarWriter{ + OutputDirPath: outputDir, + MaxEntries: maxEntries, + PairChan: make(chan *PairChanItem), + currentFile: nil, + done: make(chan bool), + } +} + +type HarWriter struct { + OutputDirPath string + MaxEntries int + PairChan chan *PairChanItem + currentFile *HarFile + done chan bool +} + +func (hw *HarWriter) WritePair(request *http.Request, requestTime time.Time, response *http.Response, responseTime time.Time) { + hw.PairChan <- &PairChanItem{ + Request: request, + RequestTime: requestTime, + Response: response, + ResponseTime: responseTime, + } +} + +func (hw *HarWriter) Start() { + if err := os.MkdirAll(hw.OutputDirPath, os.ModePerm); err != nil { + panic(fmt.Sprintf("Failed to create output directory: %s (%v,%+v)", err, err, err)) + } + + go func() { + for pair := range hw.PairChan { + if hw.currentFile == nil { + hw.openNewFile() + } + + hw.currentFile.WriteEntry(pair.Request, pair.RequestTime, pair.Response, pair.ResponseTime) + + if hw.currentFile.GetEntryCount() >= hw.MaxEntries { + hw.closeFile() + } + } + + if hw.currentFile != nil { + hw.closeFile() + } + hw.done <- true + } () +} + +func (hw *HarWriter) Stop() { + close(hw.PairChan) + <-hw.done +} + +func (hw *HarWriter) openNewFile() { + filename := filepath.Join(os.TempDir(), fmt.Sprintf("%s_%d", tempFilenamePrefix, time.Now().UnixNano())) + hw.currentFile = openNewHarFile(filename) +} + +func (hw *HarWriter) closeFile() { + hw.currentFile.Close() + tmpFilename := hw.currentFile.file.Name() + hw.currentFile = nil + + filename := buildFilename(hw.OutputDirPath, time.Now()) + os.Rename(tmpFilename, filename) +} + +func buildFilename(dir string, t time.Time) string { + // (epoch time in nanoseconds)__(YYYY_Month_DD__hh-mm-ss).har + filename := fmt.Sprintf("%d__%s.har", t.UnixNano(), t.Format("2006_Jan_02__15-04-05")) + return filepath.Join(dir, filename) +} diff --git a/tap/src/http_matcher.go b/tap/src/http_matcher.go new file mode 100644 index 000000000..ee734a9cf --- /dev/null +++ b/tap/src/http_matcher.go @@ -0,0 +1,206 @@ +package main + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/orcaman/concurrent-map" +) + +type requestResponsePair struct { + Request httpMessage `json:"request"` + Response httpMessage `json:"response"` +} + +type envoyMessageWrapper struct { + HttpBufferedTrace requestResponsePair `json:"http_buffered_trace"` +} + +type headerKeyVal struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type messageBody struct { + Truncated bool `json:"truncated"` + AsBytes string `json:"as_bytes"` +} + +type httpMessage struct { + IsRequest bool + Headers []headerKeyVal `json:"headers"` + HTTPVersion string `json:"httpVersion"` + Body messageBody `json:"body"` + captureTime time.Time + orig interface {} +} + +// Key is {client_addr}:{client_port}->{dest_addr}:{dest_port} +type requestResponseMatcher struct { + openMessagesMap cmap.ConcurrentMap + +} + +func createResponseRequestMatcher() requestResponseMatcher { + newMatcher := &requestResponseMatcher{openMessagesMap: cmap.New()} + return *newMatcher +} + +func (matcher *requestResponseMatcher) registerRequest(ident string, request *http.Request, captureTime time.Time, body string, isHTTP2 bool) *envoyMessageWrapper { + split := splitIdent(ident) + key := genKey(split) + + messageExtraHeaders := []headerKeyVal{ + {Key: "x-up9-source", Value: split[0]}, + {Key: "x-up9-destination", Value: split[1] + ":" + split[3]}, + } + + requestHTTPMessage := requestToMessage(request, captureTime, body, &messageExtraHeaders, isHTTP2) + + if response, found := matcher.openMessagesMap.Pop(key); found { + // Type assertion always succeeds because all of the map's values are of httpMessage type + responseHTTPMessage := response.(*httpMessage) + if responseHTTPMessage.IsRequest { + SilentError("Request-Duplicate", "Got duplicate request with same identifier\n") + return nil + } + Debug("Matched open Response for %s\n", key) + return matcher.preparePair(&requestHTTPMessage, responseHTTPMessage) + } + + matcher.openMessagesMap.Set(key, &requestHTTPMessage) + Debug("Registered open Request for %s\n", key) + return nil +} + +func (matcher *requestResponseMatcher) registerResponse(ident string, response *http.Response, captureTime time.Time, body string, isHTTP2 bool) *envoyMessageWrapper { + split := splitIdent(ident) + key := genKey(split) + + responseHTTPMessage := responseToMessage(response, captureTime, body, isHTTP2) + + if request, found := matcher.openMessagesMap.Pop(key); found { + // Type assertion always succeeds because all of the map's values are of httpMessage type + requestHTTPMessage := request.(*httpMessage) + if !requestHTTPMessage.IsRequest { + SilentError("Response-Duplicate", "Got duplicate response with same identifier\n") + return nil + } + Debug("Matched open Request for %s\n", key) + return matcher.preparePair(requestHTTPMessage, &responseHTTPMessage) + } + + matcher.openMessagesMap.Set(key, &responseHTTPMessage) + Debug("Registered open Response for %s\n", key) + return nil +} + +func (matcher *requestResponseMatcher) preparePair(requestHTTPMessage *httpMessage, responseHTTPMessage *httpMessage) *envoyMessageWrapper { + matcher.addDuration(requestHTTPMessage, responseHTTPMessage) + + return &envoyMessageWrapper{ + HttpBufferedTrace: requestResponsePair{ + Request: *requestHTTPMessage, + Response: *responseHTTPMessage, + }, + } +} + +func requestToMessage(request *http.Request, captureTime time.Time, body string, messageExtraHeaders *[]headerKeyVal, isHTTP2 bool) httpMessage { + messageHeaders := make([]headerKeyVal, 0) + + for key, value := range request.Header { + messageHeaders = append(messageHeaders, headerKeyVal{Key: key, Value: value[0]}) + } + + if !isHTTP2 { + messageHeaders = append(messageHeaders, headerKeyVal{Key: ":method", Value: request.Method}) + messageHeaders = append(messageHeaders, headerKeyVal{Key: ":path", Value: request.RequestURI}) + messageHeaders = append(messageHeaders, headerKeyVal{Key: ":authority", Value: request.Host}) + messageHeaders = append(messageHeaders, headerKeyVal{Key: ":scheme", Value: "http"}) + } + + messageHeaders = append(messageHeaders, headerKeyVal{Key: "x-request-start", Value: fmt.Sprintf("%.3f", float64(captureTime.UnixNano()) / float64(1000000000))}) + + messageHeaders = append(messageHeaders, *messageExtraHeaders...) + + httpVersion := request.Proto + + requestBody := messageBody{Truncated: false, AsBytes: body} + + return httpMessage{ + IsRequest: true, + Headers: messageHeaders, + HTTPVersion: httpVersion, + Body: requestBody, + captureTime: captureTime, + orig: request, + } +} + +func responseToMessage(response *http.Response, captureTime time.Time, body string, isHTTP2 bool) httpMessage { + messageHeaders := make([]headerKeyVal, 0) + + for key, value := range response.Header { + messageHeaders = append(messageHeaders, headerKeyVal{Key: key, Value: value[0]}) + } + + if !isHTTP2 { + messageHeaders = append(messageHeaders, headerKeyVal{Key: ":status", Value: strconv.Itoa(response.StatusCode)}) + } + + httpVersion := response.Proto + + requestBody := messageBody{Truncated: false, AsBytes: body} + + return httpMessage{ + IsRequest: false, + Headers: messageHeaders, + HTTPVersion: httpVersion, + Body: requestBody, + captureTime: captureTime, + orig: response, + } +} + +func (matcher *requestResponseMatcher) addDuration(requestHTTPMessage *httpMessage, responseHTTPMessage *httpMessage) { + durationMs := float64(responseHTTPMessage.captureTime.UnixNano() / 1000000) - float64(requestHTTPMessage.captureTime.UnixNano() / 1000000) + if durationMs < 1 { + durationMs = 1 + } + + responseHTTPMessage.Headers = append(responseHTTPMessage.Headers, headerKeyVal{Key: "x-up9-duration-ms", Value: fmt.Sprintf("%.0f", durationMs)}) +} + +func splitIdent(ident string) []string { + ident = strings.Replace(ident, "->", " ", -1) + return strings.Split(ident, " ") +} + +func genKey(split []string) string { + key := fmt.Sprintf("%s:%s->%s:%s,%s", split[0], split[2], split[1], split[3], split[4]) + return key +} + +func (matcher *requestResponseMatcher) deleteOlderThan(t time.Time) int { + keysToPop := make([]string, 0) + for item := range matcher.openMessagesMap.IterBuffered() { + // Map only contains values of type httpMessage + message, _ := item.Val.(*httpMessage) + + if message.captureTime.Before(t) { + keysToPop = append(keysToPop, item.Key) + } + } + + numDeleted := len(keysToPop) + + for _, key := range keysToPop { + _, _ = matcher.openMessagesMap.Pop(key) + } + + return numDeleted +} diff --git a/tap/src/http_reader.go b/tap/src/http_reader.go new file mode 100644 index 000000000..c3dca8f68 --- /dev/null +++ b/tap/src/http_reader.go @@ -0,0 +1,290 @@ +package main + +import ( + "bufio" + "bytes" + "compress/gzip" + b64 "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "sync" + "time" +) + +type httpReaderDataMsg struct { + bytes []byte + timestamp time.Time +} + +type tcpID struct { + srcIP string + dstIP string + srcPort string + dstPort string +} + +func (tid *tcpID) String() string { + return fmt.Sprintf("%s->%s %s->%s", tid.srcIP, tid.dstIP, tid.srcPort, tid.dstPort) +} + +/* httpReader gets reads from a channel of bytes of tcp payload, and parses it into HTTP/1 requests and responses. + * The payload is written to the channel by a tcpStream object that is dedicated to one tcp connection. + * An httpReader object is unidirectional: it parses either a client stream or a server stream. + * Implemets io.Reader interface (Read) + */ +type httpReader struct { + ident string + tcpID tcpID + isClient bool + isHTTP2 bool + msgQueue chan httpReaderDataMsg // Channel of captured reassembled tcp payload + data []byte + captureTime time.Time + hexdump bool + parent *tcpStream + grpcAssembler GrpcAssembler + messageCount uint + harWriter *HarWriter +} + +func (h *httpReader) Read(p []byte) (int, error) { + var msg httpReaderDataMsg + ok := true + for ok && len(h.data) == 0 { + msg, ok = <-h.msgQueue + h.data = msg.bytes + h.captureTime = msg.timestamp + } + if !ok || len(h.data) == 0 { + return 0, io.EOF + } + + l := copy(p, h.data) + h.data = h.data[l:] + return l, nil +} + +func (h *httpReader) run(wg *sync.WaitGroup) { + defer wg.Done() + b := bufio.NewReader(h) + + if isHTTP2, err := checkIsHTTP2Connection(b, h.isClient); err != nil { + SilentError("HTTP/2-Prepare-Connection", "stream %s Failed to check if client is HTTP/2: %s (%v,%+v)\n", h.ident, err, err, err) + // Do something? + } else { + h.isHTTP2 = isHTTP2 + } + + if h.isHTTP2 { + prepareHTTP2Connection(b, h.isClient) + h.grpcAssembler = createGrpcAssembler(b) + } + + for true { + if h.isHTTP2 { + err := h.handleHTTP2Stream(b) + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } else if err != nil { + SilentError("HTTP/2", "stream %s error: %s (%v,%+v)\n", h.ident, err, err, err) + continue + } + } else if h.isClient { + err := h.handleHTTP1ClientStream(b) + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } else if err != nil { + SilentError("HTTP-request", "stream %s Request error: %s (%v,%+v)\n", h.ident, err, err, err) + continue + } + } else { + err := h.handleHTTP1ServerStream(b) + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } else if err != nil { + SilentError("HTTP-response", "stream %s Response error: %s (%v,%+v)\n", h.ident, err, err, err) + continue + } + } + } +} + +func (h *httpReader) handleHTTP2Stream(b *bufio.Reader) error { + streamID, messageHTTP1, body, error := h.grpcAssembler.readMessage() + h.messageCount++ + if error != nil { + return error + } + + var reqResPair *envoyMessageWrapper + + switch messageHTTP1 := messageHTTP1.(type) { + case http.Request: + ident := fmt.Sprintf("%s->%s %s->%s %d", h.tcpID.srcIP, h.tcpID.dstIP, h.tcpID.srcPort, h.tcpID.dstPort, streamID) + reqResPair = reqResMatcher.registerRequest(ident, &messageHTTP1, h.captureTime, body, true) + case http.Response: + ident := fmt.Sprintf("%s->%s %s->%s %d", h.tcpID.dstIP, h.tcpID.srcIP, h.tcpID.dstPort, h.tcpID.srcPort, streamID) + reqResPair = reqResMatcher.registerResponse(ident, &messageHTTP1, h.captureTime, body, true) + } + + if reqResPair != nil { + if h.harWriter != nil { + h.harWriter.WritePair( + reqResPair.HttpBufferedTrace.Request.orig.(*http.Request), + reqResPair.HttpBufferedTrace.Request.captureTime, + reqResPair.HttpBufferedTrace.Response.orig.(*http.Response), + reqResPair.HttpBufferedTrace.Response.captureTime, + ) + } else { + jsonStr, err := json.Marshal(reqResPair) + + broadcastReqResPair(jsonStr) + + if err != nil { + return err + } + } + } + + return nil +} + +func (h *httpReader) handleHTTP1ClientStream(b *bufio.Reader) error { + req, err := http.ReadRequest(b) + h.messageCount++ + if err != nil { + return err + } + body, err := ioutil.ReadAll(req.Body) + req.Body = io.NopCloser(bytes.NewBuffer(body)) // rewind + s := len(body) + if err != nil { + SilentError("HTTP-request-body", "stream %s Got body err: %s\n", h.ident, err) + } else if h.hexdump { + Info("Body(%d/0x%x)\n%s\n", len(body), len(body), hex.Dump(body)) + } + req.Body.Close() + encoding := req.Header["Content-Encoding"] + bodyStr, err := readBody(body, encoding) + if err != nil { + SilentError("HTTP-request-body-decode", "stream %s Failed to decode body: %s\n", h.ident, err) + } + Info("HTTP/%s Request: %s %s (Body:%d)\n", h.ident, req.Method, req.URL, s) + + ident := fmt.Sprintf("%s->%s %s->%s %d", h.tcpID.srcIP, h.tcpID.dstIP, h.tcpID.srcPort, h.tcpID.dstPort, h.messageCount) + reqResPair := reqResMatcher.registerRequest(ident, req, h.captureTime, bodyStr, false) + if reqResPair != nil { + if h.harWriter != nil { + h.harWriter.WritePair( + reqResPair.HttpBufferedTrace.Request.orig.(*http.Request), + reqResPair.HttpBufferedTrace.Request.captureTime, + reqResPair.HttpBufferedTrace.Response.orig.(*http.Response), + reqResPair.HttpBufferedTrace.Response.captureTime, + ) + } else { + jsonStr, err := json.Marshal(reqResPair) + + broadcastReqResPair(jsonStr) + + if err != nil { + SilentError("HTTP-marshal", "stream %s Error convert request response to json: %s\n", h.ident, err) + } + } + } + + h.parent.Lock() + h.parent.urls = append(h.parent.urls, req.URL.String()) + h.parent.Unlock() + + return nil +} + +func (h *httpReader) handleHTTP1ServerStream(b *bufio.Reader) error { + res, err := http.ReadResponse(b, nil) + h.messageCount++ + var req string + h.parent.Lock() + if len(h.parent.urls) == 0 { + req = fmt.Sprintf("") + } else { + req, h.parent.urls = h.parent.urls[0], h.parent.urls[1:] + } + h.parent.Unlock() + if err != nil { + return err + } + body, err := ioutil.ReadAll(res.Body) + res.Body = io.NopCloser(bytes.NewBuffer(body)) // rewind + s := len(body) + if err != nil { + SilentError("HTTP-response-body", "HTTP/%s: failed to get body(parsed len:%d): %s\n", h.ident, s, err) + } + if h.hexdump { + Info("Body(%d/0x%x)\n%s\n", len(body), len(body), hex.Dump(body)) + } + res.Body.Close() + sym := "," + if res.ContentLength > 0 && res.ContentLength != int64(s) { + sym = "!=" + } + contentType, ok := res.Header["Content-Type"] + if !ok { + contentType = []string{http.DetectContentType(body)} + } + encoding := res.Header["Content-Encoding"] + Info("HTTP/%s Response: %s URL:%s (%d%s%d%s) -> %s\n", h.ident, res.Status, req, res.ContentLength, sym, s, contentType, encoding) + bodyStr, err := readBody(body, encoding) + if err != nil { + SilentError("HTTP-response-body-decode", "stream %s Failed to decode body: %s\n", h.ident, err) + } + + ident := fmt.Sprintf("%s->%s %s->%s %d", h.tcpID.dstIP, h.tcpID.srcIP, h.tcpID.dstPort, h.tcpID.srcPort, h.messageCount) + reqResPair := reqResMatcher.registerResponse(ident, res, h.captureTime, bodyStr, false) + if reqResPair != nil { + if h.harWriter != nil { + h.harWriter.WritePair( + reqResPair.HttpBufferedTrace.Request.orig.(*http.Request), + reqResPair.HttpBufferedTrace.Request.captureTime, + reqResPair.HttpBufferedTrace.Response.orig.(*http.Response), + reqResPair.HttpBufferedTrace.Response.captureTime, + ) + } else { + jsonStr, err := json.Marshal(reqResPair) + + broadcastReqResPair(jsonStr) + + if err != nil { + SilentError("HTTP-marshal", "stream %s Error convert request response to json: %s\n", h.ident, err) + } + } + } + + return nil +} + +func readBody(bodyBytes []byte, encoding []string) (string, error) { + var bodyBuffer io.Reader + bodyBuffer = bytes.NewBuffer(bodyBytes) + var err error + if len(encoding) > 0 && (encoding[0] == "gzip" || encoding[0] == "deflate") { + bodyBuffer, err = gzip.NewReader(bodyBuffer) + if err != nil { + SilentError("HTTP-gunzip", "Failed to gzip decode: %s\n", err) + return "", err + } + } + if _, ok := bodyBuffer.(*gzip.Reader); ok { + err = bodyBuffer.(*gzip.Reader).Close() + if err != nil { + return "", err + } + } + + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(bodyBuffer) + return b64.StdEncoding.EncodeToString(buf.Bytes()), err +} diff --git a/tap/src/net_utils.go b/tap/src/net_utils.go new file mode 100644 index 000000000..7219fe90b --- /dev/null +++ b/tap/src/net_utils.go @@ -0,0 +1,62 @@ +package main + +import ( + "net" + "strings" +) + +var privateIPBlocks []*net.IPNet + +func init() { + initPrivateIPBlocks() +} + +// Get this host ipv4 and ipv6 addresses on all interfaces +func getLocalhostIPs() ([]string, error) { + addrMasks, err := net.InterfaceAddrs() + if err != nil { + // TODO: return error, log error + return nil, err + } + + myIPs := make([]string, len(addrMasks)) + for ii, addr := range addrMasks { + myIPs[ii] = strings.Split(addr.String(), "/")[0] + } + + return myIPs, nil +} + +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + for _, block := range privateIPBlocks { + if block.Contains(ip) { + return true + } + } + return false +} + +func initPrivateIPBlocks() { + for _, cidr := range []string{ + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC1918 + "172.16.0.0/12", // RFC1918 + "192.168.0.0/16", // RFC1918 + "169.254.0.0/16", // RFC3927 link-local + "::1/128", // IPv6 loopback + "fe80::/10", // IPv6 link-local + "fc00::/7", // IPv6 unique local addr + } { + _, block, err := net.ParseCIDR(cidr) + if err != nil { + Error("Private-IP-Block-Parse", "parse error on %q: %v", cidr, err) + } else { + privateIPBlocks = append(privateIPBlocks, block) + } + } +} diff --git a/tap/src/passive_tapper.go b/tap/src/passive_tapper.go new file mode 100644 index 000000000..e29903e92 --- /dev/null +++ b/tap/src/passive_tapper.go @@ -0,0 +1,509 @@ +// Copyright 2012 Google, Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. + +// The pcapdump binary implements a tcpdump-like command line tool with gopacket +// using pcap as a backend data collection mechanism. +package main + +import ( + "encoding/hex" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/signal" + "runtime" + "runtime/pprof" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/examples/util" + "github.com/google/gopacket/ip4defrag" + "github.com/google/gopacket/layers" // pulls in all layers decoders + "github.com/google/gopacket/pcap" + "github.com/google/gopacket/reassembly" +) + +const AppPortsEnvVar = "APP_PORTS" +const TapOutPortEnvVar = "WEB_SOCKET_PORT" +const maxHTTP2DataLenEnvVar = "HTTP2_DATA_SIZE_LIMIT" +const hostModeEnvVar = "HOST_MODE" +// default is 1MB, more than the max size accepted by collector and traffic-dumper +const maxHTTP2DataLenDefault = 1 * 1024 * 1024 +const cleanPeriod = time.Second * 10 +const outboundThrottleCacheExpiryPeriod = time.Minute * 15 +var remoteOnlyOutboundPorts = []int { 80, 443 } + +func parseAppPorts(appPortsList string) []int { + ports := make([]int, 0) + for _, portStr := range strings.Split(appPortsList, ",") { + parsedInt, parseError := strconv.Atoi(portStr) + if parseError != nil { + fmt.Println("Provided app port ", portStr, " is not a valid number!") + } else { + ports = append(ports, parsedInt) + } + } + return ports +} + +var maxcount = flag.Int("c", -1, "Only grab this many packets, then exit") +var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)") +var statsevery = flag.Int("stats", 60, "Output statistics every N seconds") +var lazy = flag.Bool("lazy", false, "If true, do lazy decoding") +var nodefrag = flag.Bool("nodefrag", false, "If true, do not do IPv4 defrag") +var checksum = flag.Bool("checksum", false, "Check TCP checksum") // global +var nooptcheck = flag.Bool("nooptcheck", true, "Do not check TCP options (useful to ignore MSS on captures with TSO)") // global +var ignorefsmerr = flag.Bool("ignorefsmerr", true, "Ignore TCP FSM errors") // global +var allowmissinginit = flag.Bool("allowmissinginit", true, "Support streams without SYN/SYN+ACK/ACK sequence") // global +var verbose = flag.Bool("verbose", false, "Be verbose") +var debug = flag.Bool("debug", false, "Display debug information") +var quiet = flag.Bool("quiet", false, "Be quiet regarding errors") + +// http +var nohttp = flag.Bool("nohttp", false, "Disable HTTP parsing") +var output = flag.String("output", "", "Path to create file for HTTP 200 OK responses") +var writeincomplete = flag.Bool("writeincomplete", false, "Write incomplete response") + +var hexdump = flag.Bool("dump", false, "Dump HTTP request/response as hex") // global +var hexdumppkt = flag.Bool("dumppkt", false, "Dump packet as hex") + +// capture +var iface = flag.String("i", "en0", "Interface to read packets from") +var fname = flag.String("r", "", "Filename to read from, overrides -i") +var snaplen = flag.Int("s", 65536, "Snap length (number of bytes max to read per packet") +var tstype = flag.String("timestamp_type", "", "Type of timestamps to use") +var promisc = flag.Bool("promisc", true, "Set promiscuous mode") +var anydirection = flag.Bool("anydirection", false, "Capture http requests to other hosts") +var staleTimeoutSeconds = flag.Int("staletimout", 120, "Max time in seconds to keep connections which don't transmit data") + +var memprofile = flag.String("memprofile", "", "Write memory profile") + +// output +var dumpToHar = flag.Bool("hardump", false, "Dump traffic to har files") +var harOutputDir = flag.String("hardir", "output", "Directory in which to store output har files") +var harEntriesPerFile = flag.Int("harentriesperfile", 200, "Number of max number of har entries to store in each file") + +var reqResMatcher = createResponseRequestMatcher() // global +var statsTracker = StatsTracker{} + +// global +var stats struct { + ipdefrag int + missedBytes int + pkt int + sz int + totalsz int + rejectFsm int + rejectOpt int + rejectConnFsm int + reassembled int + outOfOrderBytes int + outOfOrderPackets int + biggestChunkBytes int + biggestChunkPackets int + overlapBytes int + overlapPackets int +} + +type CollectorMessage struct { + MessageType string + Ports *[]int `json:"ports,omitempty"` + Addresses *[]string `json:"addresses,omitempty"` +} + +var outputLevel int +var errorsMap map[string]uint +var errorsMapMutex sync.Mutex +var nErrors uint +var appPorts []int // global +var ownIps []string //global +var hostMode bool //global +var hostAppAddresses []string //global + +/* minOutputLevel: Error will be printed only if outputLevel is above this value + * t: key for errorsMap (counting errors) + * s, a: arguments fmt.Printf + * Note: Too bad for perf that a... is evaluated + */ +func logError(minOutputLevel int, t string, s string, a ...interface{}) { + errorsMapMutex.Lock() + nErrors++ + nb, _ := errorsMap[t] + errorsMap[t] = nb + 1 + errorsMapMutex.Unlock() + if outputLevel >= minOutputLevel { + formatStr := fmt.Sprintf("%s: %s", t, s) + fmt.Printf(formatStr, a...) + } +} +func Error(t string, s string, a ...interface{}) { + logError(0, t, s, a...) +} +func SilentError(t string, s string, a ...interface{}) { + logError(2, t, s, a...) +} +func Info(s string, a ...interface{}) { + if outputLevel >= 1 { + fmt.Printf(s, a...) + } +} +func Debug(s string, a ...interface{}) { + if outputLevel >= 2 { + fmt.Printf(s, a...) + } +} + +func inArrayInt(arr []int, valueToCheck int) bool { + for _, value := range arr { + if value == valueToCheck { + return true + } + } + return false +} + +func inArrayString(arr []string, valueToCheck string) bool { + for _, value := range arr { + if value == valueToCheck { + return true + } + } + return false +} + +/* + * The assembler context + */ +type Context struct { + CaptureInfo gopacket.CaptureInfo +} + +func (c *Context) GetCaptureInfo() gopacket.CaptureInfo { + return c.CaptureInfo +} + +func main() { + defer util.Run()() + if *debug { + outputLevel = 2 + } else if *verbose { + outputLevel = 1 + } else if *quiet { + outputLevel = -1 + } + errorsMap = make(map[string]uint) + + if localhostIPs, err := getLocalhostIPs(); err != nil { + // TODO: think this over + fmt.Println("Failed to get self IP addresses") + Error("Getting-Self-Address", "Error getting self ip address: %s (%v,%+v)\n", err, err, err) + ownIps = make([]string, 0) + } else { + ownIps = localhostIPs + } + + appPortsStr := os.Getenv(AppPortsEnvVar) + if appPortsStr == "" { + fmt.Println("Received empty/no APP_PORTS env var! only listening to http on port 80!") + appPorts = make([]int, 0) + } else { + appPorts = parseAppPorts(appPortsStr) + } + tapOutputPort := os.Getenv(TapOutPortEnvVar) + if tapOutputPort == "" { + fmt.Println("Received empty/no WEB_SOCKET_PORT env var! falling back to port 8080") + tapOutputPort = "8080" + } + envVal := os.Getenv(maxHTTP2DataLenEnvVar) + if envVal == "" { + fmt.Println("Received empty/no HTTP2_DATA_SIZE_LIMIT env var! falling back to", maxHTTP2DataLenDefault) + maxHTTP2DataLen = maxHTTP2DataLenDefault + } else { + if convertedInt, err := strconv.Atoi(envVal); err != nil { + fmt.Println("Received invalid HTTP2_DATA_SIZE_LIMIT env var! falling back to", maxHTTP2DataLenDefault) + maxHTTP2DataLen = maxHTTP2DataLenDefault + } else { + fmt.Println("Received HTTP2_DATA_SIZE_LIMIT env var:", maxHTTP2DataLenDefault) + maxHTTP2DataLen = convertedInt + } + } + hostMode = os.Getenv(hostModeEnvVar) == "1" + + fmt.Printf("App Ports: %v\n", appPorts) + fmt.Printf("Tap output websocket port: %s\n", tapOutputPort) + + var onCollectorMessage = func(message []byte) { + var parsedMessage CollectorMessage + err := json.Unmarshal(message, &parsedMessage) + if err == nil { + + if parsedMessage.MessageType == "setPorts" { + Debug("Got message from collector. Type: %s, Ports: %v\n", parsedMessage.MessageType, parsedMessage.Ports) + appPorts = *parsedMessage.Ports + } else if parsedMessage.MessageType == "setAddresses" { + Debug("Got message from collector. Type: %s, IPs: %v\n", parsedMessage.MessageType, parsedMessage.Addresses) + hostAppAddresses = *parsedMessage.Addresses + } + } else { + Error("Collector-Message-Parsing", "Error parsing message from collector: %s (%v,%+v)\n", err, err, err) + } + } + + go startOutputServer(tapOutputPort, onCollectorMessage) + + var handle *pcap.Handle + var err error + if *fname != "" { + if handle, err = pcap.OpenOffline(*fname); err != nil { + log.Fatal("PCAP OpenOffline error:", err) + } + } else { + // This is a little complicated because we want to allow all possible options + // for creating the packet capture handle... instead of all this you can + // just call pcap.OpenLive if you want a simple handle. + inactive, err := pcap.NewInactiveHandle(*iface) + if err != nil { + log.Fatal("could not create: %v", err) + } + defer inactive.CleanUp() + if err = inactive.SetSnapLen(*snaplen); err != nil { + log.Fatal("could not set snap length: %v", err) + } else if err = inactive.SetPromisc(*promisc); err != nil { + log.Fatal("could not set promisc mode: %v", err) + } else if err = inactive.SetTimeout(time.Second); err != nil { + log.Fatal("could not set timeout: %v", err) + } + if *tstype != "" { + if t, err := pcap.TimestampSourceFromString(*tstype); err != nil { + log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) + } else if err := inactive.SetTimestampSource(t); err != nil { + log.Fatalf("Supported timestamp types: %v", inactive.SupportedTimestamps()) + } + } + if handle, err = inactive.Activate(); err != nil { + log.Fatal("PCAP Activate error:", err) + } + defer handle.Close() + } + if len(flag.Args()) > 0 { + bpffilter := strings.Join(flag.Args(), " ") + Info("Using BPF filter %q\n", bpffilter) + if err = handle.SetBPFFilter(bpffilter); err != nil { + log.Fatal("BPF filter error:", err) + } + } + + var harWriter *HarWriter + if *dumpToHar { + harWriter = NewHarWriter(*harOutputDir, *harEntriesPerFile) + harWriter.Start() + defer harWriter.Stop() + } + + + var dec gopacket.Decoder + var ok bool + decoder_name := *decoder + if decoder_name == "" { + decoder_name = fmt.Sprintf("%s", handle.LinkType()) + } + if dec, ok = gopacket.DecodersByLayerName[decoder_name]; !ok { + log.Fatalln("No decoder named", decoder_name) + } + source := gopacket.NewPacketSource(handle, dec) + source.Lazy = *lazy + source.NoCopy = true + Info("Starting to read packets\n") + count := 0 + bytes := int64(0) + start := time.Now() + defragger := ip4defrag.NewIPv4Defragmenter() + + streamFactory := &tcpStreamFactory{doHTTP: !*nohttp, harWriter: harWriter} + streamPool := reassembly.NewStreamPool(streamFactory) + assembler := reassembly.NewAssembler(streamPool) + var assemblerMutex sync.Mutex + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + + staleConnectionTimeout := time.Second * time.Duration(*staleTimeoutSeconds) + cleaner := Cleaner{ + assembler: assembler, + assemblerMutex: &assemblerMutex, + matcher: &reqResMatcher, + cleanPeriod: cleanPeriod, + connectionTimeout: staleConnectionTimeout, + } + cleaner.start() + + go func() { + statsPeriod := time.Second * time.Duration(*statsevery) + ticker := time.NewTicker(statsPeriod) + + for true { + <-ticker.C + + // Since the start + errorsMapMutex.Lock() + errorMapLen := len(errorsMap) + errorsSummery := fmt.Sprintf("%v", errorsMap) + errorsMapMutex.Unlock() + fmt.Printf("Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)\nErrors Summary: %s\n", + count, + bytes, + time.Since(start), + nErrors, + errorMapLen, + errorsSummery, + ) + + // At this moment + memStats := runtime.MemStats{} + runtime.ReadMemStats(&memStats) + fmt.Printf( + "mem: %d, goroutines: %d, unmatched messages: %d\n", + memStats.HeapAlloc, + runtime.NumGoroutine(), + reqResMatcher.openMessagesMap.Count(), + ) + + // Since the last print + cleanStats := cleaner.dumpStats() + appStats := statsTracker.dumpStats() + fmt.Printf( + "flushed connections %d, closed connections: %d, deleted messages: %d, matched messages: %d\n", + cleanStats.flushed, + cleanStats.closed, + cleanStats.deleted, + appStats.matchedMessages, + ) + } + }() + + for packet := range source.Packets() { + count++ + Debug("PACKET #%d\n", count) + data := packet.Data() + bytes += int64(len(data)) + if *hexdumppkt { + Debug("Packet content (%d/0x%x)\n%s\n", len(data), len(data), hex.Dump(data)) + } + + // defrag the IPv4 packet if required + if !*nodefrag { + ip4Layer := packet.Layer(layers.LayerTypeIPv4) + if ip4Layer == nil { + continue + } + ip4 := ip4Layer.(*layers.IPv4) + l := ip4.Length + newip4, err := defragger.DefragIPv4(ip4) + if err != nil { + log.Fatalln("Error while de-fragmenting", err) + } else if newip4 == nil { + Debug("Fragment...\n") + continue // packet fragment, we don't have whole packet yet. + } + if newip4.Length != l { + stats.ipdefrag++ + Debug("Decoding re-assembled packet: %s\n", newip4.NextLayerType()) + pb, ok := packet.(gopacket.PacketBuilder) + if !ok { + panic("Not a PacketBuilder") + } + nextDecoder := newip4.NextLayerType() + nextDecoder.Decode(newip4.Payload, pb) + } + } + + tcp := packet.Layer(layers.LayerTypeTCP) + if tcp != nil { + tcp := tcp.(*layers.TCP) + if *checksum { + err := tcp.SetNetworkLayerForChecksum(packet.NetworkLayer()) + if err != nil { + log.Fatalf("Failed to set network layer for checksum: %s\n", err) + } + } + c := Context{ + CaptureInfo: packet.Metadata().CaptureInfo, + } + stats.totalsz += len(tcp.Payload) + //fmt.Println(packet.NetworkLayer().NetworkFlow().Src(), ":", tcp.SrcPort, " -> ", packet.NetworkLayer().NetworkFlow().Dst(), ":", tcp.DstPort) + assemblerMutex.Lock() + assembler.AssembleWithContext(packet.NetworkLayer().NetworkFlow(), tcp, &c) + assemblerMutex.Unlock() + } + + done := *maxcount > 0 && count >= *maxcount + if done { + errorsMapMutex.Lock() + errorMapLen := len(errorsMap) + errorsMapMutex.Unlock() + fmt.Fprintf(os.Stderr, "Processed %v packets (%v bytes) in %v (errors: %v, errTypes:%v)\n", count, bytes, time.Since(start), nErrors, errorMapLen) + } + select { + case <-signalChan: + fmt.Fprintf(os.Stderr, "\nCaught SIGINT: aborting\n") + done = true + default: + // NOP: continue + } + if done { + break + } + } + + assemblerMutex.Lock() + closed := assembler.FlushAll() + assemblerMutex.Unlock() + Debug("Final flush: %d closed", closed) + if outputLevel >= 2 { + streamPool.Dump() + } + + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal(err) + } + pprof.WriteHeapProfile(f) + f.Close() + } + + streamFactory.WaitGoRoutines() + assemblerMutex.Lock() + Debug("%s\n", assembler.Dump()) + assemblerMutex.Unlock() + if !*nodefrag { + fmt.Printf("IPdefrag:\t\t%d\n", stats.ipdefrag) + } + fmt.Printf("TCP stats:\n") + fmt.Printf(" missed bytes:\t\t%d\n", stats.missedBytes) + fmt.Printf(" total packets:\t\t%d\n", stats.pkt) + fmt.Printf(" rejected FSM:\t\t%d\n", stats.rejectFsm) + fmt.Printf(" rejected Options:\t%d\n", stats.rejectOpt) + fmt.Printf(" reassembled bytes:\t%d\n", stats.sz) + fmt.Printf(" total TCP bytes:\t%d\n", stats.totalsz) + fmt.Printf(" conn rejected FSM:\t%d\n", stats.rejectConnFsm) + fmt.Printf(" reassembled chunks:\t%d\n", stats.reassembled) + fmt.Printf(" out-of-order packets:\t%d\n", stats.outOfOrderPackets) + fmt.Printf(" out-of-order bytes:\t%d\n", stats.outOfOrderBytes) + fmt.Printf(" biggest-chunk packets:\t%d\n", stats.biggestChunkPackets) + fmt.Printf(" biggest-chunk bytes:\t%d\n", stats.biggestChunkBytes) + fmt.Printf(" overlap packets:\t%d\n", stats.overlapPackets) + fmt.Printf(" overlap bytes:\t\t%d\n", stats.overlapBytes) + fmt.Printf("Errors: %d\n", nErrors) + for e, _ := range errorsMap { + fmt.Printf(" %s:\t\t%d\n", e, errorsMap[e]) + } +} diff --git a/tap/src/stats_tracker.go b/tap/src/stats_tracker.go new file mode 100644 index 000000000..b362b4bb4 --- /dev/null +++ b/tap/src/stats_tracker.go @@ -0,0 +1,35 @@ +package main + +import ( + "sync" +) + +type AppStats struct { + matchedMessages int +} + +type StatsTracker struct { + stats AppStats + statsMutex sync.Mutex +} + +func (st *StatsTracker) incMatchedMessages() { + st.statsMutex.Lock() + st.stats.matchedMessages++ + st.statsMutex.Unlock() +} + +func (st *StatsTracker) dumpStats() AppStats { + st.statsMutex.Lock() + + stats := AppStats{ + matchedMessages: st.stats.matchedMessages, + } + + st.stats.matchedMessages = 0 + + st.statsMutex.Unlock() + + return stats +} + diff --git a/tap/src/tap_output.go b/tap/src/tap_output.go new file mode 100644 index 000000000..f96a44330 --- /dev/null +++ b/tap/src/tap_output.go @@ -0,0 +1,241 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/patrickmn/go-cache" +) + + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +var ( + newline = []byte{'\n'} + space = []byte{' '} + hub *Hub + outboundSocketNotifyExpiringCache = cache.New(outboundThrottleCacheExpiryPeriod, outboundThrottleCacheExpiryPeriod) +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func (_ *http.Request) bool { return true }, +} + +// Client is a middleman between the websocket connection and the hub. +type Client struct { + hub *Hub + + // The websocket connection. + conn *websocket.Conn + + // Buffered channel of outbound messages. + send chan []byte +} + +type OutBoundLinkMessage struct { + SourceIP string `json:"sourceIP"` + IP string `json:"ip"` + Port int `json:"port"` + Type string `json:"type"` +} + + +// readPump pumps messages from the websocket connection to the hub. +// +// The application runs readPump in a per-connection goroutine. The application +// ensures that there is at most one reader on a connection by executing all +// reads from this goroutine. +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + break + } + message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + c.hub.onMessageCallback(message) + } +} + +// writePump pumps messages from the hub to the websocket connection. +// +// A goroutine running writePump is started for each connection. The +// application ensures that there is at most one writer to a connection by +// executing all writes from this goroutine. +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +type Hub struct { + // Registered clients. + clients map[*Client]bool + + // Inbound messages from the clients. + broadcast chan []byte + + // Register requests from the clients. + register chan *Client + + // Unregister requests from clients. + unregister chan *Client + + // Handle messages from client + onMessageCallback func([]byte) + + +} + +func newHub(onMessageCallback func([]byte)) *Hub { + return &Hub{ + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[*Client]bool), + onMessageCallback: onMessageCallback, + } +} + +func (h *Hub) run() { + for { + select { + case client := <-h.register: + h.clients[client] = true + case client := <-h.unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + case message := <-h.broadcast: + // matched messages counter is incremented in this thread instead of in multiple http reader + // threads in order to reduce contention. + statsTracker.incMatchedMessages() + + for client := range h.clients { + select { + case client.send <- message: + default: + close(client.send) + delete(h.clients, client) + } + } + } + } +} + + +// serveWs handles websocket requests from the peer. +func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client.hub.register <- client + + // Allow collection of memory referenced by the caller by doing all work in + // new goroutines. + go client.writePump() + go client.readPump() +} + +func startOutputServer(port string, messageCallback func([]byte)) { + flag.Parse() + hub = newHub(messageCallback) + go hub.run() + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + serveWs(hub, w, r) + }) + err := http.ListenAndServe("0.0.0.0:" + port, nil) + if err != nil { + log.Fatal("Output server error: ", err) + } +} + +func broadcastReqResPair(reqResJson []byte) { + hub.broadcast <- reqResJson +} + +func broadcastOutboundLink(srcIP string, dstIP string, dstPort int) { + cacheKey := fmt.Sprintf("%s -> %s:%d", srcIP, dstIP, dstPort) + _, isInCache := outboundSocketNotifyExpiringCache.Get(cacheKey) + if isInCache { + return + } else { + outboundSocketNotifyExpiringCache.SetDefault(cacheKey, true) + } + + socketMessage := OutBoundLinkMessage{ + SourceIP: srcIP, + IP: dstIP, + Port: dstPort, + Type: "outboundSocketDetected", + } + + jsonStr, err := json.Marshal(socketMessage) + if err != nil { + log.Printf("error marshalling outbound socket detection object: %v", err) + } else { + hub.broadcast <- jsonStr + } +} diff --git a/tap/src/tcp_stream.go b/tap/src/tcp_stream.go new file mode 100644 index 000000000..2150f2b9e --- /dev/null +++ b/tap/src/tcp_stream.go @@ -0,0 +1,168 @@ +package main + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "sync" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" // pulls in all layers decoders + "github.com/google/gopacket/reassembly" +) + +/* It's a connection (bidirectional) + * Implements gopacket.reassembly.Stream interface (Accept, ReassembledSG, ReassemblyComplete) + * ReassembledSG gets called when new reassembled data is ready (i.e. bytes in order, no duplicates, complete) + * In our implementation, we pass information from ReassembledSG to the httpReader through a shared channel. + */ +type tcpStream struct { + tcpstate *reassembly.TCPSimpleFSM + fsmerr bool + optchecker reassembly.TCPOptionCheck + net, transport gopacket.Flow + isDNS bool + isHTTP bool + reversed bool + client httpReader + server httpReader + urls []string + ident string + sync.Mutex +} + +func (t *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { + // FSM + if !t.tcpstate.CheckState(tcp, dir) { + //SilentError("FSM", "%s: Packet rejected by FSM (state:%s)\n", t.ident, t.tcpstate.String()) + stats.rejectFsm++ + if !t.fsmerr { + t.fsmerr = true + stats.rejectConnFsm++ + } + if !*ignorefsmerr { + return false + } + } + // Options + err := t.optchecker.Accept(tcp, ci, dir, nextSeq, start) + if err != nil { + //SilentError("OptionChecker", "%s: Packet rejected by OptionChecker: %s\n", t.ident, err) + stats.rejectOpt++ + if !*nooptcheck { + return false + } + } + // Checksum + accept := true + if *checksum { + c, err := tcp.ComputeChecksum() + if err != nil { + SilentError("ChecksumCompute", "%s: Got error computing checksum: %s\n", t.ident, err) + accept = false + } else if c != 0x0 { + SilentError("Checksum", "%s: Invalid checksum: 0x%x\n", t.ident, c) + accept = false + } + } + if !accept { + stats.rejectOpt++ + } + return accept +} + +func (t *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { + dir, start, end, skip := sg.Info() + length, saved := sg.Lengths() + // update stats + sgStats := sg.Stats() + if skip > 0 { + stats.missedBytes += skip + } + stats.sz += length - saved + stats.pkt += sgStats.Packets + if sgStats.Chunks > 1 { + stats.reassembled++ + } + stats.outOfOrderPackets += sgStats.QueuedPackets + stats.outOfOrderBytes += sgStats.QueuedBytes + if length > stats.biggestChunkBytes { + stats.biggestChunkBytes = length + } + if sgStats.Packets > stats.biggestChunkPackets { + stats.biggestChunkPackets = sgStats.Packets + } + if sgStats.OverlapBytes != 0 && sgStats.OverlapPackets == 0 { + // In the original example this was handled with panic(). + // I don't know what this error means or how to handle it properly. + SilentError("Invalid-Overlap", "bytes:%d, pkts:%d\n", sgStats.OverlapBytes, sgStats.OverlapPackets) + } + stats.overlapBytes += sgStats.OverlapBytes + stats.overlapPackets += sgStats.OverlapPackets + + var ident string + if dir == reassembly.TCPDirClientToServer { + ident = fmt.Sprintf("%v %v(%s): ", t.net, t.transport, dir) + } else { + ident = fmt.Sprintf("%v %v(%s): ", t.net.Reverse(), t.transport.Reverse(), dir) + } + Debug("%s: SG reassembled packet with %d bytes (start:%v,end:%v,skip:%d,saved:%d,nb:%d,%d,overlap:%d,%d)\n", ident, length, start, end, skip, saved, sgStats.Packets, sgStats.Chunks, sgStats.OverlapBytes, sgStats.OverlapPackets) + if skip == -1 && *allowmissinginit { + // this is allowed + } else if skip != 0 { + // Missing bytes in stream: do not even try to parse it + return + } + data := sg.Fetch(length) + if t.isDNS { + dns := &layers.DNS{} + var decoded []gopacket.LayerType + if len(data) < 2 { + if len(data) > 0 { + sg.KeepFrom(0) + } + return + } + dnsSize := binary.BigEndian.Uint16(data[:2]) + missing := int(dnsSize) - len(data[2:]) + Debug("dnsSize: %d, missing: %d\n", dnsSize, missing) + if missing > 0 { + Info("Missing some bytes: %d\n", missing) + sg.KeepFrom(0) + return + } + p := gopacket.NewDecodingLayerParser(layers.LayerTypeDNS, dns) + err := p.DecodeLayers(data[2:], &decoded) + if err != nil { + SilentError("DNS-parser", "Failed to decode DNS: %v\n", err) + } else { + Debug("DNS: %s\n", gopacket.LayerDump(dns)) + } + if len(data) > 2+int(dnsSize) { + sg.KeepFrom(2 + int(dnsSize)) + } + } else if t.isHTTP { + if length > 0 { + if *hexdump { + Debug("Feeding http with:\n%s", hex.Dump(data)) + } + // This is where we pass the reassembled information onwards + // This channel is read by an httpReader object + if dir == reassembly.TCPDirClientToServer && !t.reversed { + t.client.msgQueue <- httpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} + } else { + t.server.msgQueue <- httpReaderDataMsg{data, ac.GetCaptureInfo().Timestamp} + } + } + } +} + +func (t *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { + Debug("%s: Connection closed\n", t.ident) + if t.isHTTP { + close(t.client.msgQueue) + close(t.server.msgQueue) + } + // do not remove the connection to allow last ACK + return false +} diff --git a/tap/src/tcp_stream_factory.go b/tap/src/tcp_stream_factory.go new file mode 100644 index 000000000..bf69f2622 --- /dev/null +++ b/tap/src/tcp_stream_factory.go @@ -0,0 +1,112 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" // pulls in all layers decoders + "github.com/google/gopacket/reassembly" +) + +/* + * The TCP factory: returns a new Stream + * Implements gopacket.reassembly.StreamFactory interface (New) + * Generates a new tcp stream for each new tcp connection. Closes the stream when the connection closes. + */ +type tcpStreamFactory struct { + wg sync.WaitGroup + doHTTP bool + harWriter *HarWriter +} + +func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { + Debug("* NEW: %s %s\n", net, transport) + fsmOptions := reassembly.TCPSimpleFSMOptions{ + SupportMissingEstablishment: *allowmissinginit, + } + Debug("Current App Ports: %v\n", appPorts) + dstIp := net.Dst().String() + dstPort := int(tcp.DstPort) + + if factory.shouldNotifyOnOutboundLink(dstIp, dstPort) { + broadcastOutboundLink(net.Src().String(), dstIp, dstPort) + } + isHTTP := factory.shouldTap(dstIp, dstPort) + stream := &tcpStream{ + net: net, + transport: transport, + isDNS: tcp.SrcPort == 53 || tcp.DstPort == 53, + isHTTP: isHTTP && factory.doHTTP, + reversed: tcp.SrcPort == 80, + tcpstate: reassembly.NewTCPSimpleFSM(fsmOptions), + ident: fmt.Sprintf("%s:%s", net, transport), + optchecker: reassembly.NewTCPOptionCheck(), + } + if stream.isHTTP { + stream.client = httpReader{ + msgQueue: make(chan httpReaderDataMsg), + ident: fmt.Sprintf("%s %s", net, transport), + tcpID: tcpID{ + srcIP: net.Src().String(), + dstIP: net.Dst().String(), + srcPort: transport.Src().String(), + dstPort: transport.Dst().String(), + }, + hexdump: *hexdump, + parent: stream, + isClient: true, + harWriter: factory.harWriter, + } + stream.server = httpReader{ + msgQueue: make(chan httpReaderDataMsg), + ident: fmt.Sprintf("%s %s", net.Reverse(), transport.Reverse()), + tcpID: tcpID{ + srcIP: net.Dst().String(), + dstIP: net.Src().String(), + srcPort: transport.Dst().String(), + dstPort: transport.Src().String(), + }, + hexdump: *hexdump, + parent: stream, + harWriter: factory.harWriter, + } + factory.wg.Add(2) + // Start reading from channels stream.client.bytes and stream.server.bytes + go stream.client.run(&factory.wg) + go stream.server.run(&factory.wg) + } + return stream +} + +func (factory *tcpStreamFactory) WaitGoRoutines() { + factory.wg.Wait() +} + +func (factory *tcpStreamFactory) shouldTap(dstIP string, dstPort int) bool { + if hostMode { + return inArrayString(hostAppAddresses, fmt.Sprintf("%s:%d", dstIP, dstPort)) + } else { + isTappedPort := dstPort == 80 || (appPorts != nil && (inArrayInt(appPorts, dstPort))) + if !isTappedPort { + return false + } + + if !*anydirection { + isDirectedHere := inArrayString(ownIps, dstIP) + if !isDirectedHere { + return false + } + } + + return true + } +} + +func (factory *tcpStreamFactory) shouldNotifyOnOutboundLink(dstIP string, dstPort int) bool { + if inArrayInt(remoteOnlyOutboundPorts, dstPort) { + isDirectedHere := inArrayString(ownIps, dstIP) + return !isDirectedHere && !isPrivateIP(dstIP) + } + return true +}