1
0
mirror of https://github.com/rancher/norman.git synced 2025-08-25 10:28:37 +00:00

Merge pull request #304 from daxmc99/remove_remotedialer

Remove remotedialer
This commit is contained in:
Darren Shepherd 2019-09-06 16:02:28 -07:00 committed by GitHub
commit 25c20af174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 0 additions and 1516 deletions

View File

@ -1,57 +0,0 @@
package remotedialer
import (
"context"
"io/ioutil"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
type ConnectAuthorizer func(proto, address string) bool
func ClientConnect(wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, onConnect func(context.Context) error) {
if err := connectToProxy(wsURL, headers, auth, dialer, onConnect); err != nil {
logrus.WithError(err).Error("Failed to connect to proxy")
time.Sleep(time.Duration(5) * time.Second)
}
}
func connectToProxy(proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context) error) error {
logrus.WithField("url", proxyURL).Info("Connecting to proxy")
if dialer == nil {
dialer = &websocket.Dialer{HandshakeTimeout: 10 * time.Second}
}
ws, resp, err := dialer.Dial(proxyURL, headers)
if err != nil {
if resp == nil {
logrus.WithError(err).Errorf("Failed to connect to proxy. Empty dialer response")
} else {
rb, err2 := ioutil.ReadAll(resp.Body)
if err2 != nil {
logrus.WithError(err).Errorf("Failed to connect to proxy. Response status: %v - %v. Couldn't read response body (err: %v)", resp.StatusCode, resp.Status, err2)
} else {
logrus.WithError(err).Errorf("Failed to connect to proxy. Response status: %v - %v. Response body: %s", resp.StatusCode, resp.Status, rb)
}
}
return err
}
defer ws.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if onConnect != nil {
if err := onConnect(ctx); err != nil {
return err
}
}
session := NewClientSession(auth, ws)
_, err = session.Serve()
session.Close()
return err
}

View File

@ -1,34 +0,0 @@
// +build !windows
package main
import (
"flag"
"net/http"
"github.com/rancher/norman/pkg/remotedialer"
"github.com/sirupsen/logrus"
)
var (
addr string
id string
debug bool
)
func main() {
flag.StringVar(&addr, "connect", "ws://localhost:8123/connect", "Address to connect to")
flag.StringVar(&id, "id", "foo", "Client ID")
flag.BoolVar(&debug, "debug", true, "Debug logging")
flag.Parse()
if debug {
logrus.SetLevel(logrus.DebugLevel)
}
headers := http.Header{
"X-Tunnel-ID": []string{id},
}
remotedialer.ClientConnect(addr, headers, nil, func(string, string) bool { return true }, nil)
}

View File

@ -1,62 +0,0 @@
package remotedialer
import (
"io"
"net"
"sync"
"time"
)
func clientDial(dialer Dialer, conn *connection, message *message) {
defer conn.Close()
var (
netConn net.Conn
err error
)
if dialer == nil {
netDialer := &net.Dialer{
Timeout: time.Duration(message.deadline) * time.Millisecond,
KeepAlive: 30 * time.Second,
}
netConn, err = netDialer.Dial(message.proto, message.address)
} else {
netConn, err = dialer(message.proto, message.address)
}
if err != nil {
conn.tunnelClose(err)
return
}
defer netConn.Close()
pipe(conn, netConn)
}
func pipe(client *connection, server net.Conn) {
wg := sync.WaitGroup{}
wg.Add(1)
close := func(err error) error {
if err == nil {
err = io.EOF
}
client.doTunnelClose(err)
server.Close()
return err
}
go func() {
defer wg.Done()
_, err := io.Copy(server, client)
close(err)
}()
_, err := io.Copy(client, server)
err = close(err)
wg.Wait()
// Write tunnel error after no more I/O is happening, just incase messages get out of order
client.writeErr(err)
}

View File

@ -1,188 +0,0 @@
package remotedialer
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
)
type connection struct {
sync.Mutex
ctx context.Context
cancel func()
err error
writeDeadline time.Time
buf chan []byte
readBuf []byte
addr addr
session *Session
connID int64
}
func newConnection(connID int64, session *Session, proto, address string) *connection {
c := &connection{
addr: addr{
proto: proto,
address: address,
},
connID: connID,
session: session,
buf: make(chan []byte, 1024),
}
return c
}
func (c *connection) tunnelClose(err error) {
c.writeErr(err)
c.doTunnelClose(err)
}
func (c *connection) doTunnelClose(err error) {
c.Lock()
defer c.Unlock()
if c.err != nil {
return
}
c.err = err
if c.err == nil {
c.err = io.ErrClosedPipe
}
close(c.buf)
}
func (c *connection) tunnelWriter() io.Writer {
return chanWriter{conn: c, C: c.buf}
}
func (c *connection) Close() error {
c.session.closeConnection(c.connID, io.EOF)
return nil
}
func (c *connection) copyData(b []byte) int {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
return n
}
func (c *connection) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
n := c.copyData(b)
if n > 0 {
return n, nil
}
next, ok := <-c.buf
if !ok {
err := io.EOF
c.Lock()
if c.err != nil {
err = c.err
}
c.Unlock()
return 0, err
}
c.readBuf = next
n = c.copyData(b)
return n, nil
}
func (c *connection) Write(b []byte) (int, error) {
c.Lock()
if c.err != nil {
defer c.Unlock()
return 0, c.err
}
c.Unlock()
deadline := int64(0)
if !c.writeDeadline.IsZero() {
deadline = c.writeDeadline.Sub(time.Now()).Nanoseconds() / 1000000
}
return c.session.writeMessage(newMessage(c.connID, deadline, b))
}
func (c *connection) writeErr(err error) {
if err != nil {
c.session.writeMessage(newErrorMessage(c.connID, err))
}
}
func (c *connection) LocalAddr() net.Addr {
return c.addr
}
func (c *connection) RemoteAddr() net.Addr {
return c.addr
}
func (c *connection) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}
func (c *connection) SetReadDeadline(t time.Time) error {
return nil
}
func (c *connection) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t
return nil
}
type addr struct {
proto string
address string
}
func (a addr) Network() string {
return a.proto
}
func (a addr) String() string {
return a.address
}
type chanWriter struct {
conn *connection
C chan []byte
}
func (c chanWriter) Write(buf []byte) (int, error) {
c.conn.Lock()
defer c.conn.Unlock()
if c.conn.err != nil {
return 0, c.conn.err
}
newBuf := make([]byte, len(buf))
copy(newBuf, buf)
buf = newBuf
select {
// must copy the buffer
case c.C <- buf:
return len(buf), nil
default:
select {
case c.C <- buf:
return len(buf), nil
case <-time.After(15 * time.Second):
return 0, errors.New("backed up reader")
}
}
}

View File

@ -1,28 +0,0 @@
package remotedialer
import (
"net"
"time"
)
type Dialer func(network, address string) (net.Conn, error)
func (s *Server) HasSession(clientKey string) bool {
_, err := s.sessions.getDialer(clientKey, 0)
return err == nil
}
func (s *Server) Dial(clientKey string, deadline time.Duration, proto, address string) (net.Conn, error) {
d, err := s.sessions.getDialer(clientKey, deadline)
if err != nil {
return nil, err
}
return d(proto, address)
}
func (s *Server) Dialer(clientKey string, deadline time.Duration) Dialer {
return func(proto, address string) (net.Conn, error) {
return s.Dial(clientKey, deadline, proto, address)
}
}

View File

@ -1,29 +0,0 @@
package main
import (
"flag"
"fmt"
"math/rand"
"net/http"
"sync/atomic"
"time"
)
var (
counter int64
listen string
)
func main() {
flag.StringVar(&listen, "listen", ":8125", "Listen address")
flag.Parse()
fmt.Println("listening ", listen)
http.ListenAndServe(listen, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := atomic.AddInt64(&counter, 1)
fmt.Println("request", next)
time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond)
rw.Write([]byte("HI"))
}))
}

View File

@ -1,220 +0,0 @@
package remotedialer
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"strings"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
const (
Data messageType = iota + 1
Connect
Error
AddClient
RemoveClient
)
var (
idCounter int64
)
func init() {
r := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
idCounter = r.Int63()
}
type messageType int64
type message struct {
id int64
err error
connID int64
deadline int64
messageType messageType
bytes []byte
body io.Reader
proto string
address string
}
func nextid() int64 {
return atomic.AddInt64(&idCounter, 1)
}
func newMessage(connID int64, deadline int64, bytes []byte) *message {
return &message{
id: nextid(),
connID: connID,
deadline: deadline,
messageType: Data,
bytes: bytes,
}
}
func newConnect(connID int64, deadline time.Duration, proto, address string) *message {
return &message{
id: nextid(),
connID: connID,
deadline: deadline.Nanoseconds() / 1000000,
messageType: Connect,
bytes: []byte(fmt.Sprintf("%s/%s", proto, address)),
proto: proto,
address: address,
}
}
func newErrorMessage(connID int64, err error) *message {
return &message{
id: nextid(),
err: err,
connID: connID,
messageType: Error,
bytes: []byte(err.Error()),
}
}
func newAddClient(client string) *message {
return &message{
id: nextid(),
messageType: AddClient,
address: client,
bytes: []byte(client),
}
}
func newRemoveClient(client string) *message {
return &message{
id: nextid(),
messageType: RemoveClient,
address: client,
bytes: []byte(client),
}
}
func newServerMessage(reader io.Reader) (*message, error) {
buf := bufio.NewReader(reader)
id, err := binary.ReadVarint(buf)
if err != nil {
return nil, err
}
connID, err := binary.ReadVarint(buf)
if err != nil {
return nil, err
}
mType, err := binary.ReadVarint(buf)
if err != nil {
return nil, err
}
m := &message{
id: id,
messageType: messageType(mType),
connID: connID,
body: buf,
}
if m.messageType == Data || m.messageType == Connect {
deadline, err := binary.ReadVarint(buf)
if err != nil {
return nil, err
}
m.deadline = deadline
}
if m.messageType == Connect {
bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100))
if err != nil {
return nil, err
}
parts := strings.SplitN(string(bytes), "/", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("failed to parse connect address")
}
m.proto = parts[0]
m.address = parts[1]
m.bytes = bytes
} else if m.messageType == AddClient || m.messageType == RemoveClient {
bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100))
if err != nil {
return nil, err
}
m.address = string(bytes)
m.bytes = bytes
}
return m, nil
}
func (m *message) Err() error {
if m.err != nil {
return m.err
}
bytes, err := ioutil.ReadAll(io.LimitReader(m.body, 100))
if err != nil {
return err
}
str := string(bytes)
if str == "EOF" {
m.err = io.EOF
} else {
m.err = errors.New(str)
}
return m.err
}
func (m *message) Bytes() []byte {
return append(m.header(), m.bytes...)
}
func (m *message) header() []byte {
buf := make([]byte, 24)
offset := 0
offset += binary.PutVarint(buf[offset:], m.id)
offset += binary.PutVarint(buf[offset:], m.connID)
offset += binary.PutVarint(buf[offset:], int64(m.messageType))
if m.messageType == Data || m.messageType == Connect {
offset += binary.PutVarint(buf[offset:], m.deadline)
}
return buf[:offset]
}
func (m *message) Read(p []byte) (int, error) {
return m.body.Read(p)
}
func (m *message) WriteTo(wsConn *wsConn) (int, error) {
err := wsConn.WriteMessage(websocket.BinaryMessage, m.Bytes())
return len(m.bytes), err
}
func (m *message) String() string {
switch m.messageType {
case Data:
if m.body == nil {
return fmt.Sprintf("%d DATA [%d]: %d bytes: %s", m.id, m.connID, len(m.bytes), string(m.bytes))
}
return fmt.Sprintf("%d DATA [%d]: buffered", m.id, m.connID)
case Error:
return fmt.Sprintf("%d ERROR [%d]: %s", m.id, m.connID, m.Err())
case Connect:
return fmt.Sprintf("%d CONNECT [%d]: %s/%s deadline %d", m.id, m.connID, m.proto, m.address, m.deadline)
case AddClient:
return fmt.Sprintf("%d ADDCLIENT [%s]", m.id, m.address)
case RemoveClient:
return fmt.Sprintf("%d REMOVECLIENT [%s]", m.id, m.address)
}
return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType)
}

View File

@ -1,121 +0,0 @@
package remotedialer
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
var (
Token = "X-API-Tunnel-Token"
ID = "X-API-Tunnel-ID"
)
func (s *Server) AddPeer(url, id, token string) {
if s.PeerID == "" || s.PeerToken == "" {
return
}
ctx, cancel := context.WithCancel(context.Background())
peer := peer{
url: url,
id: id,
token: token,
cancel: cancel,
}
logrus.Infof("Adding peer %s, %s", url, id)
s.peerLock.Lock()
defer s.peerLock.Unlock()
if p, ok := s.peers[id]; ok {
if p.equals(peer) {
return
}
p.cancel()
}
s.peers[id] = peer
go peer.start(ctx, s)
}
func (s *Server) RemovePeer(id string) {
s.peerLock.Lock()
defer s.peerLock.Unlock()
if p, ok := s.peers[id]; ok {
logrus.Infof("Removing peer %s", id)
p.cancel()
}
delete(s.peers, id)
}
type peer struct {
url, id, token string
cancel func()
}
func (p peer) equals(other peer) bool {
return p.url == other.url &&
p.id == other.id &&
p.token == other.token
}
func (p *peer) start(ctx context.Context, s *Server) {
headers := http.Header{
ID: {s.PeerID},
Token: {s.PeerToken},
}
dialer := &websocket.Dialer{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
HandshakeTimeout: 10 * time.Second,
}
outer:
for {
select {
case <-ctx.Done():
break outer
default:
}
ws, _, err := dialer.Dial(p.url, headers)
if err != nil {
logrus.Errorf("Failed to connect to peer %s [local ID=%s]: %v", p.url, s.PeerID, err)
time.Sleep(5 * time.Second)
continue
}
session := NewClientSession(func(string, string) bool { return true }, ws)
session.dialer = func(network, address string) (net.Conn, error) {
parts := strings.SplitN(network, "::", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid clientKey/proto: %s", network)
}
return s.Dial(parts[0], 15*time.Second, parts[1], address)
}
s.sessions.addListener(session)
_, err = session.Serve()
s.sessions.removeListener(session)
session.Close()
if err != nil {
logrus.Errorf("Failed to serve peer connection %s: %v", p.id, err)
}
ws.Close()
time.Sleep(5 * time.Second)
}
}

View File

@ -1,97 +0,0 @@
package remotedialer
import (
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
var (
errFailedAuth = errors.New("failed authentication")
errWrongMessageType = errors.New("wrong websocket message type")
)
type Authorizer func(req *http.Request) (clientKey string, authed bool, err error)
type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error)
func DefaultErrorWriter(rw http.ResponseWriter, req *http.Request, code int, err error) {
rw.Write([]byte(err.Error()))
rw.WriteHeader(code)
}
type Server struct {
PeerID string
PeerToken string
authorizer Authorizer
errorWriter ErrorWriter
sessions *sessionManager
peers map[string]peer
peerLock sync.Mutex
}
func New(auth Authorizer, errorWriter ErrorWriter) *Server {
return &Server{
peers: map[string]peer{},
authorizer: auth,
errorWriter: errorWriter,
sessions: newSessionManager(),
}
}
func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
clientKey, authed, peer, err := s.auth(req)
if err != nil {
s.errorWriter(rw, req, 400, err)
return
}
if !authed {
s.errorWriter(rw, req, 401, errFailedAuth)
return
}
logrus.Infof("Handling backend connection request [%s]", clientKey)
upgrader := websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
Error: s.errorWriter,
}
wsConn, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey))
return
}
session := s.sessions.add(clientKey, wsConn, peer)
defer s.sessions.remove(session)
// Don't need to associate req.Context() to the Session, it will cancel otherwise
code, err := session.Serve()
if err != nil {
// Hijacked so we can't write to the client
logrus.Infof("error in remotedialer server [%d]: %v", code, err)
}
}
func (s *Server) auth(req *http.Request) (clientKey string, authed, peer bool, err error) {
id := req.Header.Get(ID)
token := req.Header.Get(Token)
if id != "" && token != "" {
// peer authentication
s.peerLock.Lock()
p, ok := s.peers[id]
s.peerLock.Unlock()
if ok && p.token == token {
return id, true, true, nil
}
}
id, authed, err = s.authorizer(req)
return id, authed, false, err
}

View File

@ -1,125 +0,0 @@
package main
import (
"flag"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/mux"
"github.com/rancher/norman/pkg/remotedialer"
"github.com/sirupsen/logrus"
)
var (
clients = map[string]*http.Client{}
l sync.Mutex
counter int64
)
func authorizer(req *http.Request) (string, bool, error) {
id := req.Header.Get("x-tunnel-id")
return id, id != "", nil
}
func Client(server *remotedialer.Server, rw http.ResponseWriter, req *http.Request) {
timeout := req.URL.Query().Get("timeout")
if timeout == "" {
timeout = "15"
}
vars := mux.Vars(req)
clientKey := vars["id"]
url := fmt.Sprintf("%s://%s%s", vars["scheme"], vars["host"], vars["path"])
client := getClient(server, clientKey, timeout)
id := atomic.AddInt64(&counter, 1)
logrus.Infof("[%03d] REQ t=%s %s", id, timeout, url)
resp, err := client.Get(url)
if err != nil {
logrus.Errorf("[%03d] REQ ERR t=%s %s: %v", id, timeout, url, err)
remotedialer.DefaultErrorWriter(rw, req, 500, err)
return
}
defer resp.Body.Close()
logrus.Infof("[%03d] REQ OK t=%s %s", id, timeout, url)
rw.WriteHeader(resp.StatusCode)
io.Copy(rw, resp.Body)
logrus.Infof("[%03d] REQ DONE t=%s %s", id, timeout, url)
}
func getClient(server *remotedialer.Server, clientKey, timeout string) *http.Client {
l.Lock()
defer l.Unlock()
key := fmt.Sprintf("%s/%s", clientKey, timeout)
client := clients[key]
if client != nil {
return client
}
dialer := server.Dialer(clientKey, 15*time.Second)
client = &http.Client{
Transport: &http.Transport{
Dial: dialer,
},
}
if timeout != "" {
t, err := strconv.Atoi(timeout)
if err == nil {
client.Timeout = time.Duration(t) * time.Second
}
}
clients[key] = client
return client
}
func main() {
var (
addr string
peerID string
peerToken string
peers string
debug bool
)
flag.StringVar(&addr, "listen", ":8123", "Listen address")
flag.StringVar(&peerID, "id", "", "Peer ID")
flag.StringVar(&peerToken, "token", "", "Peer Token")
flag.StringVar(&peers, "peers", "", "Peers format id:token:url,id:token:url")
flag.BoolVar(&debug, "debug", false, "Enable debug logging")
flag.Parse()
if debug {
logrus.SetLevel(logrus.DebugLevel)
remotedialer.PrintTunnelData = true
}
handler := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter)
handler.PeerToken = peerToken
handler.PeerID = peerID
for _, peer := range strings.Split(peers, ",") {
parts := strings.SplitN(strings.TrimSpace(peer), ":", 3)
if len(parts) != 3 {
continue
}
handler.AddPeer(parts[2], parts[0], parts[1])
}
router := mux.NewRouter()
router.Handle("/connect", handler)
router.HandleFunc("/client/{id}/{scheme}/{host}{path:.*}", func(rw http.ResponseWriter, req *http.Request) {
Client(handler, rw, req)
})
fmt.Println("Listening on ", addr)
http.ListenAndServe(addr, router)
}

View File

@ -1,303 +0,0 @@
package remotedialer
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
type Session struct {
sync.Mutex
nextConnID int64
clientKey string
sessionKey int64
conn *wsConn
conns map[int64]*connection
remoteClientKeys map[string]map[int]bool
auth ConnectAuthorizer
pingCancel context.CancelFunc
pingWait sync.WaitGroup
dialer Dialer
client bool
}
// PrintTunnelData No tunnel logging by default
var PrintTunnelData bool
func init() {
if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" {
PrintTunnelData = true
}
}
func NewClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *Session {
return &Session{
clientKey: "client",
conn: newWSConn(conn),
conns: map[int64]*connection{},
auth: auth,
client: true,
}
}
func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *Session {
return &Session{
nextConnID: 1,
clientKey: clientKey,
sessionKey: sessionKey,
conn: newWSConn(conn),
conns: map[int64]*connection{},
remoteClientKeys: map[string]map[int]bool{},
}
}
func (s *Session) startPings() {
ctx, cancel := context.WithCancel(context.Background())
s.pingCancel = cancel
s.pingWait.Add(1)
go func() {
defer s.pingWait.Done()
t := time.NewTicker(PingWriteInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
s.conn.Lock()
if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil {
logrus.WithError(err).Error("Error writing ping")
}
logrus.Debug("Wrote ping")
s.conn.Unlock()
}
}
}()
}
func (s *Session) stopPings() {
if s.pingCancel == nil {
return
}
s.pingCancel()
s.pingWait.Wait()
}
func (s *Session) Serve() (int, error) {
if s.client {
s.startPings()
}
for {
msType, reader, err := s.conn.NextReader()
if err != nil {
return 400, err
}
if msType != websocket.BinaryMessage {
return 400, errWrongMessageType
}
if err := s.serveMessage(reader); err != nil {
return 500, err
}
}
}
func (s *Session) serveMessage(reader io.Reader) error {
message, err := newServerMessage(reader)
if err != nil {
return err
}
if PrintTunnelData {
logrus.Debug("REQUEST ", message)
}
if message.messageType == Connect {
if s.auth == nil || !s.auth(message.proto, message.address) {
return errors.New("connect not allowed")
}
s.clientConnect(message)
return nil
}
s.Lock()
if message.messageType == AddClient && s.remoteClientKeys != nil {
err := s.addRemoteClient(message.address)
s.Unlock()
return err
} else if message.messageType == RemoveClient {
err := s.removeRemoteClient(message.address)
s.Unlock()
return err
}
conn := s.conns[message.connID]
s.Unlock()
if conn == nil {
if message.messageType == Data {
err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID)
newErrorMessage(message.connID, err).WriteTo(s.conn)
}
return nil
}
switch message.messageType {
case Data:
if _, err := io.Copy(conn.tunnelWriter(), message); err != nil {
s.closeConnection(message.connID, err)
}
case Error:
s.closeConnection(message.connID, message.Err())
}
return nil
}
func parseAddress(address string) (string, int, error) {
parts := strings.SplitN(address, "/", 2)
if len(parts) != 2 {
return "", 0, errors.New("not / separated")
}
v, err := strconv.Atoi(parts[1])
return parts[0], v, err
}
func (s *Session) addRemoteClient(address string) error {
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}
keys := s.remoteClientKeys[clientKey]
if keys == nil {
keys = map[int]bool{}
s.remoteClientKeys[clientKey] = keys
}
keys[int(sessionKey)] = true
if PrintTunnelData {
logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}
return nil
}
func (s *Session) removeRemoteClient(address string) error {
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}
keys := s.remoteClientKeys[clientKey]
delete(keys, int(sessionKey))
if len(keys) == 0 {
delete(s.remoteClientKeys, clientKey)
}
if PrintTunnelData {
logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}
return nil
}
func (s *Session) closeConnection(connID int64, err error) {
s.Lock()
conn := s.conns[connID]
delete(s.conns, connID)
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()
if conn != nil {
conn.tunnelClose(err)
}
}
func (s *Session) clientConnect(message *message) {
conn := newConnection(message.connID, s, message.proto, message.address)
s.Lock()
s.conns[message.connID] = conn
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()
go clientDial(s.dialer, conn, message)
}
func (s *Session) serverConnect(deadline time.Duration, proto, address string) (net.Conn, error) {
connID := atomic.AddInt64(&s.nextConnID, 1)
conn := newConnection(connID, s, proto, address)
s.Lock()
s.conns[connID] = conn
if PrintTunnelData {
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
}
s.Unlock()
_, err := s.writeMessage(newConnect(connID, deadline, proto, address))
if err != nil {
s.closeConnection(connID, err)
return nil, err
}
return conn, err
}
func (s *Session) writeMessage(message *message) (int, error) {
if PrintTunnelData {
logrus.Debug("WRITE ", message)
}
return message.WriteTo(s.conn)
}
func (s *Session) Close() {
s.Lock()
defer s.Unlock()
s.stopPings()
for _, connection := range s.conns {
connection.tunnelClose(errors.New("tunnel disconnect"))
}
s.conns = map[int64]*connection{}
}
func (s *Session) sessionAdded(clientKey string, sessionKey int64) {
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
_, err := s.writeMessage(newAddClient(client))
if err != nil {
s.conn.conn.Close()
}
}
func (s *Session) sessionRemoved(clientKey string, sessionKey int64) {
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
_, err := s.writeMessage(newRemoveClient(client))
if err != nil {
s.conn.conn.Close()
}
}

View File

@ -1,137 +0,0 @@
package remotedialer
import (
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/gorilla/websocket"
)
type sessionListener interface {
sessionAdded(clientKey string, sessionKey int64)
sessionRemoved(clientKey string, sessionKey int64)
}
type sessionManager struct {
sync.Mutex
clients map[string][]*Session
peers map[string][]*Session
listeners map[sessionListener]bool
}
func newSessionManager() *sessionManager {
return &sessionManager{
clients: map[string][]*Session{},
peers: map[string][]*Session{},
listeners: map[sessionListener]bool{},
}
}
func toDialer(s *Session, prefix string, deadline time.Duration) Dialer {
return func(proto, address string) (net.Conn, error) {
if prefix == "" {
return s.serverConnect(deadline, proto, address)
}
return s.serverConnect(deadline, prefix+"::"+proto, address)
}
}
func (sm *sessionManager) removeListener(listener sessionListener) {
sm.Lock()
defer sm.Unlock()
delete(sm.listeners, listener)
}
func (sm *sessionManager) addListener(listener sessionListener) {
sm.Lock()
defer sm.Unlock()
sm.listeners[listener] = true
for k, sessions := range sm.clients {
for _, session := range sessions {
listener.sessionAdded(k, session.sessionKey)
}
}
for k, sessions := range sm.peers {
for _, session := range sessions {
listener.sessionAdded(k, session.sessionKey)
}
}
}
func (sm *sessionManager) getDialer(clientKey string, deadline time.Duration) (Dialer, error) {
sm.Lock()
defer sm.Unlock()
sessions := sm.clients[clientKey]
if len(sessions) > 0 {
return toDialer(sessions[0], "", deadline), nil
}
for _, sessions := range sm.peers {
for _, session := range sessions {
session.Lock()
keys := session.remoteClientKeys[clientKey]
session.Unlock()
if len(keys) > 0 {
return toDialer(session, clientKey, deadline), nil
}
}
}
return nil, fmt.Errorf("failed to find Session for client %s", clientKey)
}
func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *Session {
sessionKey := rand.Int63()
session := newSession(sessionKey, clientKey, conn)
sm.Lock()
defer sm.Unlock()
if peer {
sm.peers[clientKey] = append(sm.peers[clientKey], session)
} else {
sm.clients[clientKey] = append(sm.clients[clientKey], session)
}
for l := range sm.listeners {
l.sessionAdded(clientKey, session.sessionKey)
}
return session
}
func (sm *sessionManager) remove(s *Session) {
sm.Lock()
defer sm.Unlock()
for _, store := range []map[string][]*Session{sm.clients, sm.peers} {
var newSessions []*Session
for _, v := range store[s.clientKey] {
if v.sessionKey == s.sessionKey {
continue
}
newSessions = append(newSessions, v)
}
if len(newSessions) == 0 {
delete(store, s.clientKey)
} else {
store[s.clientKey] = newSessions
}
}
for l := range sm.listeners {
l.sessionRemoved(s.clientKey, s.sessionKey)
}
s.Close()
}

View File

@ -1,57 +0,0 @@
package remotedialer
import (
"context"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
func (s *Session) startPingsWhileWindows(rootCtx context.Context) {
ctx, cancel := context.WithCancel(rootCtx)
s.pingCancel = cancel
s.pingWait.Add(1)
go func() {
defer s.pingWait.Done()
t := time.NewTicker(PingWriteInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
s.conn.Lock()
if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil {
logrus.WithError(err).Error("Error writing ping")
}
logrus.Debug("Wrote ping")
s.conn.Unlock()
}
}
}()
}
func (s *Session) ServeWhileWindows(ctx context.Context) (int, error) {
if s.client {
s.startPingsWhileWindows(ctx)
}
for {
msType, reader, err := s.conn.NextReader()
if err != nil {
return 400, err
}
if msType != websocket.BinaryMessage {
return 400, errWrongMessageType
}
if err := s.serveMessage(reader); err != nil {
return 500, err
}
}
}

View File

@ -1,11 +0,0 @@
package remotedialer
import (
"time"
)
var (
PingWaitDuration = time.Duration(10 * time.Second)
PingWriteInterval = time.Duration(5 * time.Second)
MaxRead = 8192
)

View File

@ -1,47 +0,0 @@
package remotedialer
import (
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
type wsConn struct {
sync.Mutex
conn *websocket.Conn
}
func newWSConn(conn *websocket.Conn) *wsConn {
w := &wsConn{
conn: conn,
}
w.setupDeadline()
return w
}
func (w *wsConn) WriteMessage(messageType int, data []byte) error {
w.Lock()
defer w.Unlock()
w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration))
return w.conn.WriteMessage(messageType, data)
}
func (w *wsConn) NextReader() (int, io.Reader, error) {
return w.conn.NextReader()
}
func (w *wsConn) setupDeadline() {
w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
w.conn.SetPingHandler(func(string) error {
w.Lock()
w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second))
w.Unlock()
return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
})
w.conn.SetPongHandler(func(string) error {
return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
})
}