1
0
mirror of https://github.com/rancher/rke.git synced 2025-08-01 07:08:38 +00:00

Add support for SSH certificate authentication

This commit is contained in:
Bernard Wagner 2019-01-28 15:54:00 +02:00 committed by Alena Prokharchyk
parent 74426ae713
commit d8758c551f
2 changed files with 57 additions and 4 deletions

View File

@ -21,6 +21,7 @@ type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error
type dialer struct {
signer ssh.Signer
sshKeyString string
sshCertString string
sshAddress string
username string
netConn string
@ -51,6 +52,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port),
username: h.BastionHost.User,
sshKeyString: h.BastionHost.SSHKey,
sshCertString: h.BastionHost.SSHCert,
netConn: "tcp",
useSSHAgentAuth: h.SSHAgentAuth,
}
@ -60,6 +62,13 @@ func newDialer(h *Host, kind string) (*dialer, error) {
if err != nil {
return nil, err
}
if bastionDialer.sshCertString == "" && len(h.BastionHost.SSHCertPath) > 0 {
bastionDialer.sshCertString, err = certificatePath(h.BastionHost.SSHCertPath)
if err != nil {
return nil, err
}
}
}
}
@ -68,6 +77,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
username: h.User,
dockerSocket: h.DockerSocket,
sshKeyString: h.SSHKey,
sshCertString: h.SSHCert,
netConn: "unix",
useSSHAgentAuth: h.SSHAgentAuth,
bastionDialer: bastionDialer,
@ -80,6 +90,12 @@ func newDialer(h *Host, kind string) (*dialer, error) {
return nil, err
}
if dialer.sshCertString == "" && len(h.SSHCertPath) > 0 {
dialer.sshCertString, err = certificatePath(h.SSHCertPath)
if err != nil {
return nil, err
}
}
}
switch kind {
@ -152,7 +168,7 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) {
}
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
if err != nil {
return nil, fmt.Errorf("Error configuring SSH: %v", err)
}
@ -182,7 +198,7 @@ func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error)
}
func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.useSSHAgentAuth)
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.sshCertString, d.bastionDialer.useSSHAgentAuth)
if err != nil {
return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err)
}
@ -194,7 +210,7 @@ func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
if err != nil {
return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err)
}
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
if err != nil {
return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err)
}
@ -211,6 +227,7 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port),
username: bastionHost.User,
sshKeyString: bastionHost.SSHKey,
sshCertString: bastionHost.SSHCert,
netConn: "tcp",
useSSHAgentAuth: bastionHost.SSHAgentAuth,
}
@ -222,6 +239,15 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
return nil, err
}
}
if bastionDialer.sshCertString == "" && len(bastionHost.SSHCertPath) > 0 {
var err error
bastionDialer.sshCertString, err = certificatePath(bastionHost.SSHCertPath)
if err != nil {
return nil, err
}
}
return func(rt http.RoundTripper) http.RoundTripper {
if ht, ok := rt.(*http.Transport); ok {

View File

@ -87,7 +87,7 @@ func parsePrivateKey(keyBuff string) (ssh.Signer, error) {
return ssh.ParsePrivateKey([]byte(keyBuff))
}
func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
func getSSHConfig(username, sshPrivateKeyString string, sshCertificateString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
config := &ssh.ClientConfig{
User: username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
@ -112,6 +112,22 @@ func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh
if err != nil {
return config, err
}
if len(sshCertificateString) > 0 {
key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshCertificateString))
if err != nil {
return config, fmt.Errorf("Unable to parse SSH certificate: %v", err)
}
if _, ok := key.(*ssh.Certificate); !ok {
return config, fmt.Errorf("Unable to cast public key to SSH Certificate")
}
signer, err = ssh.NewCertSigner(key.(*ssh.Certificate), signer)
if err != nil {
return config, err
}
}
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
return config, nil
@ -128,6 +144,17 @@ func privateKeyPath(sshKeyPath string) (string, error) {
return string(buff), nil
}
func certificatePath(sshCertPath string) (string, error) {
if sshCertPath[:2] == "~/" {
sshCertPath = filepath.Join(userHome(), sshCertPath[2:])
}
buff, err := ioutil.ReadFile(sshCertPath)
if err != nil {
return "", fmt.Errorf("Error while reading SSH certificate file: %v", err)
}
return string(buff), nil
}
func userHome() string {
if home := os.Getenv("HOME"); home != "" {
return home