diff --git a/hosts/dialer.go b/hosts/dialer.go index ae986005..def6254f 100644 --- a/hosts/dialer.go +++ b/hosts/dialer.go @@ -21,6 +21,7 @@ type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error type dialer struct { signer ssh.Signer sshKeyString string + sshCertString string sshAddress string username string netConn string @@ -51,6 +52,7 @@ func newDialer(h *Host, kind string) (*dialer, error) { sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port), username: h.BastionHost.User, sshKeyString: h.BastionHost.SSHKey, + sshCertString: h.BastionHost.SSHCert, netConn: "tcp", useSSHAgentAuth: h.SSHAgentAuth, } @@ -60,6 +62,13 @@ func newDialer(h *Host, kind string) (*dialer, error) { if err != nil { return nil, err } + + if bastionDialer.sshCertString == "" && len(h.BastionHost.SSHCertPath) > 0 { + bastionDialer.sshCertString, err = certificatePath(h.BastionHost.SSHCertPath) + if err != nil { + return nil, err + } + } } } @@ -68,6 +77,7 @@ func newDialer(h *Host, kind string) (*dialer, error) { username: h.User, dockerSocket: h.DockerSocket, sshKeyString: h.SSHKey, + sshCertString: h.SSHCert, netConn: "unix", useSSHAgentAuth: h.SSHAgentAuth, bastionDialer: bastionDialer, @@ -80,6 +90,12 @@ func newDialer(h *Host, kind string) (*dialer, error) { return nil, err } + if dialer.sshCertString == "" && len(h.SSHCertPath) > 0 { + dialer.sshCertString, err = certificatePath(h.SSHCertPath) + if err != nil { + return nil, err + } + } } switch kind { @@ -152,7 +168,7 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) { } func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) { - cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth) + cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth) if err != nil { return nil, fmt.Errorf("Error configuring SSH: %v", err) } @@ -182,7 +198,7 @@ func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) } func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) { - bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.useSSHAgentAuth) + bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.sshCertString, d.bastionDialer.useSSHAgentAuth) if err != nil { return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err) } @@ -194,7 +210,7 @@ func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) { if err != nil { return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err) } - cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth) + cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth) if err != nil { return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err) } @@ -211,6 +227,7 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port), username: bastionHost.User, sshKeyString: bastionHost.SSHKey, + sshCertString: bastionHost.SSHCert, netConn: "tcp", useSSHAgentAuth: bastionHost.SSHAgentAuth, } @@ -222,6 +239,15 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er return nil, err } + } + + if bastionDialer.sshCertString == "" && len(bastionHost.SSHCertPath) > 0 { + var err error + bastionDialer.sshCertString, err = certificatePath(bastionHost.SSHCertPath) + if err != nil { + return nil, err + } + } return func(rt http.RoundTripper) http.RoundTripper { if ht, ok := rt.(*http.Transport); ok { diff --git a/hosts/tunnel.go b/hosts/tunnel.go index ec52a33f..660d169c 100644 --- a/hosts/tunnel.go +++ b/hosts/tunnel.go @@ -87,7 +87,7 @@ func parsePrivateKey(keyBuff string) (ssh.Signer, error) { return ssh.ParsePrivateKey([]byte(keyBuff)) } -func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh.ClientConfig, error) { +func getSSHConfig(username, sshPrivateKeyString string, sshCertificateString string, useAgentAuth bool) (*ssh.ClientConfig, error) { config := &ssh.ClientConfig{ User: username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -112,6 +112,22 @@ func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh if err != nil { return config, err } + + if len(sshCertificateString) > 0 { + key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshCertificateString)) + if err != nil { + return config, fmt.Errorf("Unable to parse SSH certificate: %v", err) + } + + if _, ok := key.(*ssh.Certificate); !ok { + return config, fmt.Errorf("Unable to cast public key to SSH Certificate") + } + signer, err = ssh.NewCertSigner(key.(*ssh.Certificate), signer) + if err != nil { + return config, err + } + } + config.Auth = append(config.Auth, ssh.PublicKeys(signer)) return config, nil @@ -128,6 +144,17 @@ func privateKeyPath(sshKeyPath string) (string, error) { return string(buff), nil } +func certificatePath(sshCertPath string) (string, error) { + if sshCertPath[:2] == "~/" { + sshCertPath = filepath.Join(userHome(), sshCertPath[2:]) + } + buff, err := ioutil.ReadFile(sshCertPath) + if err != nil { + return "", fmt.Errorf("Error while reading SSH certificate file: %v", err) + } + return string(buff), nil +} + func userHome() string { if home := os.Getenv("HOME"); home != "" { return home