* some debugging and fixing in the proxy (still needs some work)

* added a logger and stringers for message types
This commit is contained in:
betzalel 2017-07-11 16:50:06 +03:00
parent 092f92264a
commit 18bef62b79
28 changed files with 1340 additions and 726 deletions

File diff suppressed because it is too large Load Diff

25
.vscode/launch.json vendored
View File

@ -2,7 +2,7 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Launch Test", "name": "Launch Server Test",
"type": "go", "type": "go",
"request": "launch", "request": "launch",
"mode": "test", "mode": "test",
@ -24,6 +24,29 @@
}, },
"showLog": true "showLog": true
}, },
{
"name": "Launch Proxy Test",
"type": "go",
"request": "launch",
"mode": "test",
"remotePath": "",
"port": 2345,
"program": "${workspaceRoot}/proxy",
"args": [
"-test.v"
],
"osx": {
"env": {
//"GOPATH": "/Users/amitbet/Dropbox/go"
}
},
"windows": {
"env": {
//"GOPATH": "${env.USERPROFILE}\\Dropbox\\go"
}
},
"showLog": true
},
{ {
"name": "Launch", "name": "Launch",
"type": "go", "type": "go",

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"unicode" "unicode"
"vncproxy/common" "vncproxy/common"
"vncproxy/logger"
) )
// A ServerMessage implements a message sent from the server to the client. // A ServerMessage implements a message sent from the server to the client.
@ -52,7 +53,7 @@ type ClientConn struct {
// SetPixelFormat method. // SetPixelFormat method.
PixelFormat common.PixelFormat PixelFormat common.PixelFormat
Listener common.SegmentConsumer Listeners *common.MultiListener
} }
// A ClientConfig structure is used to configure a ClientConn. After // A ClientConfig structure is used to configure a ClientConn. After
@ -80,20 +81,26 @@ type ClientConfig struct {
ServerMessages []common.ServerMessage ServerMessages []common.ServerMessage
} }
func Client(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
conn := &ClientConn{ conn := &ClientConn{
conn: c, conn: c,
config: cfg, config: cfg,
Listeners: &common.MultiListener{},
} }
return conn, nil
}
func (conn *ClientConn) Connect() error {
if err := conn.handshake(); err != nil { if err := conn.handshake(); err != nil {
logger.Errorf("ClientConn.Connect error: %v", err)
conn.Close() conn.Close()
return nil, err return err
} }
go conn.mainLoop() go conn.mainLoop()
return conn, nil return nil
} }
func (c *ClientConn) Close() error { func (c *ClientConn) Close() error {
@ -360,11 +367,12 @@ func (c *ClientConn) handshake() error {
// 7.1.2 Security Handshake from server // 7.1.2 Security Handshake from server
var numSecurityTypes uint8 var numSecurityTypes uint8
if err = binary.Read(c.conn, binary.BigEndian, &numSecurityTypes); err != nil { if err = binary.Read(c.conn, binary.BigEndian, &numSecurityTypes); err != nil {
return fmt.Errorf("Error reading security types: %v", err)
return err return err
} }
if numSecurityTypes == 0 { if numSecurityTypes == 0 {
return fmt.Errorf("no security types: %s", c.readErrorReason()) return fmt.Errorf("Error: no security types: %s", c.readErrorReason())
} }
securityTypes := make([]uint8, numSecurityTypes) securityTypes := make([]uint8, numSecurityTypes)
@ -455,9 +463,8 @@ FindAuth:
PixelFormat: c.PixelFormat, PixelFormat: c.PixelFormat,
} }
rfbSeg := &common.RfbSegment{SegmentType: common.SegmentServerInitMessage, Message: &srvInit} rfbSeg := &common.RfbSegment{SegmentType: common.SegmentServerInitMessage, Message: &srvInit}
c.Listener.Consume(rfbSeg)
return nil return c.Listeners.Consume(rfbSeg)
} }
// mainLoop reads messages sent from the server and routes them to the // mainLoop reads messages sent from the server and routes them to the
@ -465,7 +472,7 @@ FindAuth:
func (c *ClientConn) mainLoop() { func (c *ClientConn) mainLoop() {
defer c.Close() defer c.Close()
reader := &common.RfbReadHelper{Reader: c.conn, Listener: c.Listener} reader := &common.RfbReadHelper{Reader: c.conn, Listeners: c.Listeners}
// Build the map of available server messages // Build the map of available server messages
typeMap := make(map[uint8]common.ServerMessage) typeMap := make(map[uint8]common.ServerMessage)
@ -486,6 +493,10 @@ func (c *ClientConn) mainLoop() {
} }
} }
defer func(){
logger.Warn("ClientConn.MainLoop: exiting!")
}()
for { for {
var messageType uint8 var messageType uint8
if err := binary.Read(c.conn, binary.BigEndian, &messageType); err != nil { if err := binary.Read(c.conn, binary.BigEndian, &messageType); err != nil {
@ -497,6 +508,7 @@ func (c *ClientConn) mainLoop() {
// Unsupported message type! Bad! // Unsupported message type! Bad!
break break
} }
logger.Debugf("ClientConn.MainLoop: got ServerMessage:%s", common.ServerMessageType(messageType))
reader.SendMessageSeparator(common.ServerMessageType(messageType)) reader.SendMessageSeparator(common.ServerMessageType(messageType))
reader.PublishBytes([]byte{byte(messageType)}) reader.PublishBytes([]byte{byte(messageType)})
@ -504,6 +516,7 @@ func (c *ClientConn) mainLoop() {
if err != nil { if err != nil {
break break
} }
logger.Debugf("ClientConn.MainLoop: read & parsed ServerMessage:%s, %v", parsedMsg.Type(), parsedMsg)
if c.config.ServerMessageCh == nil { if c.config.ServerMessageCh == nil {
continue continue

View File

@ -29,7 +29,7 @@ func readPixelFormat(r io.Reader, result *common.PixelFormat) error {
if pfBoolByte != 0 { if pfBoolByte != 0 {
// Big endian is true // Big endian is true
result.BigEndian = true result.BigEndian = 1
} }
if err := binary.Read(brPF, binary.BigEndian, &pfBoolByte); err != nil { if err := binary.Read(brPF, binary.BigEndian, &pfBoolByte); err != nil {
@ -38,7 +38,7 @@ func readPixelFormat(r io.Reader, result *common.PixelFormat) error {
if pfBoolByte != 0 { if pfBoolByte != 0 {
// True Color is true. So we also have to read all the color max & shifts. // True Color is true. So we also have to read all the color max & shifts.
result.TrueColor = true result.TrueColor = 1
if err := binary.Read(brPF, binary.BigEndian, &result.RedMax); err != nil { if err := binary.Read(brPF, binary.BigEndian, &result.RedMax); err != nil {
return err return err
@ -82,7 +82,7 @@ func writePixelFormat(format *common.PixelFormat) ([]byte, error) {
} }
var boolByte byte var boolByte byte
if format.BigEndian { if format.BigEndian == 1 {
boolByte = 1 boolByte = 1
} else { } else {
boolByte = 0 boolByte = 0
@ -93,7 +93,7 @@ func writePixelFormat(format *common.PixelFormat) ([]byte, error) {
return nil, err return nil, err
} }
if format.TrueColor { if format.TrueColor == 1 {
boolByte = 1 boolByte = 1
} else { } else {
boolByte = 0 boolByte = 0
@ -106,7 +106,7 @@ func writePixelFormat(format *common.PixelFormat) ([]byte, error) {
// If we have true color enabled then we have to fill in the rest of the // If we have true color enabled then we have to fill in the rest of the
// structure with the color values. // structure with the color values.
if format.TrueColor { if format.TrueColor == 1 {
if err := binary.Write(&buf, binary.BigEndian, format.RedMax); err != nil { if err := binary.Write(&buf, binary.BigEndian, format.RedMax); err != nil {
return nil, err return nil, err
} }

View File

@ -7,6 +7,8 @@ import (
"io" "io"
"vncproxy/common" "vncproxy/common"
"vncproxy/encodings" "vncproxy/encodings"
"vncproxy/logger"
"strings"
) )
// FramebufferUpdateMessage consists of a sequence of rectangles of // FramebufferUpdateMessage consists of a sequence of rectangles of
@ -49,13 +51,13 @@ func (fbm *FramebufferUpdateMessage) Read(c common.IClientConn, r *common.RfbRea
// We must always support the raw encoding // We must always support the raw encoding
rawEnc := new(encodings.RawEncoding) rawEnc := new(encodings.RawEncoding)
encMap[rawEnc.Type()] = rawEnc encMap[rawEnc.Type()] = rawEnc
fmt.Printf("numrects= %d\n", numRects) logger.Debugf("numrects= %d", numRects)
rects := make([]common.Rectangle, numRects) rects := make([]common.Rectangle, numRects)
for i := uint16(0); i < numRects; i++ { for i := uint16(0); i < numRects; i++ {
fmt.Printf("###############rect################: %d\n", i) logger.Debugf("###############rect################: %d\n", i)
var encodingType int32 var encodingTypeInt int32
r.SendRectSeparator(-1) r.SendRectSeparator(-1)
rect := &rects[i] rect := &rects[i]
data := []interface{}{ data := []interface{}{
@ -63,34 +65,45 @@ func (fbm *FramebufferUpdateMessage) Read(c common.IClientConn, r *common.RfbRea
&rect.Y, &rect.Y,
&rect.Width, &rect.Width,
&rect.Height, &rect.Height,
&encodingType, &encodingTypeInt,
} }
for _, val := range data { for _, val := range data {
if err := binary.Read(r, binary.BigEndian, val); err != nil { if err := binary.Read(r, binary.BigEndian, val); err != nil {
fmt.Printf("err: %v\n", err) logger.Errorf("err: %v", err)
return nil, err return nil, err
} }
} }
jBytes, _ := json.Marshal(data) jBytes, _ := json.Marshal(data)
fmt.Printf("rect hdr data: %s\n", string(jBytes)) encType := common.EncodingType(encodingTypeInt)
//fmt.Printf(" encoding type: %d", encodingType)
enc, ok := encMap[encodingType]
if !ok {
return nil, fmt.Errorf("unsupported encoding type: %d\n", encodingType)
}
logger.Debugf("rect hdr data: enctype=%s, data: %s\n", encType, string(jBytes))
enc, supported := encMap[encodingTypeInt]
if supported {
var err error var err error
rect.Enc, err = enc.Read(c.CurrentPixelFormat(), rect, r) rect.Enc, err = enc.Read(c.CurrentPixelFormat(), rect, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
if strings.Contains(encType.String(), "Pseudo") {
rect.Enc = &encodings.PseudoEncoding{encodingTypeInt}
} else {
return nil, fmt.Errorf("unsupported encoding type: %d, %s", encodingTypeInt, encType)
}
}
} }
return &FramebufferUpdateMessage{rects}, nil return &FramebufferUpdateMessage{rects}, nil
} }
// SetColorMapEntriesMessage is sent by the server to set values into // SetColorMapEntriesMessage is sent by the server to set values into
// the color map. This message will automatically update the color map // the color map. This message will automatically update the color map
// for the associated connection, but contains the color change data // for the associated connection, but contains the color change data

View File

@ -1,14 +0,0 @@
package common
type Logger interface {
Debug(v ...interface{})
Debugf(format string, v ...interface{})
Info(v ...interface{})
Infof(format string, v ...interface{})
Warn(v ...interface{})
Warnf(format string, v ...interface{})
Error(v ...interface{})
Errorf(format string, v ...interface{})
Fatal(v ...interface{})
Fatalf(format string, v ...interface{})
}

View File

@ -1,6 +1,9 @@
package common package common
import "io" import (
"io"
)
type ClientMessageType uint8 type ClientMessageType uint8
@ -54,3 +57,21 @@ type ClientMessage interface {
Read(io.Reader) (ClientMessage, error) Read(io.Reader) (ClientMessage, error)
Write(io.Writer) error Write(io.Writer) error
} }
func (cmt ClientMessageType) String() string {
switch cmt {
case SetPixelFormatMsgType:
return "SetPixelFormatMsgType"
case SetEncodingsMsgType:
return "SetEncodingsMsgType"
case FramebufferUpdateRequestMsgType:
return "FramebufferUpdateRequestMsgType"
case KeyEventMsgType:
return "KeyEventMsgType"
case PointerEventMsgType:
return "PointerEventMsgType"
case ClientCutTextMsgType:
return "ClientCutTextMsgType"
}
return ""
}

View File

@ -21,6 +21,124 @@ type Encoding interface {
// EncodingType represents a known VNC encoding type. // EncodingType represents a known VNC encoding type.
type EncodingType int32 type EncodingType int32
func (enct EncodingType) String() string {
switch enct {
case EncRaw:
return "EncRaw"
case EncCopyRect:
return "EncCopyRect"
case EncRRE:
return "EncRRE"
case EncCoRRE:
return "EncCoRRE"
case EncHextile:
return "EncHextile"
case EncZlib:
return "EncZlib"
case EncTight:
return "EncTight"
case EncZlibHex:
return "EncZlibHex"
case EncUltra1:
return "EncUltra1"
case EncUltra2:
return "EncUltra2"
case EncJPEG:
return "EncJPEG"
case EncJRLE:
return "EncJRLE"
case EncTRLE:
return "EncTRLE"
case EncZRLE:
return "EncZRLE"
case EncJPEGQualityLevelPseudo10:
return "EncJPEGQualityLevelPseudo10"
case EncJPEGQualityLevelPseudo9:
return "EncJPEGQualityLevelPseudo9"
case EncJPEGQualityLevelPseudo8:
return "EncJPEGQualityLevelPseudo8"
case EncJPEGQualityLevelPseudo7:
return "EncJPEGQualityLevelPseudo7"
case EncJPEGQualityLevelPseudo6:
return "EncJPEGQualityLevelPseudo6"
case EncJPEGQualityLevelPseudo5:
return "EncJPEGQualityLevelPseudo5"
case EncJPEGQualityLevelPseudo4:
return "EncJPEGQualityLevelPseudo4"
case EncJPEGQualityLevelPseudo3:
return "EncJPEGQualityLevelPseudo3"
case EncJPEGQualityLevelPseudo2:
return "EncJPEGQualityLevelPseudo2"
case EncJPEGQualityLevelPseudo1:
return "EncJPEGQualityLevelPseudo1"
case EncColorPseudo:
return "EncColorPseudo"
case EncDesktopSizePseudo:
return "EncDesktopSizePseudo"
case EncLastRectPseudo:
return "EncLastRectPseudo"
case EncCompressionLevel10:
return "EncCompressionLevel10"
case EncCompressionLevel9:
return "EncCompressionLevel9"
case EncCompressionLevel8:
return "EncCompressionLevel8"
case EncCompressionLevel7:
return "EncCompressionLevel7"
case EncCompressionLevel6:
return "EncCompressionLevel6"
case EncCompressionLevel5:
return "EncCompressionLevel5"
case EncCompressionLevel4:
return "EncCompressionLevel4"
case EncCompressionLevel3:
return "EncCompressionLevel3"
case EncCompressionLevel2:
return "EncCompressionLevel2"
case EncCompressionLevel1:
return "EncCompressionLevel1"
case EncQEMUPointerMotionChangePseudo:
return "EncQEMUPointerMotionChangePseudo"
case EncQEMUExtendedKeyEventPseudo:
return "EncQEMUExtendedKeyEventPseudo"
case EncTightPng:
return "EncTightPng"
case EncExtendedDesktopSizePseudo:
return "EncExtendedDesktopSizePseudo"
case EncXvpPseudo:
return "EncXvpPseudo"
case EncFencePseudo:
return "EncFencePseudo"
case EncContinuousUpdatesPseudo:
return "EncContinuousUpdatesPseudo"
case EncClientRedirect:
return "EncClientRedirect"
case EncTightPNGBase64:
return "EncTightPNGBase64"
case EncTightDiffComp:
return "EncTightDiffComp"
case EncVMWDefineCursor:
return "EncVMWDefineCursor"
case EncVMWCursorState:
return "EncVMWCursorState"
case EncVMWCursorPosition:
return "EncVMWCursorPosition"
case EncVMWTypematicInfo:
return "EncVMWTypematicInfo"
case EncVMWLEDState:
return "EncVMWLEDState"
case EncVMWServerPush2:
return "EncVMWServerPush2"
case EncVMWServerCaps:
return "EncVMWServerCaps"
case EncVMWFrameStamp:
return "EncVMWFrameStamp"
case EncOffscreenCopyRect:
return "EncOffscreenCopyRect"
}
return ""
}
const ( const (
EncRaw EncodingType = 0 EncRaw EncodingType = 0
EncCopyRect EncodingType = 1 EncCopyRect EncodingType = 1
@ -86,8 +204,8 @@ const (
type PixelFormat struct { type PixelFormat struct {
BPP uint8 BPP uint8
Depth uint8 Depth uint8
BigEndian bool BigEndian uint8
TrueColor bool TrueColor uint8
RedMax uint16 RedMax uint16
GreenMax uint16 GreenMax uint16
BlueMax uint16 BlueMax uint16
@ -110,7 +228,7 @@ func (format *PixelFormat) WriteTo(w io.Writer) error {
} }
var boolByte byte var boolByte byte
if format.BigEndian { if format.BigEndian == 1 {
boolByte = 1 boolByte = 1
} else { } else {
boolByte = 0 boolByte = 0
@ -121,7 +239,7 @@ func (format *PixelFormat) WriteTo(w io.Writer) error {
return err return err
} }
if format.TrueColor { if format.TrueColor == 1 {
boolByte = 1 boolByte = 1
} else { } else {
boolByte = 0 boolByte = 0
@ -134,7 +252,7 @@ func (format *PixelFormat) WriteTo(w io.Writer) error {
// If we have true color enabled then we have to fill in the rest of the // If we have true color enabled then we have to fill in the rest of the
// structure with the color values. // structure with the color values.
if format.TrueColor { if format.TrueColor == 1 {
if err := binary.Write(&buf, binary.BigEndian, format.RedMax); err != nil { if err := binary.Write(&buf, binary.BigEndian, format.RedMax); err != nil {
return err return err
} }
@ -165,19 +283,19 @@ func (format *PixelFormat) WriteTo(w io.Writer) error {
} }
func NewPixelFormat(bpp uint8) *PixelFormat { func NewPixelFormat(bpp uint8) *PixelFormat {
bigEndian := false bigEndian := 0
// rgbMax := uint16(math.Exp2(float64(bpp))) - 1 // rgbMax := uint16(math.Exp2(float64(bpp))) - 1
rMax := uint16(255) rMax := uint16(255)
gMax := uint16(255) gMax := uint16(255)
bMax := uint16(255) bMax := uint16(255)
var ( var (
tc = true tc = 1
rs, gs, bs uint8 rs, gs, bs uint8
depth uint8 depth uint8
) )
switch bpp { switch bpp {
case 8: case 8:
tc = false tc = 0
depth = 8 depth = 8
rs, gs, bs = 0, 0, 0 rs, gs, bs = 0, 0, 0
case 16: case 16:
@ -189,5 +307,5 @@ func NewPixelFormat(bpp uint8) *PixelFormat {
rs, gs, bs = 16, 8, 0 rs, gs, bs = 16, 8, 0
} }
return &PixelFormat{bpp, depth, bigEndian, tc, rMax, gMax, bMax, rs, gs, bs} return &PixelFormat{bpp, depth, uint8(bigEndian), uint8(tc), rMax, gMax, bMax, rs, gs, bs}
} }

20
common/multiListener.go Normal file
View File

@ -0,0 +1,20 @@
package common
type MultiListener struct {
listeners []SegmentConsumer
}
func (m *MultiListener) AddListener(listener SegmentConsumer) {
m.listeners = append(m.listeners, listener)
}
func (m *MultiListener) Consume(seg *RfbSegment) error {
for _, li := range m.listeners {
err := li.Consume(seg)
if err != nil {
return err
}
}
return nil
}

View File

@ -2,8 +2,8 @@ package common
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"vncproxy/logger"
) )
var TightMinToCompress = 12 var TightMinToCompress = 12
@ -19,6 +19,24 @@ const (
type SegmentType int type SegmentType int
func (seg SegmentType ) String() string {
switch seg {
case SegmentBytes:
return "SegmentBytes"
case SegmentMessageSeparator:
return "SegmentMessageSeparator"
case SegmentRectSeparator:
return "SegmentRectSeparator"
case SegmentFullyParsedClientMessage:
return "SegmentFullyParsedClientMessage"
case SegmentFullyParsedServerMessage:
return "SegmentFullyParsedServerMessage"
case SegmentServerInitMessage:
return "SegmentServerInitMessage"
}
return ""
}
type RfbSegment struct { type RfbSegment struct {
Bytes []byte Bytes []byte
SegmentType SegmentType SegmentType SegmentType
@ -32,7 +50,7 @@ type SegmentConsumer interface {
type RfbReadHelper struct { type RfbReadHelper struct {
io.Reader io.Reader
Listener SegmentConsumer Listeners *MultiListener
} }
func (r *RfbReadHelper) ReadDiscrete(p []byte) (int, error) { func (r *RfbReadHelper) ReadDiscrete(p []byte) (int, error) {
@ -41,27 +59,17 @@ func (r *RfbReadHelper) ReadDiscrete(p []byte) (int, error) {
func (r *RfbReadHelper) SendRectSeparator(upcomingRectType int) error { func (r *RfbReadHelper) SendRectSeparator(upcomingRectType int) error {
seg := &RfbSegment{SegmentType: SegmentRectSeparator, UpcomingObjectType: upcomingRectType} seg := &RfbSegment{SegmentType: SegmentRectSeparator, UpcomingObjectType: upcomingRectType}
if r.Listener != nil { return r.Listeners.Consume(seg)
return nil
}
return r.Listener.Consume(seg)
} }
func (r *RfbReadHelper) SendMessageSeparator(upcomingMessageType ServerMessageType) error { func (r *RfbReadHelper) SendMessageSeparator(upcomingMessageType ServerMessageType) error {
seg := &RfbSegment{SegmentType: SegmentMessageSeparator, UpcomingObjectType: int(upcomingMessageType)} seg := &RfbSegment{SegmentType: SegmentMessageSeparator, UpcomingObjectType: int(upcomingMessageType)}
if r.Listener == nil { return r.Listeners.Consume(seg)
return nil
}
return r.Listener.Consume(seg)
} }
func (r *RfbReadHelper) PublishBytes(p []byte) error { func (r *RfbReadHelper) PublishBytes(p []byte) error {
seg := &RfbSegment{Bytes: p, SegmentType: SegmentBytes} seg := &RfbSegment{Bytes: p, SegmentType: SegmentBytes}
if r.Listener == nil { return r.Listeners.Consume(seg)
return nil
}
return r.Listener.Consume(seg)
} }
func (r *RfbReadHelper) Read(p []byte) (n int, err error) { func (r *RfbReadHelper) Read(p []byte) (n int, err error) {
@ -71,13 +79,11 @@ func (r *RfbReadHelper) Read(p []byte) (n int, err error) {
} }
//write the bytes to the Listener for further processing //write the bytes to the Listener for further processing
seg := &RfbSegment{Bytes: p, SegmentType: SegmentBytes} seg := &RfbSegment{Bytes: p, SegmentType: SegmentBytes}
if r.Listener == nil { err = r.Listeners.Consume(seg)
return 0, nil
}
r.Listener.Consume(seg)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return readLen, err return readLen, err
} }
@ -97,7 +103,7 @@ func (r *RfbReadHelper) ReadUint8() (uint8, error) {
if err := binary.Read(r, binary.BigEndian, &myUint); err != nil { if err := binary.Read(r, binary.BigEndian, &myUint); err != nil {
return 0, err return 0, err
} }
//fmt.Printf("myUint=%d", myUint)
return myUint, nil return myUint, nil
} }
func (r *RfbReadHelper) ReadUint16() (uint16, error) { func (r *RfbReadHelper) ReadUint16() (uint16, error) {
@ -105,7 +111,7 @@ func (r *RfbReadHelper) ReadUint16() (uint16, error) {
if err := binary.Read(r, binary.BigEndian, &myUint); err != nil { if err := binary.Read(r, binary.BigEndian, &myUint); err != nil {
return 0, err return 0, err
} }
//fmt.Printf("myUint=%d", myUint)
return myUint, nil return myUint, nil
} }
func (r *RfbReadHelper) ReadUint32() (uint32, error) { func (r *RfbReadHelper) ReadUint32() (uint32, error) {
@ -113,7 +119,7 @@ func (r *RfbReadHelper) ReadUint32() (uint32, error) {
if err := binary.Read(r, binary.BigEndian, &myUint); err != nil { if err := binary.Read(r, binary.BigEndian, &myUint); err != nil {
return 0, err return 0, err
} }
//fmt.Printf("myUint=%d", myUint)
return myUint, nil return myUint, nil
} }
func (r *RfbReadHelper) ReadCompactLen() (int, error) { func (r *RfbReadHelper) ReadCompactLen() (int, error) {
@ -144,7 +150,7 @@ func (r *RfbReadHelper) ReadTightData(dataSize int) ([]byte, error) {
return r.ReadBytes(int(dataSize)) return r.ReadBytes(int(dataSize))
} }
zlibDataLen, err := r.ReadCompactLen() zlibDataLen, err := r.ReadCompactLen()
fmt.Printf("compactlen=%d\n", zlibDataLen) logger.Debugf("compactlen=%d", zlibDataLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -24,6 +24,20 @@ const (
ServerCutText ServerCutText
) )
func (typ ServerMessageType) String() string {
switch typ {
case FramebufferUpdate:
return "FramebufferUpdate"
case SetColourMapEntries:
return "SetColourMapEntries"
case Bell:
return "Bell"
case ServerCutText:
return "ServerCutText"
}
return ""
}
type ServerInit struct { type ServerInit struct {
FBWidth, FBHeight uint16 FBWidth, FBHeight uint16
PixelFormat PixelFormat PixelFormat PixelFormat

View File

@ -1,6 +1,9 @@
package encodings package encodings
import "vncproxy/common" import (
"vncproxy/common"
"vncproxy/logger"
)
const ( const (
HextileRaw = 1 HextileRaw = 1
@ -34,16 +37,16 @@ func (z *HextileEncoding) Read(pixelFmt *common.PixelFormat, rect *common.Rectan
//handle Hextile Subrect(tx, ty, tw, th): //handle Hextile Subrect(tx, ty, tw, th):
subencoding, err := r.ReadUint8() subencoding, err := r.ReadUint8()
//fmt.Printf("hextile reader tile: (%d,%d) subenc=%d\n", ty, tx, subencoding) //logger.Debugf("hextile reader tile: (%d,%d) subenc=%d\n", ty, tx, subencoding)
if err != nil { if err != nil {
//fmt.Printf("error in hextile reader: %v\n", err) logger.Errorf("HextileEncoding.Read: error in hextile reader: %v", err)
return nil, err return nil, err
} }
if (subencoding & HextileRaw) != 0 { if (subencoding & HextileRaw) != 0 {
//ReadRawRect(c, rect, r) //ReadRawRect(c, rect, r)
r.ReadBytes(tw * th * bytesPerPixel) r.ReadBytes(tw * th * bytesPerPixel)
//fmt.Printf("hextile reader: HextileRaw\n") //logger.Debug("hextile reader: HextileRaw\n")
continue continue
} }
if (subencoding & HextileBackgroundSpecified) != 0 { if (subencoding & HextileBackgroundSpecified) != 0 {
@ -53,10 +56,10 @@ func (z *HextileEncoding) Read(pixelFmt *common.PixelFormat, rect *common.Rectan
r.ReadBytes(int(bytesPerPixel)) r.ReadBytes(int(bytesPerPixel))
} }
if (subencoding & HextileAnySubrects) == 0 { if (subencoding & HextileAnySubrects) == 0 {
//fmt.Printf("hextile reader: no Subrects\n") //logger.Debug("hextile reader: no Subrects")
continue continue
} }
//fmt.Printf("hextile reader: handling Subrects\n")
nSubrects, err := r.ReadUint8() nSubrects, err := r.ReadUint8()
if err != nil { if err != nil {
return nil, err return nil, err

15
encodings/enc-pseudo.go Normal file
View File

@ -0,0 +1,15 @@
package encodings
import "vncproxy/common"
type PseudoEncoding struct {
Typ int32
}
func (pe *PseudoEncoding ) Type() int32{
return pe.Typ
}
func (pe *PseudoEncoding) Read(*common.PixelFormat, *common.Rectangle, *common.RfbReadHelper) (common.Encoding, error){
return pe, nil
}

View File

@ -21,7 +21,7 @@ const (
type TightEncoding struct { type TightEncoding struct {
//output io.Writer //output io.Writer
logger common.Logger //logger common.Logger
} }
// func (t *TightEncoding) SetOutput(output io.Writer) { // func (t *TightEncoding) SetOutput(output io.Writer) {

View File

@ -3,6 +3,7 @@ package encodings
import ( import (
"fmt" "fmt"
"vncproxy/common" "vncproxy/common"
"vncproxy/logger"
) )
type TightPngEncoding struct { type TightPngEncoding struct {
@ -16,15 +17,15 @@ func (t *TightPngEncoding) Read(pixelFmt *common.PixelFormat, rect *common.Recta
//var subencoding uint8 //var subencoding uint8
compctl, err := r.ReadUint8() compctl, err := r.ReadUint8()
if err != nil { if err != nil {
fmt.Printf("error in handling tight encoding: %v\n", err) logger.Errorf("error in handling tight encoding: %v", err)
return nil, err return nil, err
} }
fmt.Printf("bytesPixel= %d, subencoding= %d\n", bytesPixel, compctl) logger.Debugf("bytesPixel= %d, subencoding= %d", bytesPixel, compctl)
//move it to position (remove zlib flush commands) //move it to position (remove zlib flush commands)
compType := compctl >> 4 & 0x0F compType := compctl >> 4 & 0x0F
fmt.Printf("afterSHL:%d\n", compType) logger.Debugf("afterSHL:%d", compType)
switch compType { switch compType {
case TightPNG: case TightPNG:
len, err := r.ReadCompactLen() len, err := r.ReadCompactLen()

52
logger/logger.go Normal file
View File

@ -0,0 +1,52 @@
package logger
import "fmt"
type Logger interface {
Debug(v ...interface{})
Debugf(format string, v ...interface{})
Info(v ...interface{})
Infof(format string, v ...interface{})
Warn(v ...interface{})
Warnf(format string, v ...interface{})
Error(v ...interface{})
Errorf(format string, v ...interface{})
Fatal(v ...interface{})
Fatalf(format string, v ...interface{})
}
func Debug(v ...interface{}) {
fmt.Print("[Debug] ")
fmt.Println(v...)
}
func Debugf(format string, v ...interface{}) {
fmt.Printf("[Debug] "+format+"\n", v...)
}
func Info(v ...interface{}) {
fmt.Print("[Info] ")
fmt.Println(v...)
}
func Infof(format string, v ...interface{}) {
fmt.Printf("[Info] "+format+"\n", v...)
}
func Warn(v ...interface{}) {
fmt.Print("[Warn] ")
fmt.Println(v...)
}
func Warnf(format string, v ...interface{}) {
fmt.Printf("[Warn] "+format+"\n", v...)
}
func Error(v ...interface{}) {
fmt.Print("[Error] ")
fmt.Println(v...)
}
func Errorf(format string, v ...interface{}) {
fmt.Printf("[Error] "+format+"\n", v...)
}
func Fatal(v ...interface{}) {
fmt.Print("[Fatal] ")
fmt.Println(v...)
}
func Fatalf(format string, v ...interface{}) {
fmt.Printf("[Fatal] "+format+"\n", v)
}

23
main.go
View File

@ -1,22 +1,22 @@
package main package main
import ( import (
"fmt"
"net" "net"
"time" "time"
"vncproxy/client" "vncproxy/client"
"vncproxy/common" "vncproxy/common"
"vncproxy/encodings" "vncproxy/encodings"
"vncproxy/logger"
listeners "vncproxy/tee-listeners" listeners "vncproxy/tee-listeners"
) )
func main() { func main() {
//fmt.Println("")
//nc, err := net.Dial("tcp", "192.168.1.101:5903") //nc, err := net.Dial("tcp", "192.168.1.101:5903")
nc, err := net.Dial("tcp", "localhost:5903") nc, err := net.Dial("tcp", "localhost:5903")
if err != nil { if err != nil {
fmt.Printf("error connecting to vnc server: %s", err) logger.Errorf("error connecting to vnc server: %s", err)
} }
var noauth client.ClientAuthNone var noauth client.ClientAuthNone
authArr := []client.ClientAuth{&client.PasswordAuth{Password: "Ch_#!T@8"}, &noauth} authArr := []client.ClientAuth{&client.PasswordAuth{Password: "Ch_#!T@8"}, &noauth}
@ -25,23 +25,22 @@ func main() {
rec := listeners.NewRecorder("c:/Users/betzalel/recording.rbs") rec := listeners.NewRecorder("c:/Users/betzalel/recording.rbs")
split := &listeners.MultiListener{} clientConn, err := client.NewClientConn(nc,
split.AddListener(rec)
clientConn, err := client.Client(nc,
&client.ClientConfig{ &client.ClientConfig{
Auth: authArr, Auth: authArr,
ServerMessageCh: vncSrvMessagesChan, ServerMessageCh: vncSrvMessagesChan,
Exclusive: true, Exclusive: true,
}) })
clientConn.Listener = split
clientConn.Listeners.AddListener(rec)
clientConn.Connect()
if err != nil { if err != nil {
fmt.Printf("error creating client: %s", err) logger.Errorf("error creating client: %s", err)
} }
// err = clientConn.FramebufferUpdateRequest(false, 0, 0, 1024, 768) // err = clientConn.FramebufferUpdateRequest(false, 0, 0, 1024, 768)
// if err != nil { // if err != nil {
// fmt.Printf("error requesting fb update: %s\n", err) // logger.Errorf("error requesting fb update: %s", err)
// } // }
tight := encodings.TightEncoding{} tight := encodings.TightEncoding{}
@ -62,7 +61,7 @@ func main() {
for { for {
err = clientConn.FramebufferUpdateRequest(true, 0, 0, 1280, 800) err = clientConn.FramebufferUpdateRequest(true, 0, 0, 1280, 800)
if err != nil { if err != nil {
fmt.Printf("error requesting fb update: %s\n", err) logger.Errorf("error requesting fb update: %s", err)
} }
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
@ -70,7 +69,7 @@ func main() {
//go func() { //go func() {
for msg := range vncSrvMessagesChan { for msg := range vncSrvMessagesChan {
fmt.Printf("message type: %d, content: %v\n", msg.Type(), msg) logger.Debugf("message type: %d, content: %v\n", msg.Type(), msg)
} }
//}() //}()

View File

@ -1,7 +1,6 @@
package proxy package proxy
import ( import (
"fmt"
"log" "log"
"net" "net"
"path" "path"
@ -10,6 +9,7 @@ import (
"vncproxy/client" "vncproxy/client"
"vncproxy/common" "vncproxy/common"
"vncproxy/encodings" "vncproxy/encodings"
"vncproxy/logger"
"vncproxy/server" "vncproxy/server"
listeners "vncproxy/tee-listeners" listeners "vncproxy/tee-listeners"
) )
@ -20,16 +20,16 @@ type VncProxy struct {
recordingDir string // empty = no recording recordingDir string // empty = no recording
proxyPassword string // empty = no auth proxyPassword string // empty = no auth
targetServersPassword string //empty = no auth targetServersPassword string //empty = no auth
singleSession *VncSession // to be used when not using sessions SingleSession *VncSession // to be used when not using sessions
usingSessions bool //false = single session - defined in the var above UsingSessions bool //false = single session - defined in the var above
sessionManager *SessionManager sessionManager *SessionManager
} }
func (vp *VncProxy) connectToVncServer(targetServerUrl string) (*client.ClientConn, error) { func (vp *VncProxy) createClientConnection(targetServerUrl string) (*client.ClientConn, error) {
nc, err := net.Dial("tcp", targetServerUrl) nc, err := net.Dial("tcp", targetServerUrl)
if err != nil { if err != nil {
fmt.Printf("error connecting to vnc server: %s", err) logger.Errorf("error connecting to vnc server: %s", err)
return nil, err return nil, err
} }
@ -38,14 +38,7 @@ func (vp *VncProxy) connectToVncServer(targetServerUrl string) (*client.ClientCo
vncSrvMessagesChan := make(chan common.ServerMessage) vncSrvMessagesChan := make(chan common.ServerMessage)
//rec := listeners.NewRecorder("recording.rbs") clientConn, err := client.NewClientConn(nc,
// split := &listeners.MultiListener{}
// for _, listener := range rfbListeners {
// split.AddListener(listener)
// }
clientConn, err := client.Client(nc,
&client.ClientConfig{ &client.ClientConfig{
Auth: authArr, Auth: authArr,
ServerMessageCh: vncSrvMessagesChan, ServerMessageCh: vncSrvMessagesChan,
@ -54,71 +47,89 @@ func (vp *VncProxy) connectToVncServer(targetServerUrl string) (*client.ClientCo
//clientConn.Listener = split //clientConn.Listener = split
if err != nil { if err != nil {
fmt.Printf("error creating client: %s", err) logger.Errorf("error creating client: %s", err)
return nil, err return nil, err
} }
tight := encodings.TightEncoding{}
tightPng := encodings.TightPngEncoding{}
rre := encodings.RREEncoding{}
zlib := encodings.ZLibEncoding{}
zrle := encodings.ZRLEEncoding{}
cpyRect := encodings.CopyRectEncoding{}
coRRE := encodings.CoRREEncoding{}
hextile := encodings.HextileEncoding{}
clientConn.SetEncodings([]common.Encoding{&cpyRect, &tightPng, &tight, &hextile, &coRRE, &rre, &zlib, &zrle})
return clientConn, nil return clientConn, nil
} }
// if sessions not enabled, will always return the configured target server (only one) // if sessions not enabled, will always return the configured target server (only one)
func (vp *VncProxy) getTargetServerFromSession(sessionId string) (*VncSession, error) { func (vp *VncProxy) getTargetServerFromSession(sessionId string) (*VncSession, error) {
if !vp.usingSessions { if !vp.UsingSessions {
return vp.singleSession, nil if vp.SingleSession == nil {
logger.Errorf("SingleSession is empty, use sessions or populate the SingleSession member of the VncProxy struct.")
}
return vp.SingleSession, nil
} }
return vp.sessionManager.GetSession(sessionId) return vp.sessionManager.GetSession(sessionId)
} }
func (vp *VncProxy) newServerConnHandler(cfg *server.ServerConfig, sconn *server.ServerConn, rfbListeners []common.SegmentConsumer) error { func (vp *VncProxy) newServerConnHandler(cfg *server.ServerConfig, sconn *server.ServerConn) error {
recFile := "recording" + strconv.FormatInt(time.Now().Unix(), 10) + ".rbs" recFile := "recording" + strconv.FormatInt(time.Now().Unix(), 10) + ".rbs"
recPath := path.Join(vp.recordingDir, recFile) recPath := path.Join(vp.recordingDir, recFile)
rec := listeners.NewRecorder(recPath) rec := listeners.NewRecorder(recPath)
session, err := vp.getTargetServerFromSession(sconn.SessionId) session, err := vp.getTargetServerFromSession(sconn.SessionId)
if err != nil { if err != nil {
fmt.Printf("Proxy.newServerConnHandler can't get session: %d\n", sconn.SessionId) logger.Errorf("Proxy.newServerConnHandler can't get session: %d", sconn.SessionId)
return err return err
} }
serverSplitter := &listeners.MultiListener{} // for _, l := range rfbListeners {
for _, l := range rfbListeners { // sconn.Listeners.AddListener(l)
serverSplitter.AddListener(l) // }
sconn.Listeners.AddListener(rec)
//clientSplitter := &common.MultiListener{}
cconn, err := vp.createClientConnection(session.TargetHostname + ":" + session.TargetPort)
if err != nil {
logger.Errorf("Proxy.newServerConnHandler error creating connection: %s", err)
return err
} }
serverSplitter.AddListener(rec) cconn.Listeners.AddListener(rec)
sconn.Listener = serverSplitter //cconn.Listener = clientSplitter
clientSplitter := &listeners.MultiListener{}
clientSplitter.AddListener(rec)
cconn, err := vp.connectToVncServer(session.TargetHostname + ":" + session.TargetPort)
cconn.Listener = clientSplitter
//creating cross-listeners between server and client parts to pass messages through the proxy: //creating cross-listeners between server and client parts to pass messages through the proxy:
// gets the bytes from the actual vnc server on the env (client part of the proxy) // gets the bytes from the actual vnc server on the env (client part of the proxy)
// and writes them through the server socket to the vnc-client // and writes them through the server socket to the vnc-client
serverMsgRepeater := &listeners.WriteTo{sconn, "vnc-client bound"} serverMsgRepeater := &listeners.WriteTo{sconn, "vnc-client-bound"}
clientSplitter.AddListener(serverMsgRepeater) cconn.Listeners.AddListener(serverMsgRepeater)
// gets the messages from the server part (from vnc-client), // gets the messages from the server part (from vnc-client),
// and write through the client to the actual vnc-server // and write through the client to the actual vnc-server
clientMsgRepeater := &listeners.WriteTo{cconn, "vnc-server bound"} clientMsgRepeater := &listeners.WriteTo{cconn, "vnc-server-bound"}
serverSplitter.AddListener(clientMsgRepeater) sconn.Listeners.AddListener(clientMsgRepeater)
err = cconn.Connect()
if err != nil {
logger.Errorf("Proxy.newServerConnHandler error connecting to client: %s", err)
return err
}
encs := []common.Encoding{
&encodings.RawEncoding{},
&encodings.TightEncoding{},
//encodings.TightPngEncoding{},
//encodings.RREEncoding{},
//encodings.ZLibEncoding{},
//encodings.ZRLEEncoding{},
//encodings.CopyRectEncoding{},
//encodings.CoRREEncoding{},
//encodings.HextileEncoding{},
}
err = cconn.SetEncodings(encs)
if err != nil {
logger.Errorf("Proxy.newServerConnHandler error connecting to client: %s", err)
return err
}
return nil return nil
} }
func (vp *VncProxy) StartListening(rfbListeners []common.SegmentConsumer) { func (vp *VncProxy) StartListening() {
//chServer := make(chan common.ClientMessage) //chServer := make(chan common.ClientMessage)
chClient := make(chan common.ServerMessage) chClient := make(chan common.ServerMessage)
@ -137,10 +148,12 @@ func (vp *VncProxy) StartListening(rfbListeners []common.SegmentConsumer) {
DesktopName: []byte("workDesk"), DesktopName: []byte("workDesk"),
Height: uint16(768), Height: uint16(768),
Width: uint16(1024), Width: uint16(1024),
NewConnHandler: func(cfg *server.ServerConfig, conn *server.ServerConn) error { NewConnHandler: vp.newServerConnHandler,
vp.newServerConnHandler(cfg, conn, rfbListeners) UseDummySession: !vp.UsingSessions,
return nil // func(cfg *server.ServerConfig, conn *server.ServerConn) error {
}, // vp.newServerConnHandler(cfg, conn)
// return nil
// },
} }
if vp.wsListeningUrl != "" { if vp.wsListeningUrl != "" {

24
proxy/proxy_test.go Normal file
View File

@ -0,0 +1,24 @@
package proxy
import "testing"
func TestProxy(t *testing.T) {
//create default session if required
proxy := &VncProxy{
wsListeningUrl: "http://localhost:7777/", // empty = not listening on ws
recordingDir: "c:\\vncRec", // empty = no recording
targetServersPassword: "Ch_#!T@8", //empty = no auth
SingleSession: &VncSession{
TargetHostname: "localhost",
TargetPort: "5903",
TargetPassword: "vncPass",
ID: "dummySession",
Status: SessionStatusActive,
Type: SessionTypeRecordingProxy,
}, // to be used when not using sessions
UsingSessions: false, //false = single session - defined in the var above
}
proxy.StartListening()
}

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"sync" "sync"
"vncproxy/common" "vncproxy/common"
"vncproxy/logger"
) )
type ServerConn struct { type ServerConn struct {
@ -40,7 +41,7 @@ type ServerConn struct {
pixelFormat *common.PixelFormat pixelFormat *common.PixelFormat
// a consumer for the parsed messages, to allow for recording and proxy // a consumer for the parsed messages, to allow for recording and proxy
Listener common.SegmentConsumer Listeners *common.MultiListener
SessionId string SessionId string
@ -70,6 +71,7 @@ func NewServerConn(c io.ReadWriter, cfg *ServerConfig) (*ServerConn, error) {
pixelFormat: cfg.PixelFormat, pixelFormat: cfg.PixelFormat,
fbWidth: cfg.Width, fbWidth: cfg.Width,
fbHeight: cfg.Height, fbHeight: cfg.Height,
Listeners: &common.MultiListener{},
}, nil }, nil
} }
@ -183,7 +185,7 @@ func (c *ServerConn) handle() error {
for { for {
select { select {
case msg := <-c.cfg.ServerMessageCh: case msg := <-c.cfg.ServerMessageCh:
fmt.Printf("%v", msg) logger.Debugf("%v", msg)
// if err = msg.Write(c); err != nil { // if err = msg.Write(c); err != nil {
// return err // return err
// } // }
@ -204,25 +206,32 @@ func (c *ServerConn) handle() error {
default: default:
var messageType common.ClientMessageType var messageType common.ClientMessageType
if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { if err := binary.Read(c, binary.BigEndian, &messageType); err != nil {
fmt.Printf("Error: %v\n", err) logger.Errorf("Error: %v", err)
return err return err
} }
msg, ok := clientMessages[messageType] msg, ok := clientMessages[messageType]
if !ok { if !ok {
return fmt.Errorf("unsupported message-type: %v", messageType) return fmt.Errorf("ServerConn.Handle: unsupported message-type: %v", messageType)
} }
parsedMsg, err := msg.Read(c) parsedMsg, err := msg.Read(c)
if err != nil {
logger.Errorf("srv err %s", err.Error())
return err
}
seg := &common.RfbSegment{ seg := &common.RfbSegment{
SegmentType: common.SegmentFullyParsedClientMessage, SegmentType: common.SegmentFullyParsedClientMessage,
Message: parsedMsg, Message: parsedMsg,
} }
c.Listener.Consume(seg) err = c.Listeners.Consume(seg)
if err != nil { if err != nil {
fmt.Printf("srv err %s\n", err.Error()) logger.Errorf("ServerConn.Handle: listener consume err %s", err.Error())
return err return err
} }
fmt.Printf("message:%s, %v\n", parsedMsg.Type(), parsedMsg)
logger.Debugf("ServerConn.Handle got ClientMessage: %s, %v", parsedMsg.Type(), parsedMsg)
//c.cfg.ClientMessageCh <- parsedMsg //c.cfg.ClientMessageCh <- parsedMsg
} }
} }

View File

@ -60,7 +60,7 @@ type ServerConfig struct {
DesktopName []byte DesktopName []byte
Height uint16 Height uint16
Width uint16 Width uint16
UseDummySession bool
//handler to allow for registering for messages, this can't be a channel //handler to allow for registering for messages, this can't be a channel
//because of the websockets handler function which will kill the connection on exit if conn.handle() is run on another thread //because of the websockets handler function which will kill the connection on exit if conn.handle() is run on another thread
NewConnHandler ServerHandler NewConnHandler ServerHandler
@ -70,7 +70,7 @@ func wsHandlerFunc(ws io.ReadWriter, cfg *ServerConfig, sessionId string) {
// header := ws.Request().Header // header := ws.Request().Header
// url := ws.Request().URL // url := ws.Request().URL
// //stam := header.Get("Origin") // //stam := header.Get("Origin")
// fmt.Printf("header: %v\nurl: %v\n", header, url) // logger.Debugf("header: %v\nurl: %v", header, url)
// io.Copy(ws, ws) // io.Copy(ws, ws)
err := attachNewServerConn(ws, cfg, sessionId) err := attachNewServerConn(ws, cfg, sessionId)
@ -96,7 +96,7 @@ func TcpServe(url string, cfg *ServerConfig) error {
if err != nil { if err != nil {
return err return err
} }
go attachNewServerConn(c, cfg, "tcpDummySession") go attachNewServerConn(c, cfg, "dummySession")
// if err != nil { // if err != nil {
// return err // return err
// } // }
@ -131,7 +131,11 @@ func attachNewServerConn(c io.ReadWriter, cfg *ServerConfig, sessionId string) e
conn.Close() conn.Close()
return err return err
} }
conn.SessionId = sessionId conn.SessionId = sessionId
if cfg.UseDummySession {
conn.SessionId = "dummySession"
}
cfg.NewConnHandler(cfg, conn) cfg.NewConnHandler(cfg, conn)
//go here will kill ws connections //go here will kill ws connections

View File

@ -1,10 +1,10 @@
package server package server
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"vncproxy/logger"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -31,7 +31,7 @@ func (wsServer *WsServer) Listen(urlStr string, handlerFunc WsHandler) {
} }
url, err := url.Parse(urlStr) url, err := url.Parse(urlStr)
if err != nil { if err != nil {
fmt.Println("error while parsing url: ", err) logger.Errorf("error while parsing url: ", err)
} }
// http.HandleFunc(url.Path, // http.HandleFunc(url.Path,

View File

@ -1,11 +1,11 @@
package server package server
import ( import (
"fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"vncproxy/logger"
"bytes" "bytes"
@ -92,7 +92,7 @@ func (wsServer *WsServer1) Listen(urlStr string, handlerFunc WsHandler) {
} }
url, err := url.Parse(urlStr) url, err := url.Parse(urlStr)
if err != nil { if err != nil {
fmt.Println("error while parsing url: ", err) logger.Errorf("error while parsing url: ", err)
} }
http.HandleFunc(url.Path, handleConnection) http.HandleFunc(url.Path, handleConnection)

View File

@ -1,33 +1 @@
package server package server
// import (
// "fmt"
// "io"
// "net/http"
// "testing"
// "golang.org/x/net/websocket"
// )
// func TestWsServer(t *testing.T) {
// server := WsServer{}
// server.Listen(":8090")
// }
// // Echo the data received on the WebSocket.
// func EchoHandler(ws *websocket.Conn) {
// header := ws.Request().Header
// url := ws.Request().URL
// //stam := header.Get("Origin")
// fmt.Printf("header: %v\nurl: %v\n", header, url)
// io.Copy(ws, ws)
// }
// // This example demonstrates a trivial echo server.
// func TestGoWsServer(t *testing.T) {
// http.Handle("/", websocket.Handler(EchoHandler))
// err := http.ListenAndServe(":11111", nil)
// if err != nil {
// panic("ListenAndServe: " + err.Error())
// }
// }

View File

@ -1,22 +0,0 @@
package listeners
import "vncproxy/common"
type MultiListener struct {
listeners []common.SegmentConsumer
}
func (m *MultiListener) AddListener(listener common.SegmentConsumer) {
m.listeners = append(m.listeners, listener)
}
func (m *MultiListener) Consume(seg *common.RfbSegment) error {
for _, li := range m.listeners {
//fmt.Println(li)
err := li.Consume(seg)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"os" "os"
"time" "time"
"vncproxy/common" "vncproxy/common"
"vncproxy/logger"
"vncproxy/server" "vncproxy/server"
) )
@ -15,7 +15,7 @@ type Recorder struct {
//common.BytesListener //common.BytesListener
RBSFileName string RBSFileName string
writer *os.File writer *os.File
logger common.Logger //logger common.Logger
startTime int startTime int
buffer bytes.Buffer buffer bytes.Buffer
serverInitMessage *common.ServerInit serverInitMessage *common.ServerInit
@ -38,7 +38,7 @@ func NewRecorder(saveFilePath string) *Recorder {
rec.writer, err = os.OpenFile(saveFilePath, os.O_RDWR|os.O_CREATE, 0755) rec.writer, err = os.OpenFile(saveFilePath, os.O_RDWR|os.O_CREATE, 0755)
if err != nil { if err != nil {
fmt.Printf("unable to open file: %s, error: %v", saveFilePath, err) logger.Errorf("unable to open file: %s, error: %v", saveFilePath, err)
return nil return nil
} }
@ -54,18 +54,6 @@ func NewRecorder(saveFilePath string) *Recorder {
return &rec return &rec
} }
// func (rec *Recorder) startSession(desktopName string, fbWidth uint16, fbHeight uint16) error {
// err := rec.writeStartSession(desktopName, fbWidth, fbHeight)
// if err != nil {
// fmt.Printf("Recorder was unable to write StartSession to file error: %v", err)
// return nil
// }
// return nil
// }
const versionMsg_3_3 = "RFB 003.003\n" const versionMsg_3_3 = "RFB 003.003\n"
const versionMsg_3_7 = "RFB 003.007\n" const versionMsg_3_7 = "RFB 003.007\n"
const versionMsg_3_8 = "RFB 003.008\n" const versionMsg_3_8 = "RFB 003.008\n"
@ -115,27 +103,33 @@ func (r *Recorder) writeStartSession(initMsg *common.ServerInit) error {
} }
func (r *Recorder) Consume(data *common.RfbSegment) error { func (r *Recorder) Consume(data *common.RfbSegment) error {
//using async writes so if chan buffer overflows, proxy will not be affected //using async writes so if chan buffer overflows, proxy will not be affected
select { select {
case r.segmentChan <- data: case r.segmentChan <- data:
default: default:
fmt.Println("error: recorder queue is full") logger.Error("error: recorder queue is full")
} }
return nil return nil
} }
func (r *Recorder) HandleRfbSegment(data *common.RfbSegment) error { func (r *Recorder) HandleRfbSegment(data *common.RfbSegment) error {
defer func() {
if r := recover(); r != nil {
logger.Error("Recovered in HandleRfbSegment: ", r)
}
}()
switch data.SegmentType { switch data.SegmentType {
case common.SegmentMessageSeparator: case common.SegmentMessageSeparator:
if !r.sessionStartWritten { if !r.sessionStartWritten {
logger.Debugf("Recorder.HandleRfbSegment: writing start session segment: %v",r.serverInitMessage)
r.writeStartSession(r.serverInitMessage) r.writeStartSession(r.serverInitMessage)
} }
switch common.ServerMessageType(data.UpcomingObjectType) { switch common.ServerMessageType(data.UpcomingObjectType) {
case common.FramebufferUpdate: case common.FramebufferUpdate:
logger.Debugf("Recorder.HandleRfbSegment: saving FramebufferUpdate segment")
r.writeToDisk() r.writeToDisk()
case common.SetColourMapEntries: case common.SetColourMapEntries:
case common.Bell: case common.Bell:
@ -145,6 +139,7 @@ func (r *Recorder) HandleRfbSegment(data *common.RfbSegment) error {
} }
case common.SegmentRectSeparator: case common.SegmentRectSeparator:
logger.Debugf("Recorder.HandleRfbSegment: writing start rect start")
r.writeToDisk() r.writeToDisk()
case common.SegmentBytes: case common.SegmentBytes:
_, err := r.buffer.Write(data.Bytes) _, err := r.buffer.Write(data.Bytes)
@ -157,9 +152,10 @@ func (r *Recorder) HandleRfbSegment(data *common.RfbSegment) error {
switch clientMsg.Type() { switch clientMsg.Type() {
case common.SetPixelFormatMsgType: case common.SetPixelFormatMsgType:
clientMsg := data.Message.(*server.SetPixelFormat) clientMsg := data.Message.(*server.SetPixelFormat)
logger.Debugf("Recorder.HandleRfbSegment: client message %v", *clientMsg)
r.serverInitMessage.PixelFormat = clientMsg.PF r.serverInitMessage.PixelFormat = clientMsg.PF
default: default:
return errors.New("unknown client message type:" + string(data.UpcomingObjectType)) //return errors.New("unknown client message type:" + string(data.UpcomingObjectType))
} }
default: default:
@ -180,11 +176,11 @@ func (r *Recorder) writeToDisk() error {
paddedSize := (bytesLen + 3) & 0x7FFFFFFC paddedSize := (bytesLen + 3) & 0x7FFFFFFC
paddingSize := paddedSize - bytesLen paddingSize := paddedSize - bytesLen
fmt.Printf("paddedSize=%d paddingSize=%d bytesLen=%d", paddedSize, paddingSize, bytesLen) //logger.Debugf("paddedSize=%d paddingSize=%d bytesLen=%d", paddedSize, paddingSize, bytesLen)
//write buffer padded to 32bit //write buffer padded to 32bit
_, err := r.buffer.WriteTo(r.writer) _, err := r.buffer.WriteTo(r.writer)
padding := make([]byte, paddingSize) padding := make([]byte, paddingSize)
fmt.Printf("padding=%v ", padding) //logger.Debugf("padding=%v ", padding)
binary.Write(r.writer, binary.BigEndian, padding) binary.Write(r.writer, binary.BigEndian, padding)

View File

@ -1,9 +1,9 @@
package listeners package listeners
import ( import (
"errors"
"io" "io"
"vncproxy/common" "vncproxy/common"
"vncproxy/logger"
) )
type WriteTo struct { type WriteTo struct {
@ -12,17 +12,26 @@ type WriteTo struct {
} }
func (p *WriteTo) Consume(seg *common.RfbSegment) error { func (p *WriteTo) Consume(seg *common.RfbSegment) error {
logger.Debugf("WriteTo.Consume ("+p.Name+"): sending segment type=%s", seg.SegmentType)
switch seg.SegmentType { switch seg.SegmentType {
case common.SegmentMessageSeparator: case common.SegmentMessageSeparator:
case common.SegmentRectSeparator: case common.SegmentRectSeparator:
case common.SegmentBytes: case common.SegmentBytes:
_, err := p.Writer.Write(seg.Bytes) _, err := p.Writer.Write(seg.Bytes)
if (err != nil) {
logger.Errorf("WriteTo.Consume ("+p.Name+" SegmentBytes): problem writing to port: %s", err)
}
return err return err
case common.SegmentFullyParsedClientMessage: case common.SegmentFullyParsedClientMessage:
clientMsg := seg.Message.(common.ClientMessage) clientMsg := seg.Message.(common.ClientMessage)
clientMsg.Write(p.Writer) err := clientMsg.Write(p.Writer)
if (err != nil) {
logger.Errorf("WriteTo.Consume ("+p.Name+" SegmentFullyParsedClientMessage): problem writing to port: %s", err)
}
return err
default: default:
return errors.New("undefined RfbSegment type") //return errors.New("WriteTo.Consume: undefined RfbSegment type")
} }
return nil return nil
} }