mirror of
				https://github.com/k3s-io/kubernetes.git
				synced 2025-11-03 23:40:03 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			239 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			239 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) 2016 VMware, Inc. All Rights Reserved.
 | 
						|
//
 | 
						|
// This product is licensed to you under the Apache License, Version 2.0 (the "License").
 | 
						|
// You may not use this product except in compliance with the License.
 | 
						|
//
 | 
						|
// This product may include a number of subcomponents with separate copyright notices and
 | 
						|
// license terms. Your use of these subcomponents is subject to the terms and conditions
 | 
						|
// of the subcomponent's license, as noted in the LICENSE file.
 | 
						|
 | 
						|
package lightwave
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/tls"
 | 
						|
	"crypto/x509"
 | 
						|
	"encoding/json"
 | 
						|
	"encoding/pem"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"log"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
)
 | 
						|
 | 
						|
const tokenScope string = "openid offline_access"
 | 
						|
 | 
						|
type OIDCClient struct {
 | 
						|
	httpClient *http.Client
 | 
						|
	logger     *log.Logger
 | 
						|
 | 
						|
	Endpoint string
 | 
						|
	Options  *OIDCClientOptions
 | 
						|
}
 | 
						|
 | 
						|
type OIDCClientOptions struct {
 | 
						|
	// Whether or not to ignore any TLS errors when talking to photon,
 | 
						|
	// false by default.
 | 
						|
	IgnoreCertificate bool
 | 
						|
 | 
						|
	// List of root CA's to use for server validation
 | 
						|
	// nil by default.
 | 
						|
	RootCAs *x509.CertPool
 | 
						|
 | 
						|
	// The scope values to use when requesting tokens
 | 
						|
	TokenScope string
 | 
						|
}
 | 
						|
 | 
						|
