mirror of
https://github.com/rancher/rke.git
synced 2025-08-11 11:42:44 +00:00
Add support for SSH certificate authentication
This commit is contained in:
parent
74426ae713
commit
d8758c551f
@ -21,6 +21,7 @@ type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error
|
|||||||
type dialer struct {
|
type dialer struct {
|
||||||
signer ssh.Signer
|
signer ssh.Signer
|
||||||
sshKeyString string
|
sshKeyString string
|
||||||
|
sshCertString string
|
||||||
sshAddress string
|
sshAddress string
|
||||||
username string
|
username string
|
||||||
netConn string
|
netConn string
|
||||||
@ -51,6 +52,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
|||||||
sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port),
|
sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port),
|
||||||
username: h.BastionHost.User,
|
username: h.BastionHost.User,
|
||||||
sshKeyString: h.BastionHost.SSHKey,
|
sshKeyString: h.BastionHost.SSHKey,
|
||||||
|
sshCertString: h.BastionHost.SSHCert,
|
||||||
netConn: "tcp",
|
netConn: "tcp",
|
||||||
useSSHAgentAuth: h.SSHAgentAuth,
|
useSSHAgentAuth: h.SSHAgentAuth,
|
||||||
}
|
}
|
||||||
@ -60,6 +62,13 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bastionDialer.sshCertString == "" && len(h.BastionHost.SSHCertPath) > 0 {
|
||||||
|
bastionDialer.sshCertString, err = certificatePath(h.BastionHost.SSHCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,6 +77,7 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
|||||||
username: h.User,
|
username: h.User,
|
||||||
dockerSocket: h.DockerSocket,
|
dockerSocket: h.DockerSocket,
|
||||||
sshKeyString: h.SSHKey,
|
sshKeyString: h.SSHKey,
|
||||||
|
sshCertString: h.SSHCert,
|
||||||
netConn: "unix",
|
netConn: "unix",
|
||||||
useSSHAgentAuth: h.SSHAgentAuth,
|
useSSHAgentAuth: h.SSHAgentAuth,
|
||||||
bastionDialer: bastionDialer,
|
bastionDialer: bastionDialer,
|
||||||
@ -80,6 +90,12 @@ func newDialer(h *Host, kind string) (*dialer, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dialer.sshCertString == "" && len(h.SSHCertPath) > 0 {
|
||||||
|
dialer.sshCertString, err = certificatePath(h.SSHCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch kind {
|
switch kind {
|
||||||
@ -152,7 +168,7 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
|
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
|
||||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
|
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||||
}
|
}
|
||||||
@ -182,7 +198,7 @@ func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
|
func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
|
||||||
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.useSSHAgentAuth)
|
bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.sshCertString, d.bastionDialer.useSSHAgentAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err)
|
return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err)
|
||||||
}
|
}
|
||||||
@ -194,7 +210,7 @@ func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err)
|
return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err)
|
||||||
}
|
}
|
||||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth)
|
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshCertString, d.useSSHAgentAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err)
|
return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err)
|
||||||
}
|
}
|
||||||
@ -211,6 +227,7 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
|
|||||||
sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port),
|
sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port),
|
||||||
username: bastionHost.User,
|
username: bastionHost.User,
|
||||||
sshKeyString: bastionHost.SSHKey,
|
sshKeyString: bastionHost.SSHKey,
|
||||||
|
sshCertString: bastionHost.SSHCert,
|
||||||
netConn: "tcp",
|
netConn: "tcp",
|
||||||
useSSHAgentAuth: bastionHost.SSHAgentAuth,
|
useSSHAgentAuth: bastionHost.SSHAgentAuth,
|
||||||
}
|
}
|
||||||
@ -222,6 +239,15 @@ func BastionHostWrapTransport(bastionHost v3.BastionHost) (k8s.WrapTransport, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if bastionDialer.sshCertString == "" && len(bastionHost.SSHCertPath) > 0 {
|
||||||
|
var err error
|
||||||
|
bastionDialer.sshCertString, err = certificatePath(bastionHost.SSHCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return func(rt http.RoundTripper) http.RoundTripper {
|
return func(rt http.RoundTripper) http.RoundTripper {
|
||||||
if ht, ok := rt.(*http.Transport); ok {
|
if ht, ok := rt.(*http.Transport); ok {
|
||||||
|
@ -87,7 +87,7 @@ func parsePrivateKey(keyBuff string) (ssh.Signer, error) {
|
|||||||
return ssh.ParsePrivateKey([]byte(keyBuff))
|
return ssh.ParsePrivateKey([]byte(keyBuff))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
|
func getSSHConfig(username, sshPrivateKeyString string, sshCertificateString string, useAgentAuth bool) (*ssh.ClientConfig, error) {
|
||||||
config := &ssh.ClientConfig{
|
config := &ssh.ClientConfig{
|
||||||
User: username,
|
User: username,
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
@ -112,6 +112,22 @@ func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(sshCertificateString) > 0 {
|
||||||
|
key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshCertificateString))
|
||||||
|
if err != nil {
|
||||||
|
return config, fmt.Errorf("Unable to parse SSH certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := key.(*ssh.Certificate); !ok {
|
||||||
|
return config, fmt.Errorf("Unable to cast public key to SSH Certificate")
|
||||||
|
}
|
||||||
|
signer, err = ssh.NewCertSigner(key.(*ssh.Certificate), signer)
|
||||||
|
if err != nil {
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
@ -128,6 +144,17 @@ func privateKeyPath(sshKeyPath string) (string, error) {
|
|||||||
return string(buff), nil
|
return string(buff), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func certificatePath(sshCertPath string) (string, error) {
|
||||||
|
if sshCertPath[:2] == "~/" {
|
||||||
|
sshCertPath = filepath.Join(userHome(), sshCertPath[2:])
|
||||||
|
}
|
||||||
|
buff, err := ioutil.ReadFile(sshCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Error while reading SSH certificate file: %v", err)
|
||||||
|
}
|
||||||
|
return string(buff), nil
|
||||||
|
}
|
||||||
|
|
||||||
func userHome() string {
|
func userHome() string {
|
||||||
if home := os.Getenv("HOME"); home != "" {
|
if home := os.Getenv("HOME"); home != "" {
|
||||||
return home
|
return home
|
||||||
|
Loading…
Reference in New Issue
Block a user