mirror of
https://github.com/rancher/rke.git
synced 2025-08-01 07:08:38 +00:00
Add support for SSH certificate authentication
This commit is contained in:
parent
74426ae713
commit
d8758c551f
@ -21,6 +21,7 @@ type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error
|
||||
type dialer struct {
|
||||
signer ssh.Signer
|
||||
sshKeyString string
|
||||
sshCertString string
|
||||
sshAddress string
|
||||
username string
|
||||
netConn string
|
||||
@ -51,6 +52,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
||||
sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port),
|
||||
username: h.BastionHost.User,
|
||||
sshKeyString: h.BastionHost.SSHKey,
|
||||
sshCertString: h.BastionHost.SSHCert,
|
||||
netConn: "tcp",
|
||||
useSSHAgentAuth: h.SSHAgentAuth,
|
||||
}
|
||||
@ -60,6 +62,13 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bastionDialer.sshCertString == "" && len(h.BastionHost.SSHCertPath) > 0 {
|
||||
bastionDialer.sshCertString, err = certificatePath(h.BastionHost.SSHCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,6 +77,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
||||
username: h.User,
|
||||
dockerSocket: h.DockerSocket,
|
||||
sshKeyString: h.SSHKey,
|
||||
sshCertString: h.SSHCert,
|
||||
netConn: "unix",
|
||||
useSSHAgentAuth: h.SSHAgentAuth,
|
||||
bastionDialer: bastionDialer,
|
||||
@ -80,6 +90,12 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dialer.sshCertString == "" && len(h.SSHCertPath) > 0 {
|
||||
dialer.sshCertString, err = certificatePath(h.SSHCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
@ -152,7 +168,7 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
|
||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
|
||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||
}
|
||||
@ -182,7 +198,7 @@ func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error)
|
||||
}
|
||||
|
||||
func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
|
||||
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.useSSHAgentAuth)
|
||||
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.sshCertString, d.bastionDialer.useSSHAgentAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err)
|
||||
}
|
||||
@ -194,7 +210,7 @@ func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err)
|
||||
}
|
||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
|
||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err)
|
||||
}
|
||||
@ -211,6 +227,7 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
|
||||
sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port),
|
||||
username: bastionHost.User,
|
||||
sshKeyString: bastionHost.SSHKey,
|
||||
sshCertString: bastionHost.SSHCert,
|
||||
netConn: "tcp",
|
||||
useSSHAgentAuth: bastionHost.SSHAgentAuth,
|
||||
}
|
||||
@ -222,6 +239,15 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if bastionDialer.sshCertString == "" && len(bastionHost.SSHCertPath) > 0 {
|
||||
var err error
|
||||
bastionDialer.sshCertString, err = certificatePath(bastionHost.SSHCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
return func(rt http.RoundTripper) http.RoundTripper {
|
||||
if ht, ok := rt.(*http.Transport); ok {
|
||||
|
@ -87,7 +87,7 @@ func parsePrivateKey(keyBuff string) (ssh.Signer, error) {
|
||||
return ssh.ParsePrivateKey([]byte(keyBuff))
|
||||
}
|
||||
|
||||
func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
|
||||
func getSSHConfig(username, sshPrivateKeyString string, sshCertificateString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
@ -112,6 +112,22 @@ func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
|
||||
if len(sshCertificateString) > 0 {
|
||||
key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshCertificateString))
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("Unable to parse SSH certificate: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := key.(*ssh.Certificate); !ok {
|
||||
return config, fmt.Errorf("Unable to cast public key to SSH Certificate")
|
||||
}
|
||||
signer, err = ssh.NewCertSigner(key.(*ssh.Certificate), signer)
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
}
|
||||
|
||||
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
||||
|
||||
return config, nil
|
||||
@ -128,6 +144,17 @@ func privateKeyPath(sshKeyPath string) (string, error) {
|
||||
return string(buff), nil
|
||||
}
|
||||
|
||||
func certificatePath(sshCertPath string) (string, error) {
|
||||
if sshCertPath[:2] == "~/" {
|
||||
sshCertPath = filepath.Join(userHome(), sshCertPath[2:])
|
||||
}
|
||||
buff, err := ioutil.ReadFile(sshCertPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error while reading SSH certificate file: %v", err)
|
||||
}
|
||||
return string(buff), nil
|
||||
}
|
||||
|
||||
func userHome() string {
|
||||
if home := os.Getenv("HOME"); home != "" {
|
||||
return home
|
||||
|
Loading…
Reference in New Issue
Block a user