func NewOIDCClient(endpoint string, options *OIDCClientOptions, logger *log.Logger) (c *OIDCClient) {
 | 
						|
	if logger == nil {
 | 
						|
		logger = log.New(ioutil.Discard, "", log.LstdFlags)
 | 
						|
	}
 | 
						|
 | 
						|
	options = buildOptions(options)
 | 
						|
	tr := &http.Transport{
 | 
						|
		TLSClientConfig: &tls.Config{
 | 
						|
			InsecureSkipVerify: options.IgnoreCertificate,
 | 
						|
			RootCAs:            options.RootCAs},
 | 
						|
	}
 | 
						|
 | 
						|
	c = &OIDCClient{
 | 
						|
		httpClient: &http.Client{Transport: tr},
 | 
						|
		logger:     logger,
 | 
						|
 | 
						|
		Endpoint: strings.TrimRight(endpoint, "/"),
 | 
						|
		Options:  options,
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func buildOptions(options *OIDCClientOptions) (result *OIDCClientOptions) {
 | 
						|
	result = &OIDCClientOptions{
 | 
						|
		TokenScope: tokenScope,
 | 
						|
	}
 | 
						|
 | 
						|
	if options == nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	result.IgnoreCertificate = options.IgnoreCertificate
 | 
						|
 | 
						|
	if options.RootCAs != nil {
 | 
						|
		result.RootCAs = options.RootCAs
 | 
						|
	}
 | 
						|
 | 
						|
	if options.TokenScope != "" {
 | 
						|
		result.TokenScope = options.TokenScope
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) buildUrl(path string) (url string) {
 | 
						|
	return fmt.Sprintf("%s%s", client.Endpoint, path)
 | 
						|
}
 | 
						|
 | 
						|
// Cert download helper
 | 
						|
 | 
						|
const certDownloadPath string = "/afd/vecs/ssl"
 | 
						|
 | 
						|
type lightWaveCert struct {
 | 
						|
	Value string `json:"encoded"`
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) GetRootCerts() (certList []*x509.Certificate, err error) {
 | 
						|
	// turn TLS verification off for
 | 
						|
	originalTr := client.httpClient.Transport
 | 
						|
	defer client.setTransport(originalTr)
 | 
						|
 | 
						|
	tr := &http.Transport{
 | 
						|
		TLSClientConfig: &tls.Config{
 | 
						|
			InsecureSkipVerify: true,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	client.setTransport(tr)
 | 
						|
 | 
						|
	// get the certs
 | 
						|
	resp, err := client.httpClient.Get(client.buildUrl(certDownloadPath))
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	defer resp.Body.Close()
 | 
						|
	if resp.StatusCode != 200 {
 | 
						|
		err = fmt.Errorf("Unexpected error retrieving auth server certs: %v %s", resp.StatusCode, resp.Status)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// parse the certs
 | 
						|
	certsData := &[]lightWaveCert{}
 | 
						|
	err = json.NewDecoder(resp.Body).Decode(certsData)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	certList = make([]*x509.Certificate, len(*certsData))
 | 
						|
	for idx, cert := range *certsData {
 | 
						|
		block, _ := pem.Decode([]byte(cert.Value))
 | 
						|
		if block == nil {
 | 
						|
			err = fmt.Errorf("Unexpected response format: %v", certsData)
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		decodedCert, err := x509.ParseCertificate(block.Bytes)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		certList[idx] = decodedCert
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) setTransport(tr http.RoundTripper) {
 | 
						|
	client.httpClient.Transport = tr
 | 
						|
}
 | 
						|
 | 
						|
// Toke request helpers
 | 
						|
 | 
						|
const tokenPath string = "/openidconnect/token"
 | 
						|
const passwordGrantFormatString = "grant_type=password&username=%s&password=%s&scope=%s"
 | 
						|
const refreshTokenGrantFormatString = "grant_type=refresh_token&refresh_token=%s"
 | 
						|
 | 
						|
type OIDCTokenResponse struct {
 | 
						|
	AccessToken  string `json:"access_token"`
 | 
						|
	ExpiresIn    int    `json:"expires_in"`
 | 
						|
	RefreshToken string `json:"refresh_token,omitempty"`
 | 
						|
	IdToken      string `json:"id_token"`
 | 
						|
	TokenType    string `json:"token_type"`
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) GetTokenByPasswordGrant(username string, password string) (tokens *OIDCTokenResponse, err error) {
 | 
						|
	username = url.QueryEscape(username)
 | 
						|
	password = url.QueryEscape(password)
 | 
						|
	body := fmt.Sprintf(passwordGrantFormatString, username, password, client.Options.TokenScope)
 | 
						|
	return client.getToken(body)
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) GetTokenByRefreshTokenGrant(refreshToken string) (tokens *OIDCTokenResponse, err error) {
 | 
						|
	body := fmt.Sprintf(refreshTokenGrantFormatString, refreshToken)
 | 
						|
	return client.getToken(body)
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) getToken(body string) (tokens *OIDCTokenResponse, err error) {
 | 
						|
	request, err := http.NewRequest("POST", client.buildUrl(tokenPath), strings.NewReader(body))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
 | 
						|
 | 
						|
	resp, err := client.httpClient.Do(request)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	defer resp.Body.Close()
 | 
						|
 | 
						|
	err = client.checkResponse(resp)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	tokens = &OIDCTokenResponse{}
 | 
						|
	err = json.NewDecoder(resp.Body).Decode(tokens)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
type OIDCError struct {
 | 
						|
	Code    string `json:"error"`
 | 
						|
	Message string `json:"error_description"`
 | 
						|
}
 | 
						|
 | 
						|
func (e OIDCError) Error() string {
 | 
						|
	return fmt.Sprintf("%v: %v", e.Code, e.Message)
 | 
						|
}
 | 
						|
 | 
						|
func (client *OIDCClient) checkResponse(response *http.Response) (err error) {
 | 
						|
	if response.StatusCode/100 == 2 {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	respBody, readErr := ioutil.ReadAll(response.Body)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf(
 | 
						|
			"Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
 | 
						|
	}
 | 
						|
 | 
						|
	var oidcErr OIDCError
 | 
						|
	err = json.Unmarshal(respBody, &oidcErr)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf(
 | 
						|
			"Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
 | 
						|
	}
 | 
						|
 | 
						|
	return oidcErr
 | 
						|
}
 |