diff --git a/pkg/aggregation/listener.go b/pkg/aggregation/listener.go new file mode 100644 index 0000000..b0a3d74 --- /dev/null +++ b/pkg/aggregation/listener.go @@ -0,0 +1,70 @@ +package aggregation + +import ( + "context" + "io" + "net" + "sync" +) + +type addr string + +func (a addr) String() string { + return string(a) +} +func (a addr) Network() string { + return "tcp" +} + +type Listener struct { + sync.RWMutex + + address addr + connections chan net.Conn + closed bool +} + +func NewListener(address string) *Listener { + return &Listener{ + address: addr(address), + connections: make(chan net.Conn, 5), + } +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, ok := <-l.connections + if !ok { + return nil, io.ErrClosedPipe + } + return conn, nil +} + +func (l *Listener) Close() error { + l.Lock() + defer l.Unlock() + if !l.closed { + close(l.connections) + l.closed = true + } + return nil +} + +func (l *Listener) Dial(ctx context.Context, network, address string) (net.Conn, error) { + left, right := net.Pipe() + l.RLock() + defer l.RUnlock() + if l.closed { + return nil, io.ErrClosedPipe + } + + select { + case l.connections <- right: + return left, nil + case <-ctx.Done(): + return nil, io.ErrClosedPipe + } +} + +func (l *Listener) Addr() net.Addr { + return l.address +} diff --git a/pkg/aggregation/server.go b/pkg/aggregation/server.go new file mode 100644 index 0000000..a02e12d --- /dev/null +++ b/pkg/aggregation/server.go @@ -0,0 +1,82 @@ +package aggregation + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/rancher/remotedialer" + "github.com/rancher/steve/pkg/auth" + "github.com/sirupsen/logrus" +) + +func ListenAndServe(ctx context.Context, url string, caCert []byte, token string, handler http.Handler) { + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + } + + if caCert != nil && len(caCert) == 0 { + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } else if len(caCert) > 0 { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(caCert) + dialer.TLSClientConfig = &tls.Config{ + RootCAs: pool, + } + } + + handler = auth.ToMiddleware(auth.AuthenticatorFunc(auth.Impersonation))(handler) + + headers := http.Header{} + headers.Add("Authorization", "Bearer "+token) + + for { + err := serve(ctx, dialer, url, headers, handler) + if err != nil { + logrus.Errorf("Failed to dial steve aggregation server: %v", err) + } + select { + case <-ctx.Done(): + return + case <-time.After(5 * time.Second): + } + } +} + +func serve(ctx context.Context, dialer websocket.Dialer, url string, headers http.Header, handler http.Handler) error { + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.Replace(url, "https://", "wss://", 1) + conn, _, err := dialer.DialContext(ctx, url, headers) + if err != nil { + return err + } + defer conn.Close() + + listener := NewListener("steve") + server := http.Server{ + Handler: handler, + BaseContext: func(_ net.Listener) context.Context { + return ctx + }, + } + go server.Serve(listener) + defer server.Shutdown(context.Background()) + + session := remotedialer.NewClientSessionWithDialer(allowAll, conn, listener.Dial) + defer session.Close() + + _, err = session.Serve(ctx) + return err +} + +func allowAll(_, _ string) bool { + return true +} diff --git a/pkg/aggregation/watch.go b/pkg/aggregation/watch.go new file mode 100644 index 0000000..5f42400 --- /dev/null +++ b/pkg/aggregation/watch.go @@ -0,0 +1,93 @@ +package aggregation + +import ( + "bytes" + "context" + "net/http" + + v1 "github.com/rancher/wrangler/pkg/generated/controllers/core/v1" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" +) + +func Watch(ctx context.Context, controller v1.SecretController, secretNamespace, secretName string, httpHandler http.Handler) { + if secretNamespace == "" || secretName == "" { + return + } + h := handler{ + ctx: ctx, + handler: httpHandler, + namespace: secretNamespace, + name: secretName, + } + controller.OnChange(ctx, "aggregation-controller", h.OnSecret) +} + +type handler struct { + handler http.Handler + namespace, name string + + url string + caCert []byte + token string + ctx context.Context + cancel func() +} + +func (h *handler) OnSecret(key string, secret *corev1.Secret) (*corev1.Secret, error) { + if secret == nil { + return nil, nil + } + + if secret.Namespace != h.namespace || + secret.Name != h.name { + return secret, nil + } + + url, caCert, token, restart, err := h.shouldRestart(secret) + if err != nil { + return secret, err + } + if !restart { + return secret, nil + } + + if h.cancel != nil { + logrus.Info("Restarting steve aggregation client") + h.cancel() + } else { + logrus.Info("Starting steve aggregation client") + } + + ctx, cancel := context.WithCancel(h.ctx) + go ListenAndServe(ctx, url, caCert, token, h.handler) + + h.url = url + h.caCert = caCert + h.token = token + h.cancel = cancel + + return secret, nil +} + +func (h *handler) shouldRestart(secret *corev1.Secret) (string, []byte, string, bool, error) { + url := string(secret.Data["url"]) + if url == "" { + return "", nil, "", false, nil + } + + token := string(secret.Data["token"]) + if token == "" { + return "", nil, "", false, nil + } + + caCert := secret.Data["ca.crt"] + + if h.url != url || + h.token != token || + bytes.Equal(h.caCert, caCert) { + return url, caCert, token, true, nil + } + + return "", nil, "", false, nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 5e2191c..db491be 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/rancher/apiserver/pkg/types" "github.com/rancher/dynamiclistener/server" "github.com/rancher/steve/pkg/accesscontrol" + "github.com/rancher/steve/pkg/aggregation" "github.com/rancher/steve/pkg/auth" "github.com/rancher/steve/pkg/client" "github.com/rancher/steve/pkg/clustercache" @@ -41,16 +42,21 @@ type Server struct { needControllerStart bool next http.Handler router router.RouterFunc + + aggregationSecretNamespace string + aggregationSecretName string } type Options struct { // Controllers If the controllers are passed in the caller must also start the controllers - Controllers *Controllers - ClientFactory *client.Factory - AccessSetLookup accesscontrol.AccessSetLookup - AuthMiddleware auth.Middleware - Next http.Handler - Router router.RouterFunc + Controllers *Controllers + ClientFactory *client.Factory + AccessSetLookup accesscontrol.AccessSetLookup + AuthMiddleware auth.Middleware + Next http.Handler + Router router.RouterFunc + AggregationSecretNamespace string + AggregationSecretName string } func New(ctx context.Context, restConfig *rest.Config, opts *Options) (*Server, error) { @@ -59,13 +65,15 @@ func New(ctx context.Context, restConfig *rest.Config, opts *Options) (*Server, } server := &Server{ - RESTConfig: restConfig, - ClientFactory: opts.ClientFactory, - AccessSetLookup: opts.AccessSetLookup, - authMiddleware: opts.AuthMiddleware, - controllers: opts.Controllers, - next: opts.Next, - router: opts.Router, + RESTConfig: restConfig, + ClientFactory: opts.ClientFactory, + AccessSetLookup: opts.AccessSetLookup, + authMiddleware: opts.AuthMiddleware, + controllers: opts.Controllers, + next: opts.Next, + router: opts.Router, + aggregationSecretNamespace: opts.AggregationSecretNamespace, + aggregationSecretName: opts.AggregationSecretName, } if err := setup(ctx, server); err != nil { @@ -156,6 +164,9 @@ func setup(ctx context.Context, server *Server) error { return err } + aggregation.Watch(ctx, server.controllers.Core.Secret(), server.aggregationSecretNamespace, + server.aggregationSecretName, handler) + server.APIServer = apiServer server.Handler = handler server.SchemaFactory = sf