diff --git a/pkg/k8scheck/wait.go b/pkg/k8scheck/wait.go new file mode 100644 index 00000000..4fdb2bba --- /dev/null +++ b/pkg/k8scheck/wait.go @@ -0,0 +1,33 @@ +package k8scheck + +import ( + "context" + "fmt" + "time" + + "github.com/sirupsen/logrus" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +func Wait(ctx context.Context, config rest.Config) error { + client, err := kubernetes.NewForConfig(&config) + if err != nil { + return err + } + + for { + _, err := client.Discovery().ServerVersion() + if err == nil { + break + } + logrus.Infof("Waiting for server to become available: %v", err) + select { + case <-ctx.Done(): + return fmt.Errorf("startup canceled") + case <-time.After(2 * time.Second): + } + } + + return nil +} diff --git a/pkg/kwrapper/etcd/etcd.go b/pkg/kwrapper/etcd/etcd.go new file mode 100644 index 00000000..a94fc4f9 --- /dev/null +++ b/pkg/kwrapper/etcd/etcd.go @@ -0,0 +1,67 @@ +// +build !no_etcd + +package etcd + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "os" + "strings" + "time" + + "github.com/coreos/etcd/etcdmain" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +func RunETCD(ctx context.Context) ([]string, error) { + endpoint := "http://localhost:2379" + go runEtcd(ctx, []string{"--data-dir=./etcd"}) + + if err := checkEtcd(endpoint); err != nil { + return nil, errors.Wrap(err, "waiting on etcd") + } + + return []string{endpoint}, nil +} + +func checkEtcd(endpoint string) error { + ht := &http.Transport{} + client := http.Client{ + Transport: ht, + } + defer ht.CloseIdleConnections() + + for i := 0; ; i++ { + resp, err := client.Get(endpoint + "/health") + if err != nil { + if i > 1 { + logrus.Infof("Waiting on etcd startup: %v", err) + } + time.Sleep(time.Second) + continue + } + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + if i > 1 { + logrus.Infof("Waiting on etcd startup: status %d", resp.StatusCode) + } + time.Sleep(time.Second) + continue + } + + break + } + + return nil +} + +func runEtcd(ctx context.Context, args []string) { + os.Args = args + logrus.Info("Running ", strings.Join(args, " ")) + etcdmain.Main() + logrus.Errorf("etcd exited") +} diff --git a/pkg/kwrapper/etcd/etcd_none.go b/pkg/kwrapper/etcd/etcd_none.go new file mode 100644 index 00000000..9f749644 --- /dev/null +++ b/pkg/kwrapper/etcd/etcd_none.go @@ -0,0 +1,11 @@ +// +build no_etcd + +package etcd + +import ( + "context" +) + +func RunETCD(ctx context.Context) ([]string, error) { + return nil, nil +} diff --git a/pkg/kwrapper/k8s/config.go b/pkg/kwrapper/k8s/config.go new file mode 100644 index 00000000..13a3bfbc --- /dev/null +++ b/pkg/kwrapper/k8s/config.go @@ -0,0 +1,51 @@ +package k8s + +import ( + "context" + "fmt" + "os" + + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" +) + +func Auto(ctx context.Context) (bool, context.Context, *rest.Config, error) { + return GetConfig(ctx, "auto", os.Getenv("KUBECONFIG")) +} + +func GetConfig(ctx context.Context, k8sMode string, kubeConfig string) (bool, context.Context, *rest.Config, error) { + var ( + cfg *rest.Config + err error + ) + + switch k8sMode { + case "auto": + return getAuto(ctx, kubeConfig) + case "embedded": + return getEmbedded(ctx) + case "external": + cfg, err = getExternal(kubeConfig) + default: + return false, nil, nil, fmt.Errorf("invalid k8s-mode %s", k8sMode) + } + + return false, ctx, cfg, err +} + +func getAuto(ctx context.Context, kubeConfig string) (bool, context.Context, *rest.Config, error) { + if kubeConfig != "" { + cfg, err := getExternal(kubeConfig) + return false, ctx, cfg, err + } + + if config, err := rest.InClusterConfig(); err == nil { + return false, ctx, config, nil + } + + return getEmbedded(ctx) +} + +func getExternal(kubeConfig string) (*rest.Config, error) { + return clientcmd.BuildConfigFromFlags("", kubeConfig) +} diff --git a/pkg/kwrapper/k8s/config_k3s.go b/pkg/kwrapper/k8s/config_k3s.go new file mode 100644 index 00000000..244bdb78 --- /dev/null +++ b/pkg/kwrapper/k8s/config_k3s.go @@ -0,0 +1,37 @@ +package k8s + +import ( + "context" + "net" + "net/http" + + "github.com/rancher/norman/pkg/remotedialer" + "github.com/rancher/norman/pkg/resolvehome" + "k8s.io/kubernetes/pkg/wrapper/server" +) + +func NewK3sConfig(ctx context.Context, dataDir string, authorizer remotedialer.Authorizer) (context.Context, *server.ServerConfig, http.Handler, error) { + dataDir, err := resolvehome.Resolve(dataDir) + if err != nil { + return ctx, nil, nil, err + } + + listenIP := net.ParseIP("127.0.0.1") + _, clusterIPNet, _ := net.ParseCIDR("10.42.0.0/16") + _, serviceIPNet, _ := net.ParseCIDR("10.43.0.0/16") + + sc := &server.ServerConfig{ + AdvertiseIP: &listenIP, + AdvertisePort: 6444, + PublicHostname: "localhost", + ListenAddr: listenIP, + ListenPort: 6443, + ClusterIPRange: *clusterIPNet, + ServiceIPRange: *serviceIPNet, + UseTokenCA: true, + DataDir: dataDir, + } + + ctx = SetK3sConfig(ctx, sc) + return ctx, sc, newTunnel(authorizer), nil +} diff --git a/pkg/kwrapper/k8s/embedded_none.go b/pkg/kwrapper/k8s/embedded_none.go new file mode 100644 index 00000000..e9b20429 --- /dev/null +++ b/pkg/kwrapper/k8s/embedded_none.go @@ -0,0 +1,14 @@ +// +build !k3s + +package k8s + +import ( + "context" + "fmt" + + "k8s.io/client-go/rest" +) + +func getEmbedded(ctx context.Context) (bool, context.Context, *rest.Config, error) { + return false, ctx, nil, fmt.Errorf("embedded support is not compiled in, rebuild with -tags k8s") +} diff --git a/pkg/kwrapper/k8s/k3s.go b/pkg/kwrapper/k8s/k3s.go new file mode 100644 index 00000000..7096280b --- /dev/null +++ b/pkg/kwrapper/k8s/k3s.go @@ -0,0 +1,43 @@ +// +build k3s + +package k8s + +import ( + "context" + "os" + + "github.com/rancher/norman/pkg/kwrapper/etcd" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/kubernetes/pkg/wrapper/server" +) + +func getEmbedded(ctx context.Context) (bool, context.Context, *rest.Config, error) { + sc, ok := ctx.Value(serverConfig).(*server.ServerConfig) + if !ok { + ctx, sc, _, err = NewK3sConfig(ctx, "./k3s", nil) + if err != nil { + return false, ctx, nil, err + } + sc.NoScheduler = false + } + + if len(sc.ETCDEndpoints) == 0 { + etcdEndpoints, err := etcd.RunETCD(ctx) + if err != nil { + return ctx, nil, nil, err + } + sc.ETCDEndpoints = etcdEndpoints + } + + err := server.Server(ctx, sc) + if err != nil { + return false, ctx, nil, err + } + + os.Setenv("KUBECONFIG", sc.KubeConfig) + restConfig, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + &clientcmd.ClientConfigLoadingRules{ExplicitPath: sc.KubeConfig}, &clientcmd.ConfigOverrides{}).ClientConfig() + + return true, ctx, restConfig, err +} diff --git a/pkg/kwrapper/k8s/k3s_context.go b/pkg/kwrapper/k8s/k3s_context.go new file mode 100644 index 00000000..e1b26caa --- /dev/null +++ b/pkg/kwrapper/k8s/k3s_context.go @@ -0,0 +1,13 @@ +package k8s + +import ( + "context" +) + +var serverConfig configKey + +type configKey struct{} + +func SetK3sConfig(ctx context.Context, conf interface{}) context.Context { + return context.WithValue(ctx, serverConfig, conf) +} diff --git a/pkg/kwrapper/k8s/tunnel.go b/pkg/kwrapper/k8s/tunnel.go new file mode 100644 index 00000000..fffb998c --- /dev/null +++ b/pkg/kwrapper/k8s/tunnel.go @@ -0,0 +1,16 @@ +package k8s + +import ( + "net/http" + + "github.com/rancher/norman/pkg/remotedialer" +) + +func newTunnel(authorizer remotedialer.Authorizer) http.Handler { + if authorizer == nil { + return nil + } + server := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter) + setupK3s(server) + return server +} diff --git a/pkg/kwrapper/k8s/tunnel_k3s.go b/pkg/kwrapper/k8s/tunnel_k3s.go new file mode 100644 index 00000000..fc1b25aa --- /dev/null +++ b/pkg/kwrapper/k8s/tunnel_k3s.go @@ -0,0 +1,26 @@ +// +build k3s + +package k8s + +import ( + "context" + "net" + "time" + + "github.com/rancher/norman/pkg/kv" + "github.com/rancher/norman/pkg/remotedialer" + utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/kubernetes/cmd/kube-apiserver/app" +) + +func setupK3s(tunnelServer *remotedialer.Server) { + app.DefaultProxyDialerFn = utilnet.DialFunc(func(_ context.Context, network, address string) (net.Conn, error) { + _, port, _ := net.SplitHostPort(address) + addr := "127.0.0.1" + if port != "" { + addr += ":" + port + } + nodeName, _ := kv.Split(address, ":") + return tunnelServer.Dial(nodeName, 15*time.Second, "tcp", addr) + }) +} diff --git a/pkg/kwrapper/k8s/tunnel_none.go b/pkg/kwrapper/k8s/tunnel_none.go new file mode 100644 index 00000000..bb46acc7 --- /dev/null +++ b/pkg/kwrapper/k8s/tunnel_none.go @@ -0,0 +1,8 @@ +// +build !k3s + +package k8s + +import "github.com/rancher/norman/pkg/remotedialer" + +func setupK3s(tunnelServer *remotedialer.Server) { +} diff --git a/pkg/kwrapper/kubectl/main.go b/pkg/kwrapper/kubectl/main.go new file mode 100644 index 00000000..5201eae2 --- /dev/null +++ b/pkg/kwrapper/kubectl/main.go @@ -0,0 +1,39 @@ +package kubectl + +import ( + goflag "flag" + "fmt" + "math/rand" + "os" + "time" + + "github.com/docker/docker/pkg/reexec" + "github.com/spf13/pflag" + utilflag "k8s.io/apiserver/pkg/util/flag" + "k8s.io/apiserver/pkg/util/logs" + "k8s.io/kubernetes/pkg/kubectl/cmd" +) + +func init() { + reexec.Register("kubectl", Main) +} + +func Main() { + rand.Seed(time.Now().UTC().UnixNano()) + + command := cmd.NewDefaultKubectlCommand() + + // TODO: once we switch everything over to Cobra commands, we can go back to calling + // utilflag.InitFlags() (by removing its pflag.Parse() call). For now, we have to set the + // normalize func and add the go flag set by hand. + pflag.CommandLine.SetNormalizeFunc(utilflag.WordSepNormalizeFunc) + pflag.CommandLine.AddGoFlagSet(goflag.CommandLine) + // utilflag.InitFlags() + logs.InitLogs() + defer logs.FlushLogs() + + if err := command.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} diff --git a/pkg/remotedialer/client.go b/pkg/remotedialer/client.go new file mode 100644 index 00000000..fbae03ae --- /dev/null +++ b/pkg/remotedialer/client.go @@ -0,0 +1,47 @@ +package remotedialer + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" +) + +type ConnectAuthorizer func(proto, address string) bool + +func ClientConnect(wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, onConnect func(context.Context) error) { + if err := connectToProxy(wsURL, headers, auth, dialer, onConnect); err != nil { + logrus.WithError(err).Error("Failed to connect to proxy") + time.Sleep(time.Duration(5) * time.Second) + } +} + +func connectToProxy(proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context) error) error { + logrus.WithField("url", proxyURL).Info("Connecting to proxy") + + if dialer == nil { + dialer = &websocket.Dialer{} + } + ws, _, err := dialer.Dial(proxyURL, headers) + if err != nil { + logrus.WithError(err).Error("Failed to connect to proxy") + return err + } + defer ws.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if onConnect != nil { + if err := onConnect(ctx); err != nil { + return err + } + } + + session := newClientSession(auth, ws) + _, err = session.serve() + session.Close() + return err +} diff --git a/pkg/remotedialer/client/main.go b/pkg/remotedialer/client/main.go new file mode 100644 index 00000000..9685e475 --- /dev/null +++ b/pkg/remotedialer/client/main.go @@ -0,0 +1,34 @@ +// +build !windows + +package main + +import ( + "flag" + "net/http" + + "github.com/rancher/norman/pkg/remotedialer" + "github.com/sirupsen/logrus" +) + +var ( + addr string + id string + debug bool +) + +func main() { + flag.StringVar(&addr, "connect", "ws://localhost:8123/connect", "Address to connect to") + flag.StringVar(&id, "id", "foo", "Client ID") + flag.BoolVar(&debug, "debug", true, "Debug logging") + flag.Parse() + + if debug { + logrus.SetLevel(logrus.DebugLevel) + } + + headers := http.Header{ + "X-Tunnel-ID": []string{id}, + } + + remotedialer.ClientConnect(addr, headers, nil, func(string, string) bool { return true }, nil) +} diff --git a/pkg/remotedialer/client_dialer.go b/pkg/remotedialer/client_dialer.go new file mode 100644 index 00000000..1ab6615a --- /dev/null +++ b/pkg/remotedialer/client_dialer.go @@ -0,0 +1,58 @@ +package remotedialer + +import ( + "io" + "net" + "sync" + "time" +) + +func clientDial(dialer Dialer, conn *connection, message *message) { + defer conn.Close() + + var ( + netConn net.Conn + err error + ) + + if dialer == nil { + netConn, err = net.DialTimeout(message.proto, message.address, time.Duration(message.deadline)*time.Millisecond) + } else { + netConn, err = dialer(message.proto, message.address) + } + + if err != nil { + conn.tunnelClose(err) + return + } + defer netConn.Close() + + pipe(conn, netConn) +} + +func pipe(client *connection, server net.Conn) { + wg := sync.WaitGroup{} + wg.Add(1) + + close := func(err error) error { + if err == nil { + err = io.EOF + } + client.doTunnelClose(err) + server.Close() + return err + } + + go func() { + defer wg.Done() + _, err := io.Copy(server, client) + close(err) + }() + + _, err := io.Copy(client, server) + err = close(err) + wg.Wait() + + // Write tunnel error after no more I/O is happening, just incase messages get out of order + client.writeErr(err) +} diff --git a/pkg/remotedialer/client_windows.go b/pkg/remotedialer/client_windows.go new file mode 100644 index 00000000..b64a82de --- /dev/null +++ b/pkg/remotedialer/client_windows.go @@ -0,0 +1,89 @@ +package remotedialer + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/rancher/rancher/pkg/rkenodeconfigclient" + "github.com/rancher/rancher/pkg/rkeworker" + "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" +) + +func ClientConnectWhileWindows(ctx context.Context, wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, blockingOnConnect func(context.Context) error) int64 { + if err := connectToProxyWhileWindows(ctx, wsURL, headers, auth, dialer, blockingOnConnect); err != nil { + errMsg := err.Error() + + switch err { + case websocket.ErrBadHandshake: + return 403 + case rkeworker.ErrHyperKubePSScriptAgentRetry: + logrus.Warn("This connection try to touch proxy again: ", errMsg) + return 302 + default: + if e, ok := err.(*rkenodeconfigclient.ErrNodeOrClusterNotFound); ok { + logrus.Warn("Can't connect to the registered " + e.ErrorOccursType() + ", terminating gracefully") + return 503 + } + + logrus.Error("Failed to connect to proxy: ", errMsg) + } + + return 500 + } + + return 200 +} + +func connectToProxyWhileWindows(rootContext context.Context, proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, blockingOnConnect func(context.Context) error) error { + if dialer == nil { + dialer = &websocket.Dialer{} + } + + eg, ctx := errgroup.WithContext(rootContext) + + if blockingOnConnect != nil { + eg.Go(func() error { + return blockingOnConnect(ctx) + }) + } + + eg.Go(func() error { + reconnectCount := 0 + + for { + err := func() error { + ws, _, err := dialer.Dial(proxyURL, headers) + if err != nil { + return err + } + defer ws.Close() + + session := newClientSession(auth, ws) + _, err = session.serveWhileWindows(ctx) + session.Close() + return err + }() + if err != nil { + if reconnectCount < 10 { + errMsg := err.Error() + if strings.HasSuffix(errMsg, "An existing connection was forcibly closed by the remote host.") || + strings.HasSuffix(errMsg, "An established connection was aborted by the software in your host machine.") || + strings.HasSuffix(errMsg, "A socket operation was attempted to an unreachable network.") { + time.Sleep(5 * time.Second) + + reconnectCount += 1 + continue + } + } + } + + return err + } + }) + + return eg.Wait() +} diff --git a/pkg/remotedialer/connection.go b/pkg/remotedialer/connection.go new file mode 100644 index 00000000..0888ce70 --- /dev/null +++ b/pkg/remotedialer/connection.go @@ -0,0 +1,188 @@ +package remotedialer + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" +) + +type connection struct { + sync.Mutex + + ctx context.Context + cancel func() + err error + writeDeadline time.Time + buf chan []byte + readBuf []byte + addr addr + session *session + connID int64 +} + +func newConnection(connID int64, session *session, proto, address string) *connection { + c := &connection{ + addr: addr{ + proto: proto, + address: address, + }, + connID: connID, + session: session, + buf: make(chan []byte, 1024), + } + return c +} + +func (c *connection) tunnelClose(err error) { + c.writeErr(err) + c.doTunnelClose(err) +} + +func (c *connection) doTunnelClose(err error) { + c.Lock() + defer c.Unlock() + + if c.err != nil { + return + } + + c.err = err + if c.err == nil { + c.err = io.ErrClosedPipe + } + + close(c.buf) +} + +func (c *connection) tunnelWriter() io.Writer { + return chanWriter{conn: c, C: c.buf} +} + +func (c *connection) Close() error { + c.session.closeConnection(c.connID, io.EOF) + return nil +} + +func (c *connection) copyData(b []byte) int { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n +} + +func (c *connection) Read(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + + n := c.copyData(b) + if n > 0 { + return n, nil + } + + next, ok := <-c.buf + if !ok { + err := io.EOF + c.Lock() + if c.err != nil { + err = c.err + } + c.Unlock() + return 0, err + } + + c.readBuf = next + n = c.copyData(b) + return n, nil +} + +func (c *connection) Write(b []byte) (int, error) { + c.Lock() + if c.err != nil { + defer c.Unlock() + return 0, c.err + } + c.Unlock() + + deadline := int64(0) + if !c.writeDeadline.IsZero() { + deadline = c.writeDeadline.Sub(time.Now()).Nanoseconds() / 1000000 + } + return c.session.writeMessage(newMessage(c.connID, deadline, b)) +} + +func (c *connection) writeErr(err error) { + if err != nil { + c.session.writeMessage(newErrorMessage(c.connID, err)) + } +} + +func (c *connection) LocalAddr() net.Addr { + return c.addr +} + +func (c *connection) RemoteAddr() net.Addr { + return c.addr +} + +func (c *connection) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +func (c *connection) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *connection) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +type addr struct { + proto string + address string +} + +func (a addr) Network() string { + return a.proto +} + +func (a addr) String() string { + return a.address +} + +type chanWriter struct { + conn *connection + C chan []byte +} + +func (c chanWriter) Write(buf []byte) (int, error) { + c.conn.Lock() + defer c.conn.Unlock() + + if c.conn.err != nil { + return 0, c.conn.err + } + + newBuf := make([]byte, len(buf)) + copy(newBuf, buf) + buf = newBuf + + select { + // must copy the buffer + case c.C <- buf: + return len(buf), nil + default: + select { + case c.C <- buf: + return len(buf), nil + case <-time.After(15 * time.Second): + return 0, errors.New("backed up reader") + } + } +} diff --git a/pkg/remotedialer/dialer.go b/pkg/remotedialer/dialer.go new file mode 100644 index 00000000..b2a8350e --- /dev/null +++ b/pkg/remotedialer/dialer.go @@ -0,0 +1,28 @@ +package remotedialer + +import ( + "net" + "time" +) + +type Dialer func(network, address string) (net.Conn, error) + +func (s *Server) HasSession(clientKey string) bool { + _, err := s.sessions.getDialer(clientKey, 0) + return err == nil +} + +func (s *Server) Dial(clientKey string, deadline time.Duration, proto, address string) (net.Conn, error) { + d, err := s.sessions.getDialer(clientKey, deadline) + if err != nil { + return nil, err + } + + return d(proto, address) +} + +func (s *Server) Dialer(clientKey string, deadline time.Duration) Dialer { + return func(proto, address string) (net.Conn, error) { + return s.Dial(clientKey, deadline, proto, address) + } +} diff --git a/pkg/remotedialer/dummy/main.go b/pkg/remotedialer/dummy/main.go new file mode 100644 index 00000000..71313949 --- /dev/null +++ b/pkg/remotedialer/dummy/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "net/http" + "sync/atomic" + "time" +) + +var ( + counter int64 + listen string +) + +func main() { + flag.StringVar(&listen, "listen", ":8125", "Listen address") + flag.Parse() + + fmt.Println("listening ", listen) + http.ListenAndServe(listen, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + next := atomic.AddInt64(&counter, 1) + fmt.Println("request", next) + + time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond) + rw.Write([]byte("HI")) + })) +} diff --git a/pkg/remotedialer/message.go b/pkg/remotedialer/message.go new file mode 100644 index 00000000..0572ee5b --- /dev/null +++ b/pkg/remotedialer/message.go @@ -0,0 +1,220 @@ +package remotedialer + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "strings" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" +) + +const ( + Data messageType = iota + 1 + Connect + Error + AddClient + RemoveClient +) + +var ( + idCounter int64 +) + +func init() { + r := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + idCounter = r.Int63() +} + +type messageType int64 + +type message struct { + id int64 + err error + connID int64 + deadline int64 + messageType messageType + bytes []byte + body io.Reader + proto string + address string +} + +func nextid() int64 { + return atomic.AddInt64(&idCounter, 1) +} + +func newMessage(connID int64, deadline int64, bytes []byte) *message { + return &message{ + id: nextid(), + connID: connID, + deadline: deadline, + messageType: Data, + bytes: bytes, + } +} + +func newConnect(connID int64, deadline time.Duration, proto, address string) *message { + return &message{ + id: nextid(), + connID: connID, + deadline: deadline.Nanoseconds() / 1000000, + messageType: Connect, + bytes: []byte(fmt.Sprintf("%s/%s", proto, address)), + proto: proto, + address: address, + } +} + +func newErrorMessage(connID int64, err error) *message { + return &message{ + id: nextid(), + err: err, + connID: connID, + messageType: Error, + bytes: []byte(err.Error()), + } +} + +func newAddClient(client string) *message { + return &message{ + id: nextid(), + messageType: AddClient, + address: client, + bytes: []byte(client), + } +} + +func newRemoveClient(client string) *message { + return &message{ + id: nextid(), + messageType: RemoveClient, + address: client, + bytes: []byte(client), + } +} + +func newServerMessage(reader io.Reader) (*message, error) { + buf := bufio.NewReader(reader) + + id, err := binary.ReadVarint(buf) + if err != nil { + return nil, err + } + + connID, err := binary.ReadVarint(buf) + if err != nil { + return nil, err + } + + mType, err := binary.ReadVarint(buf) + if err != nil { + return nil, err + } + + m := &message{ + id: id, + messageType: messageType(mType), + connID: connID, + body: buf, + } + + if m.messageType == Data || m.messageType == Connect { + deadline, err := binary.ReadVarint(buf) + if err != nil { + return nil, err + } + m.deadline = deadline + } + + if m.messageType == Connect { + bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100)) + if err != nil { + return nil, err + } + parts := strings.SplitN(string(bytes), "/", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("failed to parse connect address") + } + m.proto = parts[0] + m.address = parts[1] + m.bytes = bytes + } else if m.messageType == AddClient || m.messageType == RemoveClient { + bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100)) + if err != nil { + return nil, err + } + m.address = string(bytes) + m.bytes = bytes + } + + return m, nil +} + +func (m *message) Err() error { + if m.err != nil { + return m.err + } + bytes, err := ioutil.ReadAll(io.LimitReader(m.body, 100)) + if err != nil { + return err + } + + str := string(bytes) + if str == "EOF" { + m.err = io.EOF + } else { + m.err = errors.New(str) + } + return m.err +} + +func (m *message) Bytes() []byte { + return append(m.header(), m.bytes...) +} + +func (m *message) header() []byte { + buf := make([]byte, 24) + offset := 0 + offset += binary.PutVarint(buf[offset:], m.id) + offset += binary.PutVarint(buf[offset:], m.connID) + offset += binary.PutVarint(buf[offset:], int64(m.messageType)) + if m.messageType == Data || m.messageType == Connect { + offset += binary.PutVarint(buf[offset:], m.deadline) + } + return buf[:offset] +} + +func (m *message) Read(p []byte) (int, error) { + return m.body.Read(p) +} + +func (m *message) WriteTo(wsConn *wsConn) (int, error) { + err := wsConn.WriteMessage(websocket.BinaryMessage, m.Bytes()) + return len(m.bytes), err +} + +func (m *message) String() string { + switch m.messageType { + case Data: + if m.body == nil { + return fmt.Sprintf("%d DATA [%d]: %d bytes: %s", m.id, m.connID, len(m.bytes), string(m.bytes)) + } + return fmt.Sprintf("%d DATA [%d]: buffered", m.id, m.connID) + case Error: + return fmt.Sprintf("%d ERROR [%d]: %s", m.id, m.connID, m.Err()) + case Connect: + return fmt.Sprintf("%d CONNECT [%d]: %s/%s deadline %d", m.id, m.connID, m.proto, m.address, m.deadline) + case AddClient: + return fmt.Sprintf("%d ADDCLIENT [%s]", m.id, m.address) + case RemoveClient: + return fmt.Sprintf("%d REMOVECLIENT [%s]", m.id, m.address) + } + return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType) +} diff --git a/pkg/remotedialer/peer.go b/pkg/remotedialer/peer.go new file mode 100644 index 00000000..160ab008 --- /dev/null +++ b/pkg/remotedialer/peer.go @@ -0,0 +1,120 @@ +package remotedialer + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" +) + +var ( + Token = "X-API-Tunnel-Token" + ID = "X-API-Tunnel-ID" +) + +func (s *Server) AddPeer(url, id, token string) { + if s.PeerID == "" || s.PeerToken == "" { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + peer := peer{ + url: url, + id: id, + token: token, + cancel: cancel, + } + + logrus.Infof("Adding peer %s, %s", url, id) + + s.peerLock.Lock() + defer s.peerLock.Unlock() + + if p, ok := s.peers[id]; ok { + if p.equals(peer) { + return + } + p.cancel() + } + + s.peers[id] = peer + go peer.start(ctx, s) +} + +func (s *Server) RemovePeer(id string) { + s.peerLock.Lock() + defer s.peerLock.Unlock() + + if p, ok := s.peers[id]; ok { + logrus.Infof("Removing peer %s", id) + p.cancel() + } + delete(s.peers, id) +} + +type peer struct { + url, id, token string + cancel func() +} + +func (p peer) equals(other peer) bool { + return p.url == other.url && + p.id == other.id && + p.token == other.token +} + +func (p *peer) start(ctx context.Context, s *Server) { + headers := http.Header{ + ID: {s.PeerID}, + Token: {s.PeerToken}, + } + + dialer := &websocket.Dialer{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + +outer: + for { + select { + case <-ctx.Done(): + break outer + default: + } + + ws, _, err := dialer.Dial(p.url, headers) + if err != nil { + logrus.Errorf("Failed to connect to peer %s [local ID=%s]: %v", p.url, s.PeerID, err) + time.Sleep(5 * time.Second) + continue + } + + session := newClientSession(func(string, string) bool { return true }, ws) + session.dialer = func(network, address string) (net.Conn, error) { + parts := strings.SplitN(network, "::", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid clientKey/proto: %s", network) + } + return s.Dial(parts[0], 15*time.Second, parts[1], address) + } + + s.sessions.addListener(session) + _, err = session.serve() + s.sessions.removeListener(session) + session.Close() + + if err != nil { + logrus.Errorf("Failed to serve peer connection %s: %v", p.id, err) + } + + ws.Close() + time.Sleep(5 * time.Second) + } +} diff --git a/pkg/remotedialer/server.go b/pkg/remotedialer/server.go new file mode 100644 index 00000000..e459cf39 --- /dev/null +++ b/pkg/remotedialer/server.go @@ -0,0 +1,97 @@ +package remotedialer + +import ( + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +var ( + errFailedAuth = errors.New("failed authentication") + errWrongMessageType = errors.New("wrong websocket message type") +) + +type Authorizer func(req *http.Request) (clientKey string, authed bool, err error) +type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error) + +func DefaultErrorWriter(rw http.ResponseWriter, req *http.Request, code int, err error) { + rw.Write([]byte(err.Error())) + rw.WriteHeader(code) +} + +type Server struct { + PeerID string + PeerToken string + authorizer Authorizer + errorWriter ErrorWriter + sessions *sessionManager + peers map[string]peer + peerLock sync.Mutex +} + +func New(auth Authorizer, errorWriter ErrorWriter) *Server { + return &Server{ + peers: map[string]peer{}, + authorizer: auth, + errorWriter: errorWriter, + sessions: newSessionManager(), + } +} + +func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + clientKey, authed, peer, err := s.auth(req) + if err != nil { + s.errorWriter(rw, req, 400, err) + return + } + if !authed { + s.errorWriter(rw, req, 401, errFailedAuth) + return + } + + logrus.Infof("Handling backend connection request [%s]", clientKey) + + upgrader := websocket.Upgrader{ + HandshakeTimeout: 5 * time.Second, + CheckOrigin: func(r *http.Request) bool { return true }, + Error: s.errorWriter, + } + + wsConn, err := upgrader.Upgrade(rw, req, nil) + if err != nil { + s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey)) + return + } + + session := s.sessions.add(clientKey, wsConn, peer) + defer s.sessions.remove(session) + + // Don't need to associate req.Context() to the session, it will cancel otherwise + code, err := session.serve() + if err != nil { + // Hijacked so we can't write to the client + logrus.Infof("error in remotedialer server [%d]: %v", code, err) + } +} + +func (s *Server) auth(req *http.Request) (clientKey string, authed, peer bool, err error) { + id := req.Header.Get(ID) + token := req.Header.Get(Token) + if id != "" && token != "" { + // peer authentication + s.peerLock.Lock() + p, ok := s.peers[id] + s.peerLock.Unlock() + + if ok && p.token == token { + return id, true, true, nil + } + } + + id, authed, err = s.authorizer(req) + return id, authed, false, err +} diff --git a/pkg/remotedialer/server/main.go b/pkg/remotedialer/server/main.go new file mode 100644 index 00000000..8db53272 --- /dev/null +++ b/pkg/remotedialer/server/main.go @@ -0,0 +1,125 @@ +package main + +import ( + "flag" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/mux" + "github.com/rancher/norman/pkg/remotedialer" + "github.com/sirupsen/logrus" +) + +var ( + clients = map[string]*http.Client{} + l sync.Mutex + counter int64 +) + +func authorizer(req *http.Request) (string, bool, error) { + id := req.Header.Get("x-tunnel-id") + return id, id != "", nil +} + +func Client(server *remotedialer.Server, rw http.ResponseWriter, req *http.Request) { + timeout := req.URL.Query().Get("timeout") + if timeout == "" { + timeout = "15" + } + + vars := mux.Vars(req) + clientKey := vars["id"] + url := fmt.Sprintf("%s://%s%s", vars["scheme"], vars["host"], vars["path"]) + client := getClient(server, clientKey, timeout) + + id := atomic.AddInt64(&counter, 1) + logrus.Infof("[%03d] REQ t=%s %s", id, timeout, url) + + resp, err := client.Get(url) + if err != nil { + logrus.Errorf("[%03d] REQ ERR t=%s %s: %v", id, timeout, url, err) + remotedialer.DefaultErrorWriter(rw, req, 500, err) + return + } + defer resp.Body.Close() + + logrus.Infof("[%03d] REQ OK t=%s %s", id, timeout, url) + rw.WriteHeader(resp.StatusCode) + io.Copy(rw, resp.Body) + logrus.Infof("[%03d] REQ DONE t=%s %s", id, timeout, url) +} + +func getClient(server *remotedialer.Server, clientKey, timeout string) *http.Client { + l.Lock() + defer l.Unlock() + + key := fmt.Sprintf("%s/%s", clientKey, timeout) + client := clients[key] + if client != nil { + return client + } + + dialer := server.Dialer(clientKey, 15*time.Second) + client = &http.Client{ + Transport: &http.Transport{ + Dial: dialer, + }, + } + if timeout != "" { + t, err := strconv.Atoi(timeout) + if err == nil { + client.Timeout = time.Duration(t) * time.Second + } + } + + clients[key] = client + return client +} + +func main() { + var ( + addr string + peerID string + peerToken string + peers string + debug bool + ) + flag.StringVar(&addr, "listen", ":8123", "Listen address") + flag.StringVar(&peerID, "id", "", "Peer ID") + flag.StringVar(&peerToken, "token", "", "Peer Token") + flag.StringVar(&peers, "peers", "", "Peers format id:token:url,id:token:url") + flag.BoolVar(&debug, "debug", false, "Enable debug logging") + flag.Parse() + + if debug { + logrus.SetLevel(logrus.DebugLevel) + remotedialer.PrintTunnelData = true + } + + handler := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter) + handler.PeerToken = peerToken + handler.PeerID = peerID + + for _, peer := range strings.Split(peers, ",") { + parts := strings.SplitN(strings.TrimSpace(peer), ":", 3) + if len(parts) != 3 { + continue + } + handler.AddPeer(parts[2], parts[0], parts[1]) + } + + router := mux.NewRouter() + router.Handle("/connect", handler) + router.HandleFunc("/client/{id}/{scheme}/{host}{path:.*}", func(rw http.ResponseWriter, req *http.Request) { + Client(handler, rw, req) + }) + + fmt.Println("Listening on ", addr) + http.ListenAndServe(addr, router) +} diff --git a/pkg/remotedialer/session.go b/pkg/remotedialer/session.go new file mode 100644 index 00000000..a0969eda --- /dev/null +++ b/pkg/remotedialer/session.go @@ -0,0 +1,303 @@ +package remotedialer + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" +) + +type session struct { + sync.Mutex + + nextConnID int64 + clientKey string + sessionKey int64 + conn *wsConn + conns map[int64]*connection + remoteClientKeys map[string]map[int]bool + auth ConnectAuthorizer + pingCancel context.CancelFunc + pingWait sync.WaitGroup + dialer Dialer + client bool +} + +// PrintTunnelData No tunnel logging by default +var PrintTunnelData bool + +func init() { + if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" { + PrintTunnelData = true + } +} + +func newClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *session { + return &session{ + clientKey: "client", + conn: newWSConn(conn), + conns: map[int64]*connection{}, + auth: auth, + client: true, + } +} + +func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *session { + return &session{ + nextConnID: 1, + clientKey: clientKey, + sessionKey: sessionKey, + conn: newWSConn(conn), + conns: map[int64]*connection{}, + remoteClientKeys: map[string]map[int]bool{}, + } +} + +func (s *session) startPings() { + ctx, cancel := context.WithCancel(context.Background()) + s.pingCancel = cancel + s.pingWait.Add(1) + + go func() { + defer s.pingWait.Done() + + t := time.NewTicker(PingWriteInterval) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + s.conn.Lock() + if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil { + logrus.WithError(err).Error("Error writing ping") + } + logrus.Debug("Wrote ping") + s.conn.Unlock() + } + } + }() +} + +func (s *session) stopPings() { + if s.pingCancel == nil { + return + } + + s.pingCancel() + s.pingWait.Wait() +} + +func (s *session) serve() (int, error) { + if s.client { + s.startPings() + } + + for { + msType, reader, err := s.conn.NextReader() + if err != nil { + return 400, err + } + + if msType != websocket.BinaryMessage { + return 400, errWrongMessageType + } + + if err := s.serveMessage(reader); err != nil { + return 500, err + } + } +} + +func (s *session) serveMessage(reader io.Reader) error { + message, err := newServerMessage(reader) + if err != nil { + return err + } + + if PrintTunnelData { + logrus.Debug("REQUEST ", message) + } + + if message.messageType == Connect { + if s.auth == nil || !s.auth(message.proto, message.address) { + return errors.New("connect not allowed") + } + s.clientConnect(message) + return nil + } + + s.Lock() + if message.messageType == AddClient && s.remoteClientKeys != nil { + err := s.addRemoteClient(message.address) + s.Unlock() + return err + } else if message.messageType == RemoveClient { + err := s.removeRemoteClient(message.address) + s.Unlock() + return err + } + conn := s.conns[message.connID] + s.Unlock() + + if conn == nil { + if message.messageType == Data { + err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID) + newErrorMessage(message.connID, err).WriteTo(s.conn) + } + return nil + } + + switch message.messageType { + case Data: + if _, err := io.Copy(conn.tunnelWriter(), message); err != nil { + s.closeConnection(message.connID, err) + } + case Error: + s.closeConnection(message.connID, message.Err()) + } + + return nil +} + +func parseAddress(address string) (string, int, error) { + parts := strings.SplitN(address, "/", 2) + if len(parts) != 2 { + return "", 0, errors.New("not / separated") + } + v, err := strconv.Atoi(parts[1]) + return parts[0], v, err +} + +func (s *session) addRemoteClient(address string) error { + clientKey, sessionKey, err := parseAddress(address) + if err != nil { + return fmt.Errorf("invalid remote session %s: %v", address, err) + } + + keys := s.remoteClientKeys[clientKey] + if keys == nil { + keys = map[int]bool{} + s.remoteClientKeys[clientKey] = keys + } + keys[int(sessionKey)] = true + + if PrintTunnelData { + logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) + } + + return nil +} + +func (s *session) removeRemoteClient(address string) error { + clientKey, sessionKey, err := parseAddress(address) + if err != nil { + return fmt.Errorf("invalid remote session %s: %v", address, err) + } + + keys := s.remoteClientKeys[clientKey] + delete(keys, int(sessionKey)) + if len(keys) == 0 { + delete(s.remoteClientKeys, clientKey) + } + + if PrintTunnelData { + logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey) + } + + return nil +} + +func (s *session) closeConnection(connID int64, err error) { + s.Lock() + conn := s.conns[connID] + delete(s.conns, connID) + if PrintTunnelData { + logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) + } + s.Unlock() + + if conn != nil { + conn.tunnelClose(err) + } +} + +func (s *session) clientConnect(message *message) { + conn := newConnection(message.connID, s, message.proto, message.address) + + s.Lock() + s.conns[message.connID] = conn + if PrintTunnelData { + logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) + } + s.Unlock() + + go clientDial(s.dialer, conn, message) +} + +func (s *session) serverConnect(deadline time.Duration, proto, address string) (net.Conn, error) { + connID := atomic.AddInt64(&s.nextConnID, 1) + conn := newConnection(connID, s, proto, address) + + s.Lock() + s.conns[connID] = conn + if PrintTunnelData { + logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns)) + } + s.Unlock() + + _, err := s.writeMessage(newConnect(connID, deadline, proto, address)) + if err != nil { + s.closeConnection(connID, err) + return nil, err + } + + return conn, err +} + +func (s *session) writeMessage(message *message) (int, error) { + if PrintTunnelData { + logrus.Debug("WRITE ", message) + } + return message.WriteTo(s.conn) +} + +func (s *session) Close() { + s.Lock() + defer s.Unlock() + + s.stopPings() + + for _, connection := range s.conns { + connection.tunnelClose(errors.New("tunnel disconnect")) + } + + s.conns = map[int64]*connection{} +} + +func (s *session) sessionAdded(clientKey string, sessionKey int64) { + client := fmt.Sprintf("%s/%d", clientKey, sessionKey) + _, err := s.writeMessage(newAddClient(client)) + if err != nil { + s.conn.conn.Close() + } +} + +func (s *session) sessionRemoved(clientKey string, sessionKey int64) { + client := fmt.Sprintf("%s/%d", clientKey, sessionKey) + _, err := s.writeMessage(newRemoveClient(client)) + if err != nil { + s.conn.conn.Close() + } +} diff --git a/pkg/remotedialer/session_manager.go b/pkg/remotedialer/session_manager.go new file mode 100644 index 00000000..958026ea --- /dev/null +++ b/pkg/remotedialer/session_manager.go @@ -0,0 +1,137 @@ +package remotedialer + +import ( + "fmt" + "math/rand" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type sessionListener interface { + sessionAdded(clientKey string, sessionKey int64) + sessionRemoved(clientKey string, sessionKey int64) +} + +type sessionManager struct { + sync.Mutex + clients map[string][]*session + peers map[string][]*session + listeners map[sessionListener]bool +} + +func newSessionManager() *sessionManager { + return &sessionManager{ + clients: map[string][]*session{}, + peers: map[string][]*session{}, + listeners: map[sessionListener]bool{}, + } +} + +func toDialer(s *session, prefix string, deadline time.Duration) Dialer { + return func(proto, address string) (net.Conn, error) { + if prefix == "" { + return s.serverConnect(deadline, proto, address) + } + return s.serverConnect(deadline, prefix+"::"+proto, address) + } +} + +func (sm *sessionManager) removeListener(listener sessionListener) { + sm.Lock() + defer sm.Unlock() + + delete(sm.listeners, listener) +} + +func (sm *sessionManager) addListener(listener sessionListener) { + sm.Lock() + defer sm.Unlock() + + sm.listeners[listener] = true + + for k, sessions := range sm.clients { + for _, session := range sessions { + listener.sessionAdded(k, session.sessionKey) + } + } + + for k, sessions := range sm.peers { + for _, session := range sessions { + listener.sessionAdded(k, session.sessionKey) + } + } +} + +func (sm *sessionManager) getDialer(clientKey string, deadline time.Duration) (Dialer, error) { + sm.Lock() + defer sm.Unlock() + + sessions := sm.clients[clientKey] + if len(sessions) > 0 { + return toDialer(sessions[0], "", deadline), nil + } + + for _, sessions := range sm.peers { + for _, session := range sessions { + session.Lock() + keys := session.remoteClientKeys[clientKey] + session.Unlock() + if len(keys) > 0 { + return toDialer(session, clientKey, deadline), nil + } + } + } + + return nil, fmt.Errorf("failed to find session for client %s", clientKey) +} + +func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *session { + sessionKey := rand.Int63() + session := newSession(sessionKey, clientKey, conn) + + sm.Lock() + defer sm.Unlock() + + if peer { + sm.peers[clientKey] = append(sm.peers[clientKey], session) + } else { + sm.clients[clientKey] = append(sm.clients[clientKey], session) + } + + for l := range sm.listeners { + l.sessionAdded(clientKey, session.sessionKey) + } + + return session +} + +func (sm *sessionManager) remove(s *session) { + sm.Lock() + defer sm.Unlock() + + for _, store := range []map[string][]*session{sm.clients, sm.peers} { + var newSessions []*session + + for _, v := range store[s.clientKey] { + if v.sessionKey == s.sessionKey { + continue + } + newSessions = append(newSessions, v) + } + + if len(newSessions) == 0 { + delete(store, s.clientKey) + } else { + store[s.clientKey] = newSessions + } + } + + for l := range sm.listeners { + l.sessionRemoved(s.clientKey, s.sessionKey) + } + + s.Close() +} diff --git a/pkg/remotedialer/session_windows.go b/pkg/remotedialer/session_windows.go new file mode 100644 index 00000000..7d3316cb --- /dev/null +++ b/pkg/remotedialer/session_windows.go @@ -0,0 +1,57 @@ +package remotedialer + +import ( + "context" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" +) + +func (s *session) startPingsWhileWindows(rootCtx context.Context) { + ctx, cancel := context.WithCancel(rootCtx) + s.pingCancel = cancel + s.pingWait.Add(1) + + go func() { + defer s.pingWait.Done() + + t := time.NewTicker(PingWriteInterval) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + s.conn.Lock() + if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil { + logrus.WithError(err).Error("Error writing ping") + } + logrus.Debug("Wrote ping") + s.conn.Unlock() + } + } + }() +} + +func (s *session) serveWhileWindows(ctx context.Context) (int, error) { + if s.client { + s.startPingsWhileWindows(ctx) + } + + for { + msType, reader, err := s.conn.NextReader() + if err != nil { + return 400, err + } + + if msType != websocket.BinaryMessage { + return 400, errWrongMessageType + } + + if err := s.serveMessage(reader); err != nil { + return 500, err + } + } +} diff --git a/pkg/remotedialer/types.go b/pkg/remotedialer/types.go new file mode 100644 index 00000000..f1056cd9 --- /dev/null +++ b/pkg/remotedialer/types.go @@ -0,0 +1,11 @@ +package remotedialer + +import ( + "time" +) + +var ( + PingWaitDuration = time.Duration(10 * time.Second) + PingWriteInterval = time.Duration(5 * time.Second) + MaxRead = 8192 +) diff --git a/pkg/remotedialer/wsconn.go b/pkg/remotedialer/wsconn.go new file mode 100644 index 00000000..3e519e5d --- /dev/null +++ b/pkg/remotedialer/wsconn.go @@ -0,0 +1,47 @@ +package remotedialer + +import ( + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type wsConn struct { + sync.Mutex + conn *websocket.Conn +} + +func newWSConn(conn *websocket.Conn) *wsConn { + w := &wsConn{ + conn: conn, + } + w.setupDeadline() + return w +} + +func (w *wsConn) WriteMessage(messageType int, data []byte) error { + w.Lock() + defer w.Unlock() + w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration)) + return w.conn.WriteMessage(messageType, data) +} + +func (w *wsConn) NextReader() (int, io.Reader, error) { + return w.conn.NextReader() +} + +func (w *wsConn) setupDeadline() { + w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) + w.conn.SetPingHandler(func(string) error { + w.Lock() + w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second)) + w.Unlock() + return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) + }) + w.conn.SetPongHandler(func(string) error { + return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) + }) + +} diff --git a/pkg/resolvehome/home.go b/pkg/resolvehome/home.go new file mode 100644 index 00000000..cf189d12 --- /dev/null +++ b/pkg/resolvehome/home.go @@ -0,0 +1,39 @@ +package resolvehome + +import ( + "os" + "os/user" + "strings" + + "github.com/pkg/errors" +) + +var ( + homes = []string{"$HOME", "${HOME}", "~"} +) + +func Resolve(s string) (string, error) { + for _, home := range homes { + if strings.Contains(s, home) { + homeDir, err := getHomeDir() + if err != nil { + return "", errors.Wrap(err, "determining current user") + } + s = strings.Replace(s, home, homeDir, -1) + } + } + + return s, nil +} + +func getHomeDir() (string, error) { + if os.Getuid() == 0 { + return "/root", nil + } + + u, err := user.Current() + if err != nil { + return "", errors.Wrap(err, "determining current user, try set HOME and USER env vars") + } + return u.HomeDir, nil +}