From 66c2cfd3b11dabb8204acccb023be02d1de17328 Mon Sep 17 00:00:00 2001 From: Adriano Sela Aviles Date: Fri, 15 Mar 2024 10:29:37 -0700 Subject: [PATCH] Support for Serving on Generic net.Listener --- proxy/proxy.go | 77 ++++++++++++++++++++++++++++++++++++++---------- server/server.go | 8 +++-- 2 files changed, 66 insertions(+), 19 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 7651197..a18482e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -4,6 +4,7 @@ import ( "net" "path" "strconv" + "sync" "time" "github.com/amitbet/vncproxy/client" @@ -16,12 +17,13 @@ import ( ) type VncProxy struct { - TCPListeningURL string // empty = not listening on tcp - WsListeningURL string // empty = not listening on ws - RecordingDir string // empty = no recording - ProxyVncPassword string //empty = no auth - SingleSession *VncSession // to be used when not using sessions - UsingSessions bool //false = single session - defined in the var above + NetListener net.Listener // nil = not listening on a generic net.Listener + TCPListeningURL string // empty = not listening on tcp + WsListeningURL string // empty = not listening on ws + RecordingDir string // empty = no recording + ProxyVncPassword string //empty = no auth + SingleSession *VncSession // to be used when not using sessions + UsingSessions bool //false = single session - defined in the var above sessionManager *SessionManager } @@ -172,7 +174,7 @@ func (vp *VncProxy) StartListening() { secHandlers := []server.SecurityHandler{&server.ServerAuthNone{}} if vp.ProxyVncPassword != "" { - secHandlers = []server.SecurityHandler{&server.ServerAuthVNC{vp.ProxyVncPassword}} + secHandlers = []server.SecurityHandler{&server.ServerAuthVNC{Pass: vp.ProxyVncPassword}} } cfg := &server.ServerConfig{ SecurityHandlers: secHandlers, @@ -186,19 +188,62 @@ func (vp *VncProxy) StartListening() { UseDummySession: !vp.UsingSessions, } - if vp.TCPListeningURL != "" && vp.WsListeningURL != "" { - logger.Infof("running two listeners: tcp port: %s, ws url: %s", vp.TCPListeningURL, vp.WsListeningURL) - - go server.WsServe(vp.WsListeningURL, cfg) - server.TcpServe(vp.TCPListeningURL, cfg) + if vp.countListeners() == 0 { + logger.Error("no listeners configured on VncProxy") + return } + var wg sync.WaitGroup + if vp.WsListeningURL != "" { - logger.Infof("running ws listener url: %s", vp.WsListeningURL) - server.WsServe(vp.WsListeningURL, cfg) + wg.Add(1) + + go func() { + defer wg.Done() + defer logger.Info("ws listener stopped") + + logger.Infof("running ws listener url: %s", vp.WsListeningURL) + server.WsServe(vp.WsListeningURL, cfg) + }() } + if vp.TCPListeningURL != "" { - logger.Infof("running tcp listener on port: %s", vp.TCPListeningURL) - server.TcpServe(vp.TCPListeningURL, cfg) + wg.Add(1) + + go func() { + defer wg.Done() + defer logger.Info("tcp listener stopped") + + logger.Infof("running tcp listener on port: %s", vp.TCPListeningURL) + server.TcpServe(vp.TCPListeningURL, cfg) + }() } + + if vp.NetListener != nil { + wg.Add(1) + + go func() { + defer wg.Done() + defer logger.Info("generic net.Listener stopped") + + logger.Info("running generic net.Listener") + server.NetListenerServe(vp.NetListener, cfg) + }() + } + + wg.Wait() +} + +func (vp *VncProxy) countListeners() int { + count := 0 + if vp.TCPListeningURL != "" { + count++ + } + if vp.WsListeningURL != "" { + count++ + } + if vp.NetListener != nil { + count++ + } + return count } diff --git a/server/server.go b/server/server.go index f90b7d9..7743e06 100644 --- a/server/server.go +++ b/server/server.go @@ -1,10 +1,10 @@ package server import ( - "fmt" "io" "log" "net" + "github.com/amitbet/vncproxy/common" ) @@ -61,6 +61,10 @@ func TcpServe(url string, cfg *ServerConfig) error { if err != nil { log.Fatalf("Error listen. %v", err) } + return NetListenerServe(ln, cfg) +} + +func NetListenerServe(ln net.Listener, cfg *ServerConfig) error { for { c, err := ln.Accept() if err != nil { @@ -68,7 +72,6 @@ func TcpServe(url string, cfg *ServerConfig) error { } go attachNewServerConn(c, cfg, "dummySession") } - return nil } func attachNewServerConn(c io.ReadWriter, cfg *ServerConfig, sessionId string) error { @@ -79,7 +82,6 @@ func attachNewServerConn(c io.ReadWriter, cfg *ServerConfig, sessionId string) e } if err := ServerVersionHandler(cfg, conn); err != nil { - fmt.Errorf("err: %v\n", err) conn.Close() return err }