diff --git a/tap/api/api.go b/tap/api/api.go index fd2a7ce52..586ff592b 100644 --- a/tap/api/api.go +++ b/tap/api/api.go @@ -118,6 +118,11 @@ func (p *ReadProgress) Current() (n int) { return p.lastCurrent } +func (p *ReadProgress) Reset() { + p.readBytes = 0 + p.lastCurrent = 0 +} + type Dissector interface { Register(*Extension) Ping() diff --git a/tap/tcp_reader.go b/tap/tcp_reader.go index 4f6205973..d7d6ff896 100644 --- a/tap/tcp_reader.go +++ b/tap/tcp_reader.go @@ -15,21 +15,23 @@ import ( * Implements io.Reader interface (Read) */ type tcpReader struct { - ident string - tcpID *api.TcpID - isClosed bool - isClient bool - isOutgoing bool - msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload - pastData []byte - data []byte - progress *api.ReadProgress - captureTime time.Time - parent *tcpStream - packetsSeen uint - emitter api.Emitter - counterPair *api.CounterPair - reqResMatcher api.RequestResponseMatcher + ident string + tcpID *api.TcpID + isClosed bool + isClient bool + isOutgoing bool + msgQueue chan api.TcpReaderDataMsg // Channel of captured reassembled tcp payload + msgBuffer []api.TcpReaderDataMsg + msgBufferMaster []api.TcpReaderDataMsg + exhaustBuffer bool + data []byte + progress *api.ReadProgress + captureTime time.Time + parent *tcpStream + packetsSeen uint + emitter api.Emitter + counterPair *api.CounterPair + reqResMatcher api.RequestResponseMatcher sync.Mutex } @@ -83,13 +85,40 @@ func (reader *tcpReader) isProtocolIdentified() bool { } func (reader *tcpReader) rewind() { - reader.data = make([]byte, len(reader.pastData)) - copy(reader.data, reader.pastData) + // Tell Read to exhaust the msgBuffer + reader.exhaustBuffer = true + + // Reset the data and msgBuffer from the master record + reader.data = []byte{} + reader.msgBuffer = make([]api.TcpReaderDataMsg, len(reader.msgBufferMaster)) + copy(reader.msgBuffer, reader.msgBufferMaster) + + // Reset the read progress + reader.progress.Reset() } func (reader *tcpReader) Read(p []byte) (int, error) { var msg api.TcpReaderDataMsg + if reader.exhaustBuffer && len(reader.data) == 0 { + if len(reader.msgBuffer) > 0 { + // Pop first message + msg, reader.msgBuffer = reader.msgBuffer[0], reader.msgBuffer[1:] + + // Get the bytes + reader.data = msg.GetBytes() + reader.captureTime = msg.GetTimestamp() + + // Set exhaustBuffer to false if we exhaust the msgBuffer + if len(reader.msgBuffer) == 0 { + reader.exhaustBuffer = false + } + } else { + // Buffer is empty + reader.exhaustBuffer = false + } + } + ok := true for ok && len(reader.data) == 0 { msg, ok = <-reader.msgQueue @@ -97,7 +126,7 @@ func (reader *tcpReader) Read(p []byte) (int, error) { reader.data = msg.GetBytes() reader.captureTime = msg.GetTimestamp() if !reader.isProtocolIdentified() { - reader.pastData = append(reader.pastData, reader.data...) + reader.msgBufferMaster = append(reader.msgBufferMaster, msg) } } diff --git a/tap/tcp_stream.go b/tap/tcp_stream.go index 8a30441f7..35c25f5ad 100644 --- a/tap/tcp_stream.go +++ b/tap/tcp_stream.go @@ -69,8 +69,12 @@ func (t *tcpStream) addReqResMatcher(reqResMatcher api.RequestResponseMatcher) { func (t *tcpStream) SetProtocol(protocol *api.Protocol) { t.protocol = protocol - t.client.pastData = []byte{} - t.server.pastData = []byte{} + + // Clean the buffers + t.client.msgBuffer = []api.TcpReaderDataMsg{} + t.client.msgBufferMaster = []api.TcpReaderDataMsg{} + t.server.msgBuffer = []api.TcpReaderDataMsg{} + t.server.msgBufferMaster = []api.TcpReaderDataMsg{} } func (t *tcpStream) GetOrigin() api.Capture {