mirror of
https://github.com/rancher/norman.git
synced 2025-09-17 15:49:53 +00:00
Move packages from rancher/rancher to norman
This commit is contained in:
33
pkg/k8scheck/wait.go
Normal file
33
pkg/k8scheck/wait.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package k8scheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"k8s.io/client-go/kubernetes"
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Wait(ctx context.Context, config rest.Config) error {
|
||||||
|
client, err := kubernetes.NewForConfig(&config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, err := client.Discovery().ServerVersion()
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
logrus.Infof("Waiting for server to become available: %v", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("startup canceled")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
67
pkg/kwrapper/etcd/etcd.go
Normal file
67
pkg/kwrapper/etcd/etcd.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// +build !no_etcd
|
||||||
|
|
||||||
|
package etcd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/etcd/etcdmain"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RunETCD(ctx context.Context) ([]string, error) {
|
||||||
|
endpoint := "http://localhost:2379"
|
||||||
|
go runEtcd(ctx, []string{"--data-dir=./etcd"})
|
||||||
|
|
||||||
|
if err := checkEtcd(endpoint); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "waiting on etcd")
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{endpoint}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEtcd(endpoint string) error {
|
||||||
|
ht := &http.Transport{}
|
||||||
|
client := http.Client{
|
||||||
|
Transport: ht,
|
||||||
|
}
|
||||||
|
defer ht.CloseIdleConnections()
|
||||||
|
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
resp, err := client.Get(endpoint + "/health")
|
||||||
|
if err != nil {
|
||||||
|
if i > 1 {
|
||||||
|
logrus.Infof("Waiting on etcd startup: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
io.Copy(ioutil.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if i > 1 {
|
||||||
|
logrus.Infof("Waiting on etcd startup: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEtcd(ctx context.Context, args []string) {
|
||||||
|
os.Args = args
|
||||||
|
logrus.Info("Running ", strings.Join(args, " "))
|
||||||
|
etcdmain.Main()
|
||||||
|
logrus.Errorf("etcd exited")
|
||||||
|
}
|
11
pkg/kwrapper/etcd/etcd_none.go
Normal file
11
pkg/kwrapper/etcd/etcd_none.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build no_etcd
|
||||||
|
|
||||||
|
package etcd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RunETCD(ctx context.Context) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
51
pkg/kwrapper/k8s/config.go
Normal file
51
pkg/kwrapper/k8s/config.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/tools/clientcmd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Auto(ctx context.Context) (bool, context.Context, *rest.Config, error) {
|
||||||
|
return GetConfig(ctx, "auto", os.Getenv("KUBECONFIG"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConfig(ctx context.Context, k8sMode string, kubeConfig string) (bool, context.Context, *rest.Config, error) {
|
||||||
|
var (
|
||||||
|
cfg *rest.Config
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
switch k8sMode {
|
||||||
|
case "auto":
|
||||||
|
return getAuto(ctx, kubeConfig)
|
||||||
|
case "embedded":
|
||||||
|
return getEmbedded(ctx)
|
||||||
|
case "external":
|
||||||
|
cfg, err = getExternal(kubeConfig)
|
||||||
|
default:
|
||||||
|
return false, nil, nil, fmt.Errorf("invalid k8s-mode %s", k8sMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ctx, cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAuto(ctx context.Context, kubeConfig string) (bool, context.Context, *rest.Config, error) {
|
||||||
|
if kubeConfig != "" {
|
||||||
|
cfg, err := getExternal(kubeConfig)
|
||||||
|
return false, ctx, cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config, err := rest.InClusterConfig(); err == nil {
|
||||||
|
return false, ctx, config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return getEmbedded(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExternal(kubeConfig string) (*rest.Config, error) {
|
||||||
|
return clientcmd.BuildConfigFromFlags("", kubeConfig)
|
||||||
|
}
|
37
pkg/kwrapper/k8s/config_k3s.go
Normal file
37
pkg/kwrapper/k8s/config_k3s.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
"github.com/rancher/norman/pkg/resolvehome"
|
||||||
|
"k8s.io/kubernetes/pkg/wrapper/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewK3sConfig(ctx context.Context, dataDir string, authorizer remotedialer.Authorizer) (context.Context, *server.ServerConfig, http.Handler, error) {
|
||||||
|
dataDir, err := resolvehome.Resolve(dataDir)
|
||||||
|
if err != nil {
|
||||||
|
return ctx, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listenIP := net.ParseIP("127.0.0.1")
|
||||||
|
_, clusterIPNet, _ := net.ParseCIDR("10.42.0.0/16")
|
||||||
|
_, serviceIPNet, _ := net.ParseCIDR("10.43.0.0/16")
|
||||||
|
|
||||||
|
sc := &server.ServerConfig{
|
||||||
|
AdvertiseIP: &listenIP,
|
||||||
|
AdvertisePort: 6444,
|
||||||
|
PublicHostname: "localhost",
|
||||||
|
ListenAddr: listenIP,
|
||||||
|
ListenPort: 6443,
|
||||||
|
ClusterIPRange: *clusterIPNet,
|
||||||
|
ServiceIPRange: *serviceIPNet,
|
||||||
|
UseTokenCA: true,
|
||||||
|
DataDir: dataDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = SetK3sConfig(ctx, sc)
|
||||||
|
return ctx, sc, newTunnel(authorizer), nil
|
||||||
|
}
|
14
pkg/kwrapper/k8s/embedded_none.go
Normal file
14
pkg/kwrapper/k8s/embedded_none.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// +build !k3s
|
||||||
|
|
||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getEmbedded(ctx context.Context) (bool, context.Context, *rest.Config, error) {
|
||||||
|
return false, ctx, nil, fmt.Errorf("embedded support is not compiled in, rebuild with -tags k8s")
|
||||||
|
}
|
43
pkg/kwrapper/k8s/k3s.go
Normal file
43
pkg/kwrapper/k8s/k3s.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
// +build k3s
|
||||||
|
|
||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rancher/norman/pkg/kwrapper/etcd"
|
||||||
|
"k8s.io/client-go/rest"
|
||||||
|
"k8s.io/client-go/tools/clientcmd"
|
||||||
|
"k8s.io/kubernetes/pkg/wrapper/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getEmbedded(ctx context.Context) (bool, context.Context, *rest.Config, error) {
|
||||||
|
sc, ok := ctx.Value(serverConfig).(*server.ServerConfig)
|
||||||
|
if !ok {
|
||||||
|
ctx, sc, _, err = NewK3sConfig(ctx, "./k3s", nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, ctx, nil, err
|
||||||
|
}
|
||||||
|
sc.NoScheduler = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.ETCDEndpoints) == 0 {
|
||||||
|
etcdEndpoints, err := etcd.RunETCD(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return ctx, nil, nil, err
|
||||||
|
}
|
||||||
|
sc.ETCDEndpoints = etcdEndpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.Server(ctx, sc)
|
||||||
|
if err != nil {
|
||||||
|
return false, ctx, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Setenv("KUBECONFIG", sc.KubeConfig)
|
||||||
|
restConfig, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
|
||||||
|
&clientcmd.ClientConfigLoadingRules{ExplicitPath: sc.KubeConfig}, &clientcmd.ConfigOverrides{}).ClientConfig()
|
||||||
|
|
||||||
|
return true, ctx, restConfig, err
|
||||||
|
}
|
13
pkg/kwrapper/k8s/k3s_context.go
Normal file
13
pkg/kwrapper/k8s/k3s_context.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverConfig configKey
|
||||||
|
|
||||||
|
type configKey struct{}
|
||||||
|
|
||||||
|
func SetK3sConfig(ctx context.Context, conf interface{}) context.Context {
|
||||||
|
return context.WithValue(ctx, serverConfig, conf)
|
||||||
|
}
|
16
pkg/kwrapper/k8s/tunnel.go
Normal file
16
pkg/kwrapper/k8s/tunnel.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTunnel(authorizer remotedialer.Authorizer) http.Handler {
|
||||||
|
if authorizer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
server := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter)
|
||||||
|
setupK3s(server)
|
||||||
|
return server
|
||||||
|
}
|
26
pkg/kwrapper/k8s/tunnel_k3s.go
Normal file
26
pkg/kwrapper/k8s/tunnel_k3s.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// +build k3s
|
||||||
|
|
||||||
|
package k8s
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rancher/norman/pkg/kv"
|
||||||
|
"github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||||
|
"k8s.io/kubernetes/cmd/kube-apiserver/app"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupK3s(tunnelServer *remotedialer.Server) {
|
||||||
|
app.DefaultProxyDialerFn = utilnet.DialFunc(func(_ context.Context, network, address string) (net.Conn, error) {
|
||||||
|
_, port, _ := net.SplitHostPort(address)
|
||||||
|
addr := "127.0.0.1"
|
||||||
|
if port != "" {
|
||||||
|
addr += ":" + port
|
||||||
|
}
|
||||||
|
nodeName, _ := kv.Split(address, ":")
|
||||||
|
return tunnelServer.Dial(nodeName, 15*time.Second, "tcp", addr)
|
||||||
|
})
|
||||||
|
}
|
8
pkg/kwrapper/k8s/tunnel_none.go
Normal file
8
pkg/kwrapper/k8s/tunnel_none.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// +build !k3s
|
||||||
|
|
||||||
|
package k8s
|
||||||
|
|
||||||
|
import "github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
|
||||||
|
func setupK3s(tunnelServer *remotedialer.Server) {
|
||||||
|
}
|
39
pkg/kwrapper/kubectl/main.go
Normal file
39
pkg/kwrapper/kubectl/main.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package kubectl
|
||||||
|
|
||||||
|
import (
|
||||||
|
goflag "flag"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/docker/docker/pkg/reexec"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
utilflag "k8s.io/apiserver/pkg/util/flag"
|
||||||
|
"k8s.io/apiserver/pkg/util/logs"
|
||||||
|
"k8s.io/kubernetes/pkg/kubectl/cmd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
reexec.Register("kubectl", Main)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Main() {
|
||||||
|
rand.Seed(time.Now().UTC().UnixNano())
|
||||||
|
|
||||||
|
command := cmd.NewDefaultKubectlCommand()
|
||||||
|
|
||||||
|
// TODO: once we switch everything over to Cobra commands, we can go back to calling
|
||||||
|
// utilflag.InitFlags() (by removing its pflag.Parse() call). For now, we have to set the
|
||||||
|
// normalize func and add the go flag set by hand.
|
||||||
|
pflag.CommandLine.SetNormalizeFunc(utilflag.WordSepNormalizeFunc)
|
||||||
|
pflag.CommandLine.AddGoFlagSet(goflag.CommandLine)
|
||||||
|
// utilflag.InitFlags()
|
||||||
|
logs.InitLogs()
|
||||||
|
defer logs.FlushLogs()
|
||||||
|
|
||||||
|
if err := command.Execute(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
47
pkg/remotedialer/client.go
Normal file
47
pkg/remotedialer/client.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectAuthorizer func(proto, address string) bool
|
||||||
|
|
||||||
|
func ClientConnect(wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, onConnect func(context.Context) error) {
|
||||||
|
if err := connectToProxy(wsURL, headers, auth, dialer, onConnect); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to connect to proxy")
|
||||||
|
time.Sleep(time.Duration(5) * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectToProxy(proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, onConnect func(context.Context) error) error {
|
||||||
|
logrus.WithField("url", proxyURL).Info("Connecting to proxy")
|
||||||
|
|
||||||
|
if dialer == nil {
|
||||||
|
dialer = &websocket.Dialer{}
|
||||||
|
}
|
||||||
|
ws, _, err := dialer.Dial(proxyURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to connect to proxy")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if onConnect != nil {
|
||||||
|
if err := onConnect(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session := newClientSession(auth, ws)
|
||||||
|
_, err = session.serve()
|
||||||
|
session.Close()
|
||||||
|
return err
|
||||||
|
}
|
34
pkg/remotedialer/client/main.go
Normal file
34
pkg/remotedialer/client/main.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr string
|
||||||
|
id string
|
||||||
|
debug bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&addr, "connect", "ws://localhost:8123/connect", "Address to connect to")
|
||||||
|
flag.StringVar(&id, "id", "foo", "Client ID")
|
||||||
|
flag.BoolVar(&debug, "debug", true, "Debug logging")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := http.Header{
|
||||||
|
"X-Tunnel-ID": []string{id},
|
||||||
|
}
|
||||||
|
|
||||||
|
remotedialer.ClientConnect(addr, headers, nil, func(string, string) bool { return true }, nil)
|
||||||
|
}
|
58
pkg/remotedialer/client_dialer.go
Normal file
58
pkg/remotedialer/client_dialer.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func clientDial(dialer Dialer, conn *connection, message *message) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
netConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if dialer == nil {
|
||||||
|
netConn, err = net.DialTimeout(message.proto, message.address, time.Duration(message.deadline)*time.Millisecond)
|
||||||
|
} else {
|
||||||
|
netConn, err = dialer(message.proto, message.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
conn.tunnelClose(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer netConn.Close()
|
||||||
|
|
||||||
|
pipe(conn, netConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pipe(client *connection, server net.Conn) {
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
close := func(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
client.doTunnelClose(err)
|
||||||
|
server.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := io.Copy(server, client)
|
||||||
|
close(err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(client, server)
|
||||||
|
err = close(err)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Write tunnel error after no more I/O is happening, just incase messages get out of order
|
||||||
|
client.writeErr(err)
|
||||||
|
}
|
89
pkg/remotedialer/client_windows.go
Normal file
89
pkg/remotedialer/client_windows.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/rancher/rancher/pkg/rkenodeconfigclient"
|
||||||
|
"github.com/rancher/rancher/pkg/rkeworker"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ClientConnectWhileWindows(ctx context.Context, wsURL string, headers http.Header, dialer *websocket.Dialer, auth ConnectAuthorizer, blockingOnConnect func(context.Context) error) int64 {
|
||||||
|
if err := connectToProxyWhileWindows(ctx, wsURL, headers, auth, dialer, blockingOnConnect); err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
|
||||||
|
switch err {
|
||||||
|
case websocket.ErrBadHandshake:
|
||||||
|
return 403
|
||||||
|
case rkeworker.ErrHyperKubePSScriptAgentRetry:
|
||||||
|
logrus.Warn("This connection try to touch proxy again: ", errMsg)
|
||||||
|
return 302
|
||||||
|
default:
|
||||||
|
if e, ok := err.(*rkenodeconfigclient.ErrNodeOrClusterNotFound); ok {
|
||||||
|
logrus.Warn("Can't connect to the registered " + e.ErrorOccursType() + ", terminating gracefully")
|
||||||
|
return 503
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Error("Failed to connect to proxy: ", errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return 200
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectToProxyWhileWindows(rootContext context.Context, proxyURL string, headers http.Header, auth ConnectAuthorizer, dialer *websocket.Dialer, blockingOnConnect func(context.Context) error) error {
|
||||||
|
if dialer == nil {
|
||||||
|
dialer = &websocket.Dialer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(rootContext)
|
||||||
|
|
||||||
|
if blockingOnConnect != nil {
|
||||||
|
eg.Go(func() error {
|
||||||
|
return blockingOnConnect(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
eg.Go(func() error {
|
||||||
|
reconnectCount := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
err := func() error {
|
||||||
|
ws, _, err := dialer.Dial(proxyURL, headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
session := newClientSession(auth, ws)
|
||||||
|
_, err = session.serveWhileWindows(ctx)
|
||||||
|
session.Close()
|
||||||
|
return err
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if reconnectCount < 10 {
|
||||||
|
errMsg := err.Error()
|
||||||
|
if strings.HasSuffix(errMsg, "An existing connection was forcibly closed by the remote host.") ||
|
||||||
|
strings.HasSuffix(errMsg, "An established connection was aborted by the software in your host machine.") ||
|
||||||
|
strings.HasSuffix(errMsg, "A socket operation was attempted to an unreachable network.") {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
reconnectCount += 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return eg.Wait()
|
||||||
|
}
|
188
pkg/remotedialer/connection.go
Normal file
188
pkg/remotedialer/connection.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connection struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel func()
|
||||||
|
err error
|
||||||
|
writeDeadline time.Time
|
||||||
|
buf chan []byte
|
||||||
|
readBuf []byte
|
||||||
|
addr addr
|
||||||
|
session *session
|
||||||
|
connID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnection(connID int64, session *session, proto, address string) *connection {
|
||||||
|
c := &connection{
|
||||||
|
addr: addr{
|
||||||
|
proto: proto,
|
||||||
|
address: address,
|
||||||
|
},
|
||||||
|
connID: connID,
|
||||||
|
session: session,
|
||||||
|
buf: make(chan []byte, 1024),
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) tunnelClose(err error) {
|
||||||
|
c.writeErr(err)
|
||||||
|
c.doTunnelClose(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) doTunnelClose(err error) {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
if c.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.err = err
|
||||||
|
if c.err == nil {
|
||||||
|
c.err = io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) tunnelWriter() io.Writer {
|
||||||
|
return chanWriter{conn: c, C: c.buf}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) Close() error {
|
||||||
|
c.session.closeConnection(c.connID, io.EOF)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) copyData(b []byte) int {
|
||||||
|
n := copy(b, c.readBuf)
|
||||||
|
c.readBuf = c.readBuf[n:]
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) Read(b []byte) (int, error) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n := c.copyData(b)
|
||||||
|
if n > 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
next, ok := <-c.buf
|
||||||
|
if !ok {
|
||||||
|
err := io.EOF
|
||||||
|
c.Lock()
|
||||||
|
if c.err != nil {
|
||||||
|
err = c.err
|
||||||
|
}
|
||||||
|
c.Unlock()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readBuf = next
|
||||||
|
n = c.copyData(b)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) Write(b []byte) (int, error) {
|
||||||
|
c.Lock()
|
||||||
|
if c.err != nil {
|
||||||
|
defer c.Unlock()
|
||||||
|
return 0, c.err
|
||||||
|
}
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
|
deadline := int64(0)
|
||||||
|
if !c.writeDeadline.IsZero() {
|
||||||
|
deadline = c.writeDeadline.Sub(time.Now()).Nanoseconds() / 1000000
|
||||||
|
}
|
||||||
|
return c.session.writeMessage(newMessage(c.connID, deadline, b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) writeErr(err error) {
|
||||||
|
if err != nil {
|
||||||
|
c.session.writeMessage(newErrorMessage(c.connID, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) LocalAddr() net.Addr {
|
||||||
|
return c.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) RemoteAddr() net.Addr {
|
||||||
|
return c.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) SetDeadline(t time.Time) error {
|
||||||
|
if err := c.SetReadDeadline(t); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) SetReadDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) SetWriteDeadline(t time.Time) error {
|
||||||
|
c.writeDeadline = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type addr struct {
|
||||||
|
proto string
|
||||||
|
address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a addr) Network() string {
|
||||||
|
return a.proto
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a addr) String() string {
|
||||||
|
return a.address
|
||||||
|
}
|
||||||
|
|
||||||
|
type chanWriter struct {
|
||||||
|
conn *connection
|
||||||
|
C chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chanWriter) Write(buf []byte) (int, error) {
|
||||||
|
c.conn.Lock()
|
||||||
|
defer c.conn.Unlock()
|
||||||
|
|
||||||
|
if c.conn.err != nil {
|
||||||
|
return 0, c.conn.err
|
||||||
|
}
|
||||||
|
|
||||||
|
newBuf := make([]byte, len(buf))
|
||||||
|
copy(newBuf, buf)
|
||||||
|
buf = newBuf
|
||||||
|
|
||||||
|
select {
|
||||||
|
// must copy the buffer
|
||||||
|
case c.C <- buf:
|
||||||
|
return len(buf), nil
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case c.C <- buf:
|
||||||
|
return len(buf), nil
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
return 0, errors.New("backed up reader")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
28
pkg/remotedialer/dialer.go
Normal file
28
pkg/remotedialer/dialer.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dialer func(network, address string) (net.Conn, error)
|
||||||
|
|
||||||
|
func (s *Server) HasSession(clientKey string) bool {
|
||||||
|
_, err := s.sessions.getDialer(clientKey, 0)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Dial(clientKey string, deadline time.Duration, proto, address string) (net.Conn, error) {
|
||||||
|
d, err := s.sessions.getDialer(clientKey, deadline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d(proto, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Dialer(clientKey string, deadline time.Duration) Dialer {
|
||||||
|
return func(proto, address string) (net.Conn, error) {
|
||||||
|
return s.Dial(clientKey, deadline, proto, address)
|
||||||
|
}
|
||||||
|
}
|
29
pkg/remotedialer/dummy/main.go
Normal file
29
pkg/remotedialer/dummy/main.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
counter int64
|
||||||
|
listen string
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&listen, "listen", ":8125", "Listen address")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
fmt.Println("listening ", listen)
|
||||||
|
http.ListenAndServe(listen, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
next := atomic.AddInt64(&counter, 1)
|
||||||
|
fmt.Println("request", next)
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond)
|
||||||
|
rw.Write([]byte("HI"))
|
||||||
|
}))
|
||||||
|
}
|
220
pkg/remotedialer/message.go
Normal file
220
pkg/remotedialer/message.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Data messageType = iota + 1
|
||||||
|
Connect
|
||||||
|
Error
|
||||||
|
AddClient
|
||||||
|
RemoveClient
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
idCounter int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
r := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||||||
|
idCounter = r.Int63()
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageType int64
|
||||||
|
|
||||||
|
type message struct {
|
||||||
|
id int64
|
||||||
|
err error
|
||||||
|
connID int64
|
||||||
|
deadline int64
|
||||||
|
messageType messageType
|
||||||
|
bytes []byte
|
||||||
|
body io.Reader
|
||||||
|
proto string
|
||||||
|
address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextid() int64 {
|
||||||
|
return atomic.AddInt64(&idCounter, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessage(connID int64, deadline int64, bytes []byte) *message {
|
||||||
|
return &message{
|
||||||
|
id: nextid(),
|
||||||
|
connID: connID,
|
||||||
|
deadline: deadline,
|
||||||
|
messageType: Data,
|
||||||
|
bytes: bytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnect(connID int64, deadline time.Duration, proto, address string) *message {
|
||||||
|
return &message{
|
||||||
|
id: nextid(),
|
||||||
|
connID: connID,
|
||||||
|
deadline: deadline.Nanoseconds() / 1000000,
|
||||||
|
messageType: Connect,
|
||||||
|
bytes: []byte(fmt.Sprintf("%s/%s", proto, address)),
|
||||||
|
proto: proto,
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newErrorMessage(connID int64, err error) *message {
|
||||||
|
return &message{
|
||||||
|
id: nextid(),
|
||||||
|
err: err,
|
||||||
|
connID: connID,
|
||||||
|
messageType: Error,
|
||||||
|
bytes: []byte(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAddClient(client string) *message {
|
||||||
|
return &message{
|
||||||
|
id: nextid(),
|
||||||
|
messageType: AddClient,
|
||||||
|
address: client,
|
||||||
|
bytes: []byte(client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRemoveClient(client string) *message {
|
||||||
|
return &message{
|
||||||
|
id: nextid(),
|
||||||
|
messageType: RemoveClient,
|
||||||
|
address: client,
|
||||||
|
bytes: []byte(client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerMessage(reader io.Reader) (*message, error) {
|
||||||
|
buf := bufio.NewReader(reader)
|
||||||
|
|
||||||
|
id, err := binary.ReadVarint(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
connID, err := binary.ReadVarint(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mType, err := binary.ReadVarint(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &message{
|
||||||
|
id: id,
|
||||||
|
messageType: messageType(mType),
|
||||||
|
connID: connID,
|
||||||
|
body: buf,
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.messageType == Data || m.messageType == Connect {
|
||||||
|
deadline, err := binary.ReadVarint(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.deadline = deadline
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.messageType == Connect {
|
||||||
|
bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(string(bytes), "/", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, fmt.Errorf("failed to parse connect address")
|
||||||
|
}
|
||||||
|
m.proto = parts[0]
|
||||||
|
m.address = parts[1]
|
||||||
|
m.bytes = bytes
|
||||||
|
} else if m.messageType == AddClient || m.messageType == RemoveClient {
|
||||||
|
bytes, err := ioutil.ReadAll(io.LimitReader(buf, 100))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.address = string(bytes)
|
||||||
|
m.bytes = bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) Err() error {
|
||||||
|
if m.err != nil {
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
bytes, err := ioutil.ReadAll(io.LimitReader(m.body, 100))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
str := string(bytes)
|
||||||
|
if str == "EOF" {
|
||||||
|
m.err = io.EOF
|
||||||
|
} else {
|
||||||
|
m.err = errors.New(str)
|
||||||
|
}
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) Bytes() []byte {
|
||||||
|
return append(m.header(), m.bytes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) header() []byte {
|
||||||
|
buf := make([]byte, 24)
|
||||||
|
offset := 0
|
||||||
|
offset += binary.PutVarint(buf[offset:], m.id)
|
||||||
|
offset += binary.PutVarint(buf[offset:], m.connID)
|
||||||
|
offset += binary.PutVarint(buf[offset:], int64(m.messageType))
|
||||||
|
if m.messageType == Data || m.messageType == Connect {
|
||||||
|
offset += binary.PutVarint(buf[offset:], m.deadline)
|
||||||
|
}
|
||||||
|
return buf[:offset]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) Read(p []byte) (int, error) {
|
||||||
|
return m.body.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) WriteTo(wsConn *wsConn) (int, error) {
|
||||||
|
err := wsConn.WriteMessage(websocket.BinaryMessage, m.Bytes())
|
||||||
|
return len(m.bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *message) String() string {
|
||||||
|
switch m.messageType {
|
||||||
|
case Data:
|
||||||
|
if m.body == nil {
|
||||||
|
return fmt.Sprintf("%d DATA [%d]: %d bytes: %s", m.id, m.connID, len(m.bytes), string(m.bytes))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d DATA [%d]: buffered", m.id, m.connID)
|
||||||
|
case Error:
|
||||||
|
return fmt.Sprintf("%d ERROR [%d]: %s", m.id, m.connID, m.Err())
|
||||||
|
case Connect:
|
||||||
|
return fmt.Sprintf("%d CONNECT [%d]: %s/%s deadline %d", m.id, m.connID, m.proto, m.address, m.deadline)
|
||||||
|
case AddClient:
|
||||||
|
return fmt.Sprintf("%d ADDCLIENT [%s]", m.id, m.address)
|
||||||
|
case RemoveClient:
|
||||||
|
return fmt.Sprintf("%d REMOVECLIENT [%s]", m.id, m.address)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType)
|
||||||
|
}
|
120
pkg/remotedialer/peer.go
Normal file
120
pkg/remotedialer/peer.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
Token = "X-API-Tunnel-Token"
|
||||||
|
ID = "X-API-Tunnel-ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) AddPeer(url, id, token string) {
|
||||||
|
if s.PeerID == "" || s.PeerToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
peer := peer{
|
||||||
|
url: url,
|
||||||
|
id: id,
|
||||||
|
token: token,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Adding peer %s, %s", url, id)
|
||||||
|
|
||||||
|
s.peerLock.Lock()
|
||||||
|
defer s.peerLock.Unlock()
|
||||||
|
|
||||||
|
if p, ok := s.peers[id]; ok {
|
||||||
|
if p.equals(peer) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.peers[id] = peer
|
||||||
|
go peer.start(ctx, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RemovePeer(id string) {
|
||||||
|
s.peerLock.Lock()
|
||||||
|
defer s.peerLock.Unlock()
|
||||||
|
|
||||||
|
if p, ok := s.peers[id]; ok {
|
||||||
|
logrus.Infof("Removing peer %s", id)
|
||||||
|
p.cancel()
|
||||||
|
}
|
||||||
|
delete(s.peers, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
type peer struct {
|
||||||
|
url, id, token string
|
||||||
|
cancel func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p peer) equals(other peer) bool {
|
||||||
|
return p.url == other.url &&
|
||||||
|
p.id == other.id &&
|
||||||
|
p.token == other.token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *peer) start(ctx context.Context, s *Server) {
|
||||||
|
headers := http.Header{
|
||||||
|
ID: {s.PeerID},
|
||||||
|
Token: {s.PeerToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &websocket.Dialer{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
outer:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break outer
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
ws, _, err := dialer.Dial(p.url, headers)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("Failed to connect to peer %s [local ID=%s]: %v", p.url, s.PeerID, err)
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
session := newClientSession(func(string, string) bool { return true }, ws)
|
||||||
|
session.dialer = func(network, address string) (net.Conn, error) {
|
||||||
|
parts := strings.SplitN(network, "::", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, fmt.Errorf("invalid clientKey/proto: %s", network)
|
||||||
|
}
|
||||||
|
return s.Dial(parts[0], 15*time.Second, parts[1], address)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sessions.addListener(session)
|
||||||
|
_, err = session.serve()
|
||||||
|
s.sessions.removeListener(session)
|
||||||
|
session.Close()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("Failed to serve peer connection %s: %v", p.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.Close()
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
97
pkg/remotedialer/server.go
Normal file
97
pkg/remotedialer/server.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errFailedAuth = errors.New("failed authentication")
|
||||||
|
errWrongMessageType = errors.New("wrong websocket message type")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Authorizer func(req *http.Request) (clientKey string, authed bool, err error)
|
||||||
|
type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error)
|
||||||
|
|
||||||
|
func DefaultErrorWriter(rw http.ResponseWriter, req *http.Request, code int, err error) {
|
||||||
|
rw.Write([]byte(err.Error()))
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
PeerID string
|
||||||
|
PeerToken string
|
||||||
|
authorizer Authorizer
|
||||||
|
errorWriter ErrorWriter
|
||||||
|
sessions *sessionManager
|
||||||
|
peers map[string]peer
|
||||||
|
peerLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(auth Authorizer, errorWriter ErrorWriter) *Server {
|
||||||
|
return &Server{
|
||||||
|
peers: map[string]peer{},
|
||||||
|
authorizer: auth,
|
||||||
|
errorWriter: errorWriter,
|
||||||
|
sessions: newSessionManager(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
clientKey, authed, peer, err := s.auth(req)
|
||||||
|
if err != nil {
|
||||||
|
s.errorWriter(rw, req, 400, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !authed {
|
||||||
|
s.errorWriter(rw, req, 401, errFailedAuth)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Handling backend connection request [%s]", clientKey)
|
||||||
|
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
HandshakeTimeout: 5 * time.Second,
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Error: s.errorWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
wsConn, err := upgrader.Upgrade(rw, req, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := s.sessions.add(clientKey, wsConn, peer)
|
||||||
|
defer s.sessions.remove(session)
|
||||||
|
|
||||||
|
// Don't need to associate req.Context() to the session, it will cancel otherwise
|
||||||
|
code, err := session.serve()
|
||||||
|
if err != nil {
|
||||||
|
// Hijacked so we can't write to the client
|
||||||
|
logrus.Infof("error in remotedialer server [%d]: %v", code, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) auth(req *http.Request) (clientKey string, authed, peer bool, err error) {
|
||||||
|
id := req.Header.Get(ID)
|
||||||
|
token := req.Header.Get(Token)
|
||||||
|
if id != "" && token != "" {
|
||||||
|
// peer authentication
|
||||||
|
s.peerLock.Lock()
|
||||||
|
p, ok := s.peers[id]
|
||||||
|
s.peerLock.Unlock()
|
||||||
|
|
||||||
|
if ok && p.token == token {
|
||||||
|
return id, true, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
id, authed, err = s.authorizer(req)
|
||||||
|
return id, authed, false, err
|
||||||
|
}
|
125
pkg/remotedialer/server/main.go
Normal file
125
pkg/remotedialer/server/main.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/rancher/norman/pkg/remotedialer"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clients = map[string]*http.Client{}
|
||||||
|
l sync.Mutex
|
||||||
|
counter int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func authorizer(req *http.Request) (string, bool, error) {
|
||||||
|
id := req.Header.Get("x-tunnel-id")
|
||||||
|
return id, id != "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Client(server *remotedialer.Server, rw http.ResponseWriter, req *http.Request) {
|
||||||
|
timeout := req.URL.Query().Get("timeout")
|
||||||
|
if timeout == "" {
|
||||||
|
timeout = "15"
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(req)
|
||||||
|
clientKey := vars["id"]
|
||||||
|
url := fmt.Sprintf("%s://%s%s", vars["scheme"], vars["host"], vars["path"])
|
||||||
|
client := getClient(server, clientKey, timeout)
|
||||||
|
|
||||||
|
id := atomic.AddInt64(&counter, 1)
|
||||||
|
logrus.Infof("[%03d] REQ t=%s %s", id, timeout, url)
|
||||||
|
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("[%03d] REQ ERR t=%s %s: %v", id, timeout, url, err)
|
||||||
|
remotedialer.DefaultErrorWriter(rw, req, 500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
logrus.Infof("[%03d] REQ OK t=%s %s", id, timeout, url)
|
||||||
|
rw.WriteHeader(resp.StatusCode)
|
||||||
|
io.Copy(rw, resp.Body)
|
||||||
|
logrus.Infof("[%03d] REQ DONE t=%s %s", id, timeout, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClient(server *remotedialer.Server, clientKey, timeout string) *http.Client {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s/%s", clientKey, timeout)
|
||||||
|
client := clients[key]
|
||||||
|
if client != nil {
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := server.Dialer(clientKey, 15*time.Second)
|
||||||
|
client = &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Dial: dialer,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if timeout != "" {
|
||||||
|
t, err := strconv.Atoi(timeout)
|
||||||
|
if err == nil {
|
||||||
|
client.Timeout = time.Duration(t) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clients[key] = client
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
addr string
|
||||||
|
peerID string
|
||||||
|
peerToken string
|
||||||
|
peers string
|
||||||
|
debug bool
|
||||||
|
)
|
||||||
|
flag.StringVar(&addr, "listen", ":8123", "Listen address")
|
||||||
|
flag.StringVar(&peerID, "id", "", "Peer ID")
|
||||||
|
flag.StringVar(&peerToken, "token", "", "Peer Token")
|
||||||
|
flag.StringVar(&peers, "peers", "", "Peers format id:token:url,id:token:url")
|
||||||
|
flag.BoolVar(&debug, "debug", false, "Enable debug logging")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
remotedialer.PrintTunnelData = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := remotedialer.New(authorizer, remotedialer.DefaultErrorWriter)
|
||||||
|
handler.PeerToken = peerToken
|
||||||
|
handler.PeerID = peerID
|
||||||
|
|
||||||
|
for _, peer := range strings.Split(peers, ",") {
|
||||||
|
parts := strings.SplitN(strings.TrimSpace(peer), ":", 3)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handler.AddPeer(parts[2], parts[0], parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.Handle("/connect", handler)
|
||||||
|
router.HandleFunc("/client/{id}/{scheme}/{host}{path:.*}", func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
Client(handler, rw, req)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("Listening on ", addr)
|
||||||
|
http.ListenAndServe(addr, router)
|
||||||
|
}
|
303
pkg/remotedialer/session.go
Normal file
303
pkg/remotedialer/session.go
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type session struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
nextConnID int64
|
||||||
|
clientKey string
|
||||||
|
sessionKey int64
|
||||||
|
conn *wsConn
|
||||||
|
conns map[int64]*connection
|
||||||
|
remoteClientKeys map[string]map[int]bool
|
||||||
|
auth ConnectAuthorizer
|
||||||
|
pingCancel context.CancelFunc
|
||||||
|
pingWait sync.WaitGroup
|
||||||
|
dialer Dialer
|
||||||
|
client bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintTunnelData No tunnel logging by default
|
||||||
|
var PrintTunnelData bool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if os.Getenv("CATTLE_TUNNEL_DATA_DEBUG") == "true" {
|
||||||
|
PrintTunnelData = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientSession(auth ConnectAuthorizer, conn *websocket.Conn) *session {
|
||||||
|
return &session{
|
||||||
|
clientKey: "client",
|
||||||
|
conn: newWSConn(conn),
|
||||||
|
conns: map[int64]*connection{},
|
||||||
|
auth: auth,
|
||||||
|
client: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSession(sessionKey int64, clientKey string, conn *websocket.Conn) *session {
|
||||||
|
return &session{
|
||||||
|
nextConnID: 1,
|
||||||
|
clientKey: clientKey,
|
||||||
|
sessionKey: sessionKey,
|
||||||
|
conn: newWSConn(conn),
|
||||||
|
conns: map[int64]*connection{},
|
||||||
|
remoteClientKeys: map[string]map[int]bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) startPings() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s.pingCancel = cancel
|
||||||
|
s.pingWait.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer s.pingWait.Done()
|
||||||
|
|
||||||
|
t := time.NewTicker(PingWriteInterval)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
s.conn.Lock()
|
||||||
|
if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil {
|
||||||
|
logrus.WithError(err).Error("Error writing ping")
|
||||||
|
}
|
||||||
|
logrus.Debug("Wrote ping")
|
||||||
|
s.conn.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) stopPings() {
|
||||||
|
if s.pingCancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.pingCancel()
|
||||||
|
s.pingWait.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) serve() (int, error) {
|
||||||
|
if s.client {
|
||||||
|
s.startPings()
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
msType, reader, err := s.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return 400, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if msType != websocket.BinaryMessage {
|
||||||
|
return 400, errWrongMessageType
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.serveMessage(reader); err != nil {
|
||||||
|
return 500, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) serveMessage(reader io.Reader) error {
|
||||||
|
message, err := newServerMessage(reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debug("REQUEST ", message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.messageType == Connect {
|
||||||
|
if s.auth == nil || !s.auth(message.proto, message.address) {
|
||||||
|
return errors.New("connect not allowed")
|
||||||
|
}
|
||||||
|
s.clientConnect(message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
if message.messageType == AddClient && s.remoteClientKeys != nil {
|
||||||
|
err := s.addRemoteClient(message.address)
|
||||||
|
s.Unlock()
|
||||||
|
return err
|
||||||
|
} else if message.messageType == RemoveClient {
|
||||||
|
err := s.removeRemoteClient(message.address)
|
||||||
|
s.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn := s.conns[message.connID]
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
if message.messageType == Data {
|
||||||
|
err := fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, message.connID)
|
||||||
|
newErrorMessage(message.connID, err).WriteTo(s.conn)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch message.messageType {
|
||||||
|
case Data:
|
||||||
|
if _, err := io.Copy(conn.tunnelWriter(), message); err != nil {
|
||||||
|
s.closeConnection(message.connID, err)
|
||||||
|
}
|
||||||
|
case Error:
|
||||||
|
s.closeConnection(message.connID, message.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAddress(address string) (string, int, error) {
|
||||||
|
parts := strings.SplitN(address, "/", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", 0, errors.New("not / separated")
|
||||||
|
}
|
||||||
|
v, err := strconv.Atoi(parts[1])
|
||||||
|
return parts[0], v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) addRemoteClient(address string) error {
|
||||||
|
clientKey, sessionKey, err := parseAddress(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid remote session %s: %v", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := s.remoteClientKeys[clientKey]
|
||||||
|
if keys == nil {
|
||||||
|
keys = map[int]bool{}
|
||||||
|
s.remoteClientKeys[clientKey] = keys
|
||||||
|
}
|
||||||
|
keys[int(sessionKey)] = true
|
||||||
|
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) removeRemoteClient(address string) error {
|
||||||
|
clientKey, sessionKey, err := parseAddress(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid remote session %s: %v", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := s.remoteClientKeys[clientKey]
|
||||||
|
delete(keys, int(sessionKey))
|
||||||
|
if len(keys) == 0 {
|
||||||
|
delete(s.remoteClientKeys, clientKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) closeConnection(connID int64, err error) {
|
||||||
|
s.Lock()
|
||||||
|
conn := s.conns[connID]
|
||||||
|
delete(s.conns, connID)
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
|
||||||
|
}
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
if conn != nil {
|
||||||
|
conn.tunnelClose(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) clientConnect(message *message) {
|
||||||
|
conn := newConnection(message.connID, s, message.proto, message.address)
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
s.conns[message.connID] = conn
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
|
||||||
|
}
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
go clientDial(s.dialer, conn, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) serverConnect(deadline time.Duration, proto, address string) (net.Conn, error) {
|
||||||
|
connID := atomic.AddInt64(&s.nextConnID, 1)
|
||||||
|
conn := newConnection(connID, s, proto, address)
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
s.conns[connID] = conn
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debugf("CONNECTIONS %d %d", s.sessionKey, len(s.conns))
|
||||||
|
}
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
_, err := s.writeMessage(newConnect(connID, deadline, proto, address))
|
||||||
|
if err != nil {
|
||||||
|
s.closeConnection(connID, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) writeMessage(message *message) (int, error) {
|
||||||
|
if PrintTunnelData {
|
||||||
|
logrus.Debug("WRITE ", message)
|
||||||
|
}
|
||||||
|
return message.WriteTo(s.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) Close() {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
s.stopPings()
|
||||||
|
|
||||||
|
for _, connection := range s.conns {
|
||||||
|
connection.tunnelClose(errors.New("tunnel disconnect"))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.conns = map[int64]*connection{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sessionAdded(clientKey string, sessionKey int64) {
|
||||||
|
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
|
||||||
|
_, err := s.writeMessage(newAddClient(client))
|
||||||
|
if err != nil {
|
||||||
|
s.conn.conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sessionRemoved(clientKey string, sessionKey int64) {
|
||||||
|
client := fmt.Sprintf("%s/%d", clientKey, sessionKey)
|
||||||
|
_, err := s.writeMessage(newRemoveClient(client))
|
||||||
|
if err != nil {
|
||||||
|
s.conn.conn.Close()
|
||||||
|
}
|
||||||
|
}
|
137
pkg/remotedialer/session_manager.go
Normal file
137
pkg/remotedialer/session_manager.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionListener interface {
|
||||||
|
sessionAdded(clientKey string, sessionKey int64)
|
||||||
|
sessionRemoved(clientKey string, sessionKey int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionManager struct {
|
||||||
|
sync.Mutex
|
||||||
|
clients map[string][]*session
|
||||||
|
peers map[string][]*session
|
||||||
|
listeners map[sessionListener]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSessionManager() *sessionManager {
|
||||||
|
return &sessionManager{
|
||||||
|
clients: map[string][]*session{},
|
||||||
|
peers: map[string][]*session{},
|
||||||
|
listeners: map[sessionListener]bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toDialer(s *session, prefix string, deadline time.Duration) Dialer {
|
||||||
|
return func(proto, address string) (net.Conn, error) {
|
||||||
|
if prefix == "" {
|
||||||
|
return s.serverConnect(deadline, proto, address)
|
||||||
|
}
|
||||||
|
return s.serverConnect(deadline, prefix+"::"+proto, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *sessionManager) removeListener(listener sessionListener) {
|
||||||
|
sm.Lock()
|
||||||
|
defer sm.Unlock()
|
||||||
|
|
||||||
|
delete(sm.listeners, listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *sessionManager) addListener(listener sessionListener) {
|
||||||
|
sm.Lock()
|
||||||
|
defer sm.Unlock()
|
||||||
|
|
||||||
|
sm.listeners[listener] = true
|
||||||
|
|
||||||
|
for k, sessions := range sm.clients {
|
||||||
|
for _, session := range sessions {
|
||||||
|
listener.sessionAdded(k, session.sessionKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, sessions := range sm.peers {
|
||||||
|
for _, session := range sessions {
|
||||||
|
listener.sessionAdded(k, session.sessionKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *sessionManager) getDialer(clientKey string, deadline time.Duration) (Dialer, error) {
|
||||||
|
sm.Lock()
|
||||||
|
defer sm.Unlock()
|
||||||
|
|
||||||
|
sessions := sm.clients[clientKey]
|
||||||
|
if len(sessions) > 0 {
|
||||||
|
return toDialer(sessions[0], "", deadline), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sessions := range sm.peers {
|
||||||
|
for _, session := range sessions {
|
||||||
|
session.Lock()
|
||||||
|
keys := session.remoteClientKeys[clientKey]
|
||||||
|
session.Unlock()
|
||||||
|
if len(keys) > 0 {
|
||||||
|
return toDialer(session, clientKey, deadline), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to find session for client %s", clientKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *sessionManager) add(clientKey string, conn *websocket.Conn, peer bool) *session {
|
||||||
|
sessionKey := rand.Int63()
|
||||||
|
session := newSession(sessionKey, clientKey, conn)
|
||||||
|
|
||||||
|
sm.Lock()
|
||||||
|
defer sm.Unlock()
|
||||||
|
|
||||||
|
if peer {
|
||||||
|
sm.peers[clientKey] = append(sm.peers[clientKey], session)
|
||||||
|
} else {
|
||||||
|
sm.clients[clientKey] = append(sm.clients[clientKey], session)
|
||||||
|
}
|
||||||
|
|
||||||
|
for l := range sm.listeners {
|
||||||
|
l.sessionAdded(clientKey, session.sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *sessionManager) remove(s *session) {
|
||||||
|
sm.Lock()
|
||||||
|
defer sm.Unlock()
|
||||||
|
|
||||||
|
for _, store := range []map[string][]*session{sm.clients, sm.peers} {
|
||||||
|
var newSessions []*session
|
||||||
|
|
||||||
|
for _, v := range store[s.clientKey] {
|
||||||
|
if v.sessionKey == s.sessionKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newSessions = append(newSessions, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newSessions) == 0 {
|
||||||
|
delete(store, s.clientKey)
|
||||||
|
} else {
|
||||||
|
store[s.clientKey] = newSessions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for l := range sm.listeners {
|
||||||
|
l.sessionRemoved(s.clientKey, s.sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Close()
|
||||||
|
}
|
57
pkg/remotedialer/session_windows.go
Normal file
57
pkg/remotedialer/session_windows.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *session) startPingsWhileWindows(rootCtx context.Context) {
|
||||||
|
ctx, cancel := context.WithCancel(rootCtx)
|
||||||
|
s.pingCancel = cancel
|
||||||
|
s.pingWait.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer s.pingWait.Done()
|
||||||
|
|
||||||
|
t := time.NewTicker(PingWriteInterval)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
s.conn.Lock()
|
||||||
|
if err := s.conn.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(time.Second)); err != nil {
|
||||||
|
logrus.WithError(err).Error("Error writing ping")
|
||||||
|
}
|
||||||
|
logrus.Debug("Wrote ping")
|
||||||
|
s.conn.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) serveWhileWindows(ctx context.Context) (int, error) {
|
||||||
|
if s.client {
|
||||||
|
s.startPingsWhileWindows(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
msType, reader, err := s.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return 400, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if msType != websocket.BinaryMessage {
|
||||||
|
return 400, errWrongMessageType
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.serveMessage(reader); err != nil {
|
||||||
|
return 500, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
11
pkg/remotedialer/types.go
Normal file
11
pkg/remotedialer/types.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
PingWaitDuration = time.Duration(10 * time.Second)
|
||||||
|
PingWriteInterval = time.Duration(5 * time.Second)
|
||||||
|
MaxRead = 8192
|
||||||
|
)
|
47
pkg/remotedialer/wsconn.go
Normal file
47
pkg/remotedialer/wsconn.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package remotedialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wsConn struct {
|
||||||
|
sync.Mutex
|
||||||
|
conn *websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSConn(conn *websocket.Conn) *wsConn {
|
||||||
|
w := &wsConn{
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
w.setupDeadline()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsConn) WriteMessage(messageType int, data []byte) error {
|
||||||
|
w.Lock()
|
||||||
|
defer w.Unlock()
|
||||||
|
w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration))
|
||||||
|
return w.conn.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsConn) NextReader() (int, io.Reader, error) {
|
||||||
|
return w.conn.NextReader()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsConn) setupDeadline() {
|
||||||
|
w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
|
||||||
|
w.conn.SetPingHandler(func(string) error {
|
||||||
|
w.Lock()
|
||||||
|
w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(time.Second))
|
||||||
|
w.Unlock()
|
||||||
|
return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
|
||||||
|
})
|
||||||
|
w.conn.SetPongHandler(func(string) error {
|
||||||
|
return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration))
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
39
pkg/resolvehome/home.go
Normal file
39
pkg/resolvehome/home.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package resolvehome
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
homes = []string{"$HOME", "${HOME}", "~"}
|
||||||
|
)
|
||||||
|
|
||||||
|
func Resolve(s string) (string, error) {
|
||||||
|
for _, home := range homes {
|
||||||
|
if strings.Contains(s, home) {
|
||||||
|
homeDir, err := getHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "determining current user")
|
||||||
|
}
|
||||||
|
s = strings.Replace(s, home, homeDir, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHomeDir() (string, error) {
|
||||||
|
if os.Getuid() == 0 {
|
||||||
|
return "/root", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "determining current user, try set HOME and USER env vars")
|
||||||
|
}
|
||||||
|
return u.HomeDir, nil
|
||||||
|
}
|
Reference in New Issue
Block a user