diff --git a/pkg/remotedialer/client.go b/pkg/remotedialer/client.go deleted file mode 100644 index 57dcdefb..00000000 --- a/pkg/remotedialer/client.go +++ /dev/null @@ -1,57 +0,0 @@ -package remotedialer - -import ( - "context" - "io/ioutil" - "net/http" - "time" - - "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" -) - -type ConnectAuthorizer func(proto, address string) bool - -func ClientConnect(wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, onConnect func(context.Context) error) { - if err := connectToProxy(wsURL, headers, auth, dialer, onConnect); err != nil { - logrus.WithError(err).Error("Failed to connect to proxy") - time.Sleep(time.Duration(5) * time.Second) - } -} - -func connectToProxy(proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context) error) error { - logrus.WithField("url", proxyURL).Info("Connecting to proxy") - - if dialer == nil { - dialer = &websocket.Dialer{HandshakeTimeout: 10 * time.Second} - } - ws, resp, err := dialer.Dial(proxyURL, headers) - if err != nil { - if resp == nil { - logrus.WithError(err).Errorf("Failed to connect to proxy. Empty dialer response") - } else { - rb, err2 := ioutil.ReadAll(resp.Body) - if err2 != nil { - logrus.WithError(err).Errorf("Failed to connect to proxy. Response status: %v - %v. Couldn't read response body (err: %v)", resp.StatusCode, resp.Status, err2) - } else { - logrus.WithError(err).Errorf("Failed to connect to proxy. Response status: %v - %v. Response body: %s", resp.StatusCode, resp.Status, rb) - } - } - return err - } - defer ws.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if onConnect != nil { - if err := onConnect(ctx); err != nil { - return err - } - } - - session := NewClientSession(auth, ws) - _, err = session.Serve() - session.Close() - return err -} diff --git a/pkg/remotedialer/client/main.go b/pkg/remotedialer/client/main.go deleted file mode 100644 index 9685e475..00000000 --- a/pkg/remotedialer/client/main.go +++ /dev/null @@ -1,34 +0,0 @@ -// +build !windows - -package main - -import ( - "flag" - "net/http" - - "github.com/rancher/norman/pkg/remotedialer" - "github.com/sirupsen/logrus" -) - -var ( - addr string - id string - debug bool -) - -func main() { - flag.StringVar(&addr, "connect", "ws://localhost:8123/connect", "Address to connect to") - flag.StringVar(&id, "id", "foo", "Client ID") - flag.BoolVar(&debug, "debug", true, "Debug logging") - flag.Parse() - - if debug { - logrus.SetLevel(logrus.DebugLevel) - } - - headers := http.Header{ - "X-Tunnel-ID": []string{id}, - } - - remotedialer.ClientConnect(addr, headers, nil, func(string, string) bool { return true }, nil) -} diff --git a/pkg/remotedialer/client_dialer.go b/pkg/remotedialer/client_dialer.go deleted file mode 100644 index 319f0031..00000000 --- a/pkg/remotedialer/client_dialer.go +++ /dev/null @@ -1,62 +0,0 @@ -package remotedialer - -import ( - "io" - "net" - "sync" - "time" -) - -func clientDial(dialer Dialer, conn *connection, message *message) { - defer conn.Close() - - var ( - netConn net.Conn - err error - ) - - if dialer == nil { - netDialer := &net.Dialer{ - Timeout: time.Duration(message.deadline) * time.Millisecond, - KeepAlive: 30 * time.Second, - } - netConn, err = netDialer.Dial(message.proto, message.address) - } else { - netConn, err = dialer(message.proto, message.address) - } - - if err != nil { - conn.tunnelClose(err) - return - } - defer netConn.Close() - - pipe(conn, netConn) -} - -func pipe(client *connection, server net.Conn) { - wg := sync.WaitGroup{} - wg.Add(1) - - close := func(err error) error { - if err == nil { - err = io.EOF - } - client.doTunnelClose(err) - server.Close() - return err - } - - go func() { - defer wg.Done() - _, err := io.Copy(server, client) - close(err) - }() - - _, err := io.Copy(client, server) - err = close(err) - wg.Wait() - - // Write tunnel error after no more I/O is happening, just incase messages get out of order - client.writeErr(err) -} diff --git a/pkg/remotedialer/connection.go b/pkg/remotedialer/connection.go deleted file mode 100644 index f9e6b29a..00000000 --- a/pkg/remotedialer/connection.go +++ /dev/null @@ -1,188 +0,0 @@ -package remotedialer - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" -) - -type connection struct { - sync.Mutex - - ctx context.Context - cancel func() - err error - writeDeadline time.Time - buf chan []byte - readBuf []byte - addr addr - session *Session - connID int64 -} - -func newConnection(connID int64, session *Session, proto, address string) *connection { - c := &connection{ - addr: addr{ - proto: proto, - address: address, - }, - connID: connID, - session: session, - buf: make(chan []byte, 1024), - } - return c -} - -func (c *connection) tunnelClose(err error) { - c.writeErr(err) - c.doTunnelClose(err) -} - -func (c *connection) doTunnelClose(err error) { - c.Lock() - defer c.Unlock() - - if c.err != nil { - return - } - - c.err = err - if c.err == nil { - c.err = io.ErrClosedPipe - } - - close(c.buf) -} - -func (c *connection) tunnelWriter() io.Writer { - return chanWriter{conn: c, C: c.buf} -} - -func (c *connection) Close() error { - c.session.closeConnection(c.connID, io.EOF) - return nil -} - -func (c *connection) copyData(b []byte) int { - n := copy(b, c.readBuf) - c.readBuf = c.readBuf[n:] - return n -} - -func (c *connection) Read(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - - n := c.copyData(b) - if n > 0 { - return n, nil - } - - next, ok := <-c.buf - if !ok { - err := io.EOF - c.Lock() - if c.err != nil { - err = c.err - } - c.Unlock() - return 0, err - } - - c.readBuf = next - n = c.copyData(b) - return n, nil -} - -func (c *connection) Write(b []byte) (int, error) { - c.Lock() - if c.err != nil { - defer c.Unlock() - return 0, c.err - } - c.Unlock() - - deadline := int64(0) - if !c.writeDeadline.IsZero() { - deadline = c.writeDeadline.Sub(time.Now()).Nanoseconds() / 1000000 - } - return c.session.writeMessage(newMessage(c.connID, deadline, b)) -} - -func (c *connection) writeErr(err error) { - if err != nil { - c.session.writeMessage(newErrorMessage(c.connID, err)) - } -} - -func (c *connection) LocalAddr() net.Addr { - return c.addr -} - -func (c *connection) RemoteAddr() net.Addr { - return c.addr -} - -func (c *connection) SetDeadline(t time.Time) error { - if err := c.SetReadDeadline(t); err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -func (c *connection) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *connection) SetWriteDeadline(t time.Time) error { - c.writeDeadline = t - return nil -} - -type addr struct { - proto string - address string -} - -func (a addr) Network() string { - return a.proto -} - -func (a addr) String() string { - return a.address -} - -type chanWriter struct { - conn *connection - C chan []byte -} - -func (c chanWriter) Write(buf []byte) (int, error) { - c.conn.Lock() - defer c.conn.Unlock() - - if c.conn.err != nil { - return 0, c.conn.err - } - - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - buf = newBuf - - select { - // must copy the buffer - case c.C <- buf: - return len(buf), nil - default: - select { - case c.C <- buf: - return len(buf), nil - case <-time.After(15 * time.Second): - return 0, errors.New("backed up reader") - } - } -} diff --git a/pkg/remotedialer/dialer.go b/pkg/remotedialer/dialer.go deleted file mode 100644 index b2a8350e..00000000 --- a/pkg/remotedialer/dialer.go +++ /dev/null @@ -1,28 +0,0 @@ -package remotedialer - -import ( - "net" - "time" -) - -type Dialer func(network, address string) (net.Conn, error) - -func (s *Server) HasSession(clientKey string) bool { - _, err := s.sessions.getDialer(clientKey, 0) - return err == nil -} - -func (s *Server) Dial(clientKey string, deadline time.Duration, proto, address string) (net.Conn, error) { - d, err := s.sessions.getDialer(clientKey, deadline) - if err != nil { - return nil, err - } - - return d(proto, address) -} - -func (s *Server) Dialer(clientKey string, deadline time.Duration) Dialer { - return func(proto, address string) (net.Conn, error) { - return s.Dial(clientKey, deadline, proto, address) - } -} diff --git a/pkg/remotedialer/dummy/main.go b/pkg/remotedialer/dummy/main.go deleted file mode 100644 index 71313949..00000000 --- a/pkg/remotedialer/dummy/main.go +++ /dev/null @@ -1,29 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "math/rand" - "net/http" - "sync/atomic" - "time" -) - -var ( - counter int64 - listen string -) - -func main() { - flag.StringVar(&listen, "listen", ":8125", "Listen address") - flag.Parse() - - fmt.Println("listening ", listen) - http.ListenAndServe(listen, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - next := atomic.AddInt64(&counter, 1) - fmt.Println("request", next) - - time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond) - rw.Write([]byte("HI")) - })) -} diff --git a/pkg/remotedialer/message.go b/pkg/remotedialer/message.go deleted file mode 100644 index 0572ee5b..00000000 --- a/pkg/remotedialer/message.go +++ /dev/null @@ -1,220 +0,0 @@ -package remotedialer - -import ( - "bufio" - "encoding/binary" - "errors" - "fmt" - "io" - "io/ioutil" - "math/rand" - "strings" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" -) - -const ( - Data messageType = iota + 1 - Connect - Error - AddClient - RemoveClient -) - -var ( - idCounter int64 -) - -func init() { - r := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) - idCounter = r.Int63() -} - -type messageType int64 - -type message struct { - id int64 - err error - connID int64 - deadline int64 - messageType messageType - bytes []byte - body io.Reader - proto string - address string -} - -func nextid() int64 { - return atomic.AddInt64(&idCounter, 1) -} - -func newMessage(connID int64, deadline int64, bytes []byte) *message { - return &message{ - id: nextid(), - connID: connID, - deadline: deadline, - messageType: Data, - bytes: bytes, - } -} - -func newConnect(connID int64, deadline time.Duration, proto, address string) *message { - return &message{ - id: nextid(), - connID: connID, - deadline: deadline.Nanoseconds() / 1000000, - messageType: Connect, - bytes: []byte(fmt.Sprintf("%s/%s", proto, address)), - proto: proto, - address: address, - } -} - -func newErrorMessage(connID int64, err error) *message { - return &message{ - id: nextid(), - err: err, - connID: connID, - messageType: Error, - bytes: []byte(err.Error()), - } -} - -func newAddClient(client string) *message { - return &message{ - id: nextid(), - messageType: AddClient, - address: client, - bytes: []byte(client), - } -} - -func newRemoveClient(client string) *message { - return &message{ - id: nextid(), - messageType: RemoveClient, - address: client, - bytes: []byte(client), - } -} - -func newServerMessage(reader io.Reader) (*message, error) { - buf := bufio.NewReader(reader) - - id, err := binary.ReadVarint(buf) - if err != nil { - return nil, err - } - - connID, err := binary.ReadVarint(buf) - if err != nil { - return nil, err - } - - mType, err := binary.ReadVarint(buf) - if err != nil { - return nil, err - } - - m := &message{ - id: id, - messageType: messageType(mType), - connID: connID, - body: buf, - } - - if m.messageType == Data || m.messageType == Connect { - deadline, err := binary.ReadVarint(buf) - if err != nil { - return nil, err - } - m.deadline = deadline - } - - if m.messageType == Connect { - bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100)) - if err != nil { - return nil, err - } - parts := strings.SplitN(string(bytes), "/", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("failed to parse connect address") - } - m.proto = parts[0] - m.address = parts[1] - m.bytes = bytes - } else if m.messageType == AddClient || m.messageType == RemoveClient { - bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100)) - if err != nil { - return nil, err - } - m.address = string(bytes) - m.bytes = bytes - } - - return m, nil -} - -func (m *message) Err() error { - if m.err != nil { - return m.err - } - bytes, err := ioutil.ReadAll(io.LimitReader(m.body, 100)) - if err != nil { - return err - } - - str := string(bytes) - if str == "EOF" { - m.err = io.EOF - } else { - m.err = errors.New(str) - } - return m.err -} - -func (m *message) Bytes() []byte { - return append(m.header(), m.bytes...) -} - -func (m *message) header() []byte { - buf := make([]byte, 24) - offset := 0 - offset += binary.PutVarint(buf[offset:], m.id) - offset += binary.PutVarint(buf[offset:], m.connID) - offset += binary.PutVarint(buf[offset:], int64(m.messageType)) - if m.messageType == Data || m.messageType == Connect { - offset += binary.PutVarint(buf[offset:], m.deadline) - } - return buf[:offset] -} - -func (m *message) Read(p []byte) (int, error) { - return m.body.Read(p) -} - -func (m *message) WriteTo(wsConn *wsConn) (int, error) { - err := wsConn.WriteMessage(websocket.BinaryMessage, m.Bytes()) - return len(m.bytes), err -} - -func (m *message) String() string { - switch m.messageType { - case Data: - if m.body == nil { - return fmt.Sprintf("%d DATA [%d]: %d bytes: %s", m.id, m.connID, len(m.bytes), string(m.bytes)) - } - return fmt.Sprintf("%d DATA [%d]: buffered", m.id, m.connID) - case Error: - return fmt.Sprintf("%d ERROR [%d]: %s", m.id, m.connID, m.Err()) - case Connect: - return fmt.Sprintf("%d CONNECT [%d]: %s/%s deadline %d", m.id, m.connID, m.proto, m.address, m.deadline) - case AddClient: - return fmt.Sprintf("%d ADDCLIENT [%s]", m.id, m.address) - case RemoveClient: - return fmt.Sprintf("%d REMOVECLIENT [%s]", m.id, m.address) - } - return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType) -} diff --git a/pkg/remotedialer/peer.go b/pkg/remotedialer/peer.go deleted file mode 100644 index 452ea57d..00000000 --- a/pkg/remotedialer/peer.go +++ /dev/null @@ -1,121 +0,0 @@ -package remotedialer - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" -) - -var ( - Token = "X-API-Tunnel-Token" - ID = "X-API-Tunnel-ID" -) - -func (s *Server) AddPeer(url, id, token string) { - if s.PeerID == "" || s.PeerToken == "" { - return - } - - ctx, cancel := context.WithCancel(context.Background()) - peer := peer{ - url: url, - id: id, - token: token, - cancel: cancel, - } - - logrus.Infof("Adding peer %s, %s", url, id) - - s.peerLock.Lock() - defer s.peerLock.Unlock() - - if p, ok := s.peers[id]; ok { - if p.equals(peer) { - return - } - p.cancel() - } - - s.peers[id] = peer - go peer.start(ctx, s) -} - -func (s *Server) RemovePeer(id string) { - s.peerLock.Lock() - defer s.peerLock.Unlock() - - if p, ok := s.peers[id]; ok { - logrus.Infof("Removing peer %s", id) - p.cancel() - } - delete(s.peers, id) -} - -type peer struct { - url, id, token string - cancel func() -} - -func (p peer) equals(other peer) bool { - return p.url == other.url && - p.id == other.id && - p.token == other.token -} - -func (p *peer) start(ctx context.Context, s *Server) { - headers := http.Header{ - ID: {s.PeerID}, - Token: {s.PeerToken}, - } - - dialer := &websocket.Dialer{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - HandshakeTimeout: 10 * time.Second, - } - -outer: - for { - select { - case <-ctx.Done(): - break outer - default: - } - - ws, _, err := dialer.Dial(p.url, headers) - if err != nil { - logrus.Errorf("Failed to connect to peer %s [local ID=%s]: %v", p.url, s.PeerID, err) - time.Sleep(5 * time.Second) - continue - } - - session := NewClientSession(func(string, string) bool { return true }, ws) - session.dialer = func(network, address string) (net.Conn, error) { - parts := strings.SplitN(network, "::", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid clientKey/proto: %s", network) - } - return s.Dial(parts[0], 15*time.Second, parts[1], address) - } - - s.sessions.addListener(session) - _, err = session.Serve() - s.sessions.removeListener(session) - session.Close() - - if err != nil { - logrus.Errorf("Failed to serve peer connection %s: %v", p.id, err) - } - - ws.Close() - time.Sleep(5 * time.Second) - } -} diff --git a/pkg/remotedialer/server.go b/pkg/remotedialer/server.go deleted file mode 100644 index dc87202d..00000000 --- a/pkg/remotedialer/server.go +++ /dev/null @@ -1,97 +0,0 @@ -package remotedialer - -import ( - "net/http" - "sync" - "time" - - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -var ( - errFailedAuth = errors.New("failed authentication") - errWrongMessageType = errors.New("wrong websocket message type") -) - -type Authorizer func(req *http.Request) (clientKey string, authed bool, err error) -type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error) - -func DefaultErrorWriter(rw http.ResponseWriter, req *http.Request, code int, err error) { - rw.Write([]byte(err.Error())) - rw.WriteHeader(code) -} - -type Server struct { - PeerID string - PeerToken string - authorizer Authorizer - errorWriter ErrorWriter - sessions *sessionManager - peers map[string]peer - peerLock sync.Mutex -} - -func New(auth Authorizer, errorWriter ErrorWriter) *Server { - return &Server{ - peers: map[string]peer{}, - authorizer: auth, - errorWriter: errorWriter, - sessions: newSessionManager(), - } -} - -func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - clientKey, authed, peer, err := s.auth(req) - if err != nil { - s.errorWriter(rw, req, 400, err) - return - } - if !authed { - s.errorWriter(rw, req, 401, errFailedAuth) - return - } - - logrus.Infof("Handling backend connection request [%s]", clientKey) - - upgrader := websocket.Upgrader{ - HandshakeTimeout: 5 * time.Second, - CheckOrigin: func(r *http.Request) bool { return true }, - Error: s.errorWriter, - } - - wsConn, err := upgrader.Upgrade(rw, req, nil) - if err != nil { - s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey)) - return - } - - session := s.sessions.add(clientKey, wsConn, peer) - defer s.sessions.remove(session) - - // Don't need to associate req.Context() to the Session, it will cancel otherwise - code, err := session.Serve() - if err != nil { - // Hijacked so we can't write to the client - logrus.Infof("error in remotedialer server [%d]: %v", code, err) - } -} - -func (s *Server) auth(req *http.Request) (clientKey string, authed, peer bool, err error) { - id := req.Header.Get(ID) - token := req.Header.Get(Token) - if id != "" && token != "" { - // peer authentication - s.peerLock.Lock() - p, ok := s.peers[id] - s.peerLock.Unlock() - - if ok && p.token == token { - return id, true, true, nil - } - } - - id, authed, err = s.authorizer(req) - return id, authed, false, err -} diff --git a/pkg/remotedialer/server/main.go b/pkg/remotedialer/server/main.go deleted file mode 100644 index 8db53272..00000000 --- a/pkg/remotedialer/server/main.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/gorilla/mux" - "github.com/rancher/norman/pkg/remotedialer" - "github.com/sirupsen/logrus" -) - -var ( - clients = map[string]*http.Client{} - l sync.Mutex - counter int64 -) - -func authorizer(req *http.Request) (string, bool, error) { - id := req.Header.Get("x-tunnel-id") - return id, id != "", nil -} - -func Client(server *remotedialer.Server, rw http.ResponseWriter, req *http.Request) { - timeout := req.URL.Query().Get("timeout") - if timeout == "" { - timeout = "15" - } - - vars := mux.Vars(req) - clientKey := vars["id"] - url := fmt.Sprintf("%s://%s%s", vars["scheme"], vars["host"], vars["path"]) - client := getClient(server, clientKey, timeout) - - id := atomic.AddInt64(&counter, 1) - logrus.Infof("[%03d] REQ t=%s %s", id, timeout, url) - - resp, err := client.Get(url) - if err != nil { - logrus.Errorf("[%03d] REQ ERR t=%s %s: %v", id, timeout, url, err) - remotedialer.DefaultErrorWriter(rw, req, 500, err) - return - } - defer resp.Body.Close() - - logrus.Infof("[%03d] REQ OK t=%s %s", id, timeout, url) - rw.WriteHeader(resp.StatusCode) - io.Copy(rw, resp.Body) - logrus.Infof("[%03d] REQ DONE t=%s %s", id, timeout, url) -} - -func getClient(server *remotedialer.Server, clientKey, timeout string) *http.Client { - l.Lock() - defer l.Unlock() - - key := fmt.Sprintf("%s/%s", clientKey, timeout) - client := clients[key] - if client != nil { - return client - } - - dialer := server.Dialer(clientKey, 15*time.Second) - client = &http.Client{ - Transport: &http.Transport{ - Dial: dialer, - }, - } - if timeout != "" { - t, err := strconv.Atoi(timeout) - if err == nil { - client.Timeout = time.Duration(t) * time.Second - } - } - - clients[key] = client - return client -} - -func main() { - var ( - addr string - peerID string - peerToken string - peers string - debug bool - ) - flag.StringVar(&addr, "listen", ":8123", "Listen address") - flag.StringVar(&peerID, "id", "", "Peer ID") - flag.StringVar(&peerToken, "token", "", "Peer Token") - flag.StringVar(&peers, "peers", "", "Peers format id:token:url,id:token:url") - flag.BoolVar(&debug, "debug", false, "Enable debug logging") - flag.Parse() - - if debug { - logrus.SetLevel(logrus.DebugLevel) - remotedialer.PrintTunnelData = true - } - - handler := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter) - handler.PeerToken = peerToken - handler.PeerID = peerID - - for _, peer := range strings.Split(peers, ",") { - parts := strings.SplitN(strings.TrimSpace(peer), ":", 3) - if len(parts) != 3 { - continue - } - handler.AddPeer(parts[2], parts[0], parts[1]) - } - - router := mux.NewRouter() - router.Handle("/connect", handler) - router.HandleFunc("/client/{id}/{scheme}/{host}{path:.*}", func(rw http.ResponseWriter, req *http.Request) { - Client(handler, rw, req) - }) - - fmt.Println("Listening on ", addr) - http.ListenAndServe(addr, router) -} diff --git a/pkg/remotedialer/session.go b/pkg/remotedialer/session.go deleted file mode 100644 index 5f440d4d..00000000 --- a/pkg/remotedialer/session.go +++ /dev/null @@ -1,303 +0,0 @@ -package remotedialer - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" -) - -type Session struct { - sync.Mutex - - nextConnID int64 - clientKey string - sessionKey int64 - conn *wsConn - conns map[int64]*connection - remoteClientKeys map[string]map[int]bool - auth ConnectAuthorizer - pingCancel context.CancelFunc - pingWait sync.WaitGroup - dialer Dialer - client bool -} - -// PrintTunnelData No tunnel logging by default -var PrintTunnelData bool - -func init() { - if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" { - PrintTunnelData = true - } -} - -func NewClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *Session { - return &Session{ - clientKey: "client", - conn: newWSConn(conn), - conns: map[int64]*connection{}, - auth: auth, - client: true, - } -} - -func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session { - return &Session{ - nextConnID: 1, - clientKey: clientKey, - sessionKey: sessionKey, - conn: newWSConn(conn), - conns: map[int64]*connection{}, - remoteClientKeys: map[string]map[int]bool{}, - } -} - -func (s *Session) startPings() { - ctx, cancel := context.WithCancel(context.Background()) - s.pingCancel = cancel - s.pingWait.Add(1) - - go func() { - defer s.pingWait.Done() - - t := time.NewTicker(PingWriteInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-t.C: - s.conn.Lock() - if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil { - logrus.WithError(err).Error("Error writing ping") - } - logrus.Debug("Wrote ping") - s.conn.Unlock() - } - } - }() -} - -func (s *Session) stopPings() { - if s.pingCancel == nil { - return - } - - s.pingCancel() - s.pingWait.Wait() -} - -func (s *Session) Serve() (int, error) { - if s.client { - s.startPings() - } - - for { - msType, reader, err := s.conn.NextReader() - if err != nil { - return 400, err - } - - if msType != websocket.BinaryMessage { - return 400, errWrongMessageType - } - - if err := s.serveMessage(reader); err != nil { - return 500, err - } - } -} - -func (s *Session) serveMessage(reader io.Reader) error { - message, err := newServerMessage(reader) - if err != nil { - return err - } - - if PrintTunnelData { - logrus.Debug("REQUEST ", message) - } - - if message.messageType == Connect { - if s.auth == nil || !s.auth(message.proto, message.address) { - return errors.New("connect not allowed") - } - s.clientConnect(message) - return nil - } - - s.Lock() - if message.messageType == AddClient && s.remoteClientKeys != nil { - err := s.addRemoteClient(message.address) - s.Unlock() - return err - } else if message.messageType == RemoveClient { - err := s.removeRemoteClient(message.address) - s.Unlock() - return err - } - conn := s.conns[message.connID] - s.Unlock() - - if conn == nil { - if message.messageType == Data { - err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID) - newErrorMessage(message.connID, err).WriteTo(s.conn) - } - return nil - } - - switch message.messageType { - case Data: - if _, err := io.Copy(conn.tunnelWriter(), message); err != nil { - s.closeConnection(message.connID, err) - } - case Error: - s.closeConnection(message.connID, message.Err()) - } - - return nil -} - -func parseAddress(address string) (string, int, error) { - parts := strings.SplitN(address, "/", 2) - if len(parts) != 2 { - return "", 0, errors.New("not / separated") - } - v, err := strconv.Atoi(parts[1]) - return parts[0], v, err -} - -func (s *Session) addRemoteClient(address string) error { - clientKey, sessionKey, err := parseAddress(address) - if err != nil { - return fmt.Errorf("invalid remote Session %s: %v", address, err) - } - - keys := s.remoteClientKeys[clientKey] - if keys == nil { - keys = map[int]bool{} - s.remoteClientKeys[clientKey] = keys - } - keys[int(sessionKey)] = true - - if PrintTunnelData { - logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) - } - - return nil -} - -func (s *Session) removeRemoteClient(address string) error { - clientKey, sessionKey, err := parseAddress(address) - if err != nil { - return fmt.Errorf("invalid remote Session %s: %v", address, err) - } - - keys := s.remoteClientKeys[clientKey] - delete(keys, int(sessionKey)) - if len(keys) == 0 { - delete(s.remoteClientKeys, clientKey) - } - - if PrintTunnelData { - logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) - } - - return nil -} - -func (s *Session) closeConnection(connID int64, err error) { - s.Lock() - conn := s.conns[connID] - delete(s.conns, connID) - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() - - if conn != nil { - conn.tunnelClose(err) - } -} - -func (s *Session) clientConnect(message *message) { - conn := newConnection(message.connID, s, message.proto, message.address) - - s.Lock() - s.conns[message.connID] = conn - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() - - go clientDial(s.dialer, conn, message) -} - -func (s *Session) serverConnect(deadline time.Duration, proto, address string) (net.Conn, error) { - connID := atomic.AddInt64(&s.nextConnID, 1) - conn := newConnection(connID, s, proto, address) - - s.Lock() - s.conns[connID] = conn - if PrintTunnelData { - logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) - } - s.Unlock() - - _, err := s.writeMessage(newConnect(connID, deadline, proto, address)) - if err != nil { - s.closeConnection(connID, err) - return nil, err - } - - return conn, err -} - -func (s *Session) writeMessage(message *message) (int, error) { - if PrintTunnelData { - logrus.Debug("WRITE ", message) - } - return message.WriteTo(s.conn) -} - -func (s *Session) Close() { - s.Lock() - defer s.Unlock() - - s.stopPings() - - for _, connection := range s.conns { - connection.tunnelClose(errors.New("tunnel disconnect")) - } - - s.conns = map[int64]*connection{} -} - -func (s *Session) sessionAdded(clientKey string, sessionKey int64) { - client := fmt.Sprintf("%s/%d", clientKey, sessionKey) - _, err := s.writeMessage(newAddClient(client)) - if err != nil { - s.conn.conn.Close() - } -} - -func (s *Session) sessionRemoved(clientKey string, sessionKey int64) { - client := fmt.Sprintf("%s/%d", clientKey, sessionKey) - _, err := s.writeMessage(newRemoveClient(client)) - if err != nil { - s.conn.conn.Close() - } -} diff --git a/pkg/remotedialer/session_manager.go b/pkg/remotedialer/session_manager.go deleted file mode 100644 index ae26b412..00000000 --- a/pkg/remotedialer/session_manager.go +++ /dev/null @@ -1,137 +0,0 @@ -package remotedialer - -import ( - "fmt" - "math/rand" - "net" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type sessionListener interface { - sessionAdded(clientKey string, sessionKey int64) - sessionRemoved(clientKey string, sessionKey int64) -} - -type sessionManager struct { - sync.Mutex - clients map[string][]*Session - peers map[string][]*Session - listeners map[sessionListener]bool -} - -func newSessionManager() *sessionManager { - return &sessionManager{ - clients: map[string][]*Session{}, - peers: map[string][]*Session{}, - listeners: map[sessionListener]bool{}, - } -} - -func toDialer(s *Session, prefix string, deadline time.Duration) Dialer { - return func(proto, address string) (net.Conn, error) { - if prefix == "" { - return s.serverConnect(deadline, proto, address) - } - return s.serverConnect(deadline, prefix+"::"+proto, address) - } -} - -func (sm *sessionManager) removeListener(listener sessionListener) { - sm.Lock() - defer sm.Unlock() - - delete(sm.listeners, listener) -} - -func (sm *sessionManager) addListener(listener sessionListener) { - sm.Lock() - defer sm.Unlock() - - sm.listeners[listener] = true - - for k, sessions := range sm.clients { - for _, session := range sessions { - listener.sessionAdded(k, session.sessionKey) - } - } - - for k, sessions := range sm.peers { - for _, session := range sessions { - listener.sessionAdded(k, session.sessionKey) - } - } -} - -func (sm *sessionManager) getDialer(clientKey string, deadline time.Duration) (Dialer, error) { - sm.Lock() - defer sm.Unlock() - - sessions := sm.clients[clientKey] - if len(sessions) > 0 { - return toDialer(sessions[0], "", deadline), nil - } - - for _, sessions := range sm.peers { - for _, session := range sessions { - session.Lock() - keys := session.remoteClientKeys[clientKey] - session.Unlock() - if len(keys) > 0 { - return toDialer(session, clientKey, deadline), nil - } - } - } - - return nil, fmt.Errorf("failed to find Session for client %s", clientKey) -} - -func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *Session { - sessionKey := rand.Int63() - session := newSession(sessionKey, clientKey, conn) - - sm.Lock() - defer sm.Unlock() - - if peer { - sm.peers[clientKey] = append(sm.peers[clientKey], session) - } else { - sm.clients[clientKey] = append(sm.clients[clientKey], session) - } - - for l := range sm.listeners { - l.sessionAdded(clientKey, session.sessionKey) - } - - return session -} - -func (sm *sessionManager) remove(s *Session) { - sm.Lock() - defer sm.Unlock() - - for _, store := range []map[string][]*Session{sm.clients, sm.peers} { - var newSessions []*Session - - for _, v := range store[s.clientKey] { - if v.sessionKey == s.sessionKey { - continue - } - newSessions = append(newSessions, v) - } - - if len(newSessions) == 0 { - delete(store, s.clientKey) - } else { - store[s.clientKey] = newSessions - } - } - - for l := range sm.listeners { - l.sessionRemoved(s.clientKey, s.sessionKey) - } - - s.Close() -} diff --git a/pkg/remotedialer/session_windows.go b/pkg/remotedialer/session_windows.go deleted file mode 100644 index e80462ac..00000000 --- a/pkg/remotedialer/session_windows.go +++ /dev/null @@ -1,57 +0,0 @@ -package remotedialer - -import ( - "context" - "time" - - "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" -) - -func (s *Session) startPingsWhileWindows(rootCtx context.Context) { - ctx, cancel := context.WithCancel(rootCtx) - s.pingCancel = cancel - s.pingWait.Add(1) - - go func() { - defer s.pingWait.Done() - - t := time.NewTicker(PingWriteInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-t.C: - s.conn.Lock() - if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil { - logrus.WithError(err).Error("Error writing ping") - } - logrus.Debug("Wrote ping") - s.conn.Unlock() - } - } - }() -} - -func (s *Session) ServeWhileWindows(ctx context.Context) (int, error) { - if s.client { - s.startPingsWhileWindows(ctx) - } - - for { - msType, reader, err := s.conn.NextReader() - if err != nil { - return 400, err - } - - if msType != websocket.BinaryMessage { - return 400, errWrongMessageType - } - - if err := s.serveMessage(reader); err != nil { - return 500, err - } - } -} diff --git a/pkg/remotedialer/types.go b/pkg/remotedialer/types.go deleted file mode 100644 index f1056cd9..00000000 --- a/pkg/remotedialer/types.go +++ /dev/null @@ -1,11 +0,0 @@ -package remotedialer - -import ( - "time" -) - -var ( - PingWaitDuration = time.Duration(10 * time.Second) - PingWriteInterval = time.Duration(5 * time.Second) - MaxRead = 8192 -) diff --git a/pkg/remotedialer/wsconn.go b/pkg/remotedialer/wsconn.go deleted file mode 100644 index 3e519e5d..00000000 --- a/pkg/remotedialer/wsconn.go +++ /dev/null @@ -1,47 +0,0 @@ -package remotedialer - -import ( - "io" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type wsConn struct { - sync.Mutex - conn *websocket.Conn -} - -func newWSConn(conn *websocket.Conn) *wsConn { - w := &wsConn{ - conn: conn, - } - w.setupDeadline() - return w -} - -func (w *wsConn) WriteMessage(messageType int, data []byte) error { - w.Lock() - defer w.Unlock() - w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration)) - return w.conn.WriteMessage(messageType, data) -} - -func (w *wsConn) NextReader() (int, io.Reader, error) { - return w.conn.NextReader() -} - -func (w *wsConn) setupDeadline() { - w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) - w.conn.SetPingHandler(func(string) error { - w.Lock() - w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second)) - w.Unlock() - return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) - }) - w.conn.SetPongHandler(func(string) error { - return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) - }) - -}