Reload ZFSSA certificate and reset https client config when failed to verify certificate

This commit is contained in:
Helen Zhang 2024-05-20 14:20:42 -07:00
parent 0c8642ed58
commit d4631059de
2 changed files with 154 additions and 122 deletions

View File

@ -6,11 +6,11 @@
package service package service
import ( import (
"github.com/oracle/zfssa-csi-driver/pkg/utils"
"github.com/oracle/zfssa-csi-driver/pkg/zfssarest"
"errors" "errors"
"fmt" "fmt"
"github.com/container-storage-interface/spec/lib/go/csi" "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/oracle/zfssa-csi-driver/pkg/utils"
"github.com/oracle/zfssa-csi-driver/pkg/zfssarest"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -122,7 +122,7 @@ func NewZFSSADriver(driverName, version string) (*ZFSSADriver, error) {
utils.InitLogs(zd.config.logLevel, zd.name, version, zd.config.NodeName) utils.InitLogs(zd.config.logLevel, zd.name, version, zd.config.NodeName)
err = zfssarest.InitREST(zd.config.Appliance, zd.config.Certificate, zd.config.Secure) err = zfssarest.InitREST(zd.config.Appliance, zd.config.CertLocation, zd.config.Secure)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -204,8 +204,10 @@ func getConfig(zd *ZFSSADriver) error {
} }
switch strings.ToLower(strings.TrimSpace(getEnvFallback("ZFSSA_INSECURE", "False"))) { switch strings.ToLower(strings.TrimSpace(getEnvFallback("ZFSSA_INSECURE", "False"))) {
case "true": zd.config.Secure = false case "true":
case "false": zd.config.Secure = true zd.config.Secure = false
case "false":
zd.config.Secure = true
default: default:
return errors.New("ZFSSA_INSECURE value is invalid") return errors.New("ZFSSA_INSECURE value is invalid")
} }
@ -219,6 +221,7 @@ func getConfig(zd *ZFSSADriver) error {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return errors.New("certificate does not exits") return errors.New("certificate does not exits")
} }
zd.config.CertLocation = certfile
zd.config.Certificate, err = ioutil.ReadFile(certfile) zd.config.Certificate, err = ioutil.ReadFile(certfile)
if err != nil { if err != nil {
return errors.New("failed to read certificate") return errors.New("failed to read certificate")

View File

@ -15,6 +15,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@ -93,38 +94,58 @@ var httpClient = &http.Client{Transport: &httpTransport}
var zServicesURL string var zServicesURL string
var zName string var zName string
var tokens tokenList var tokens tokenList
var zfssaCertLocation string
// Initializes the ZFSSA REST API interface // Initializes the ZFSSA REST API interface
// func InitREST(name string, certLocation string, secure bool) error {
func InitREST(name string, certs []byte, secure bool) error {
if secure {
// set TLSv1.2 for the minimum version of supporting TLS
httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
// Get the SystemCertPool, continue with an empty pool on error
httpTransport.TLSClientConfig.RootCAs, _ = x509.SystemCertPool()
if httpTransport.TLSClientConfig.RootCAs == nil {
httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
}
if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(certs); !ok {
return errors.New("failed to append the certificate")
}
}
httpTransport.TLSClientConfig.InsecureSkipVerify = !secure httpTransport.TLSClientConfig.InsecureSkipVerify = !secure
httpTransport.MaxConnsPerHost = 16 httpTransport.MaxConnsPerHost = 16
httpTransport.MaxIdleConnsPerHost = 16 httpTransport.MaxIdleConnsPerHost = 16
httpTransport.IdleConnTimeout = 30 * time.Second httpTransport.IdleConnTimeout = 30 * time.Second
tokens.list = make(map[string]*Token) zfssaCertLocation = certLocation
err := resetHttpTlsClient(nil)
if err != nil {
return err
}
zServicesURL = fmt.Sprintf(zServices, name) zServicesURL = fmt.Sprintf(zServices, name)
zName = name zName = name
return nil return nil
} }
func resetHttpTlsClient(ctx context.Context) error {
if httpTransport.TLSClientConfig.InsecureSkipVerify {
utils.GetLogREST(ctx, 2).Println("resetHttpTransport skipped")
return nil
}
// set TLSv1.2 for the minimum version of supporting TLS
httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
// Get the SystemCertPool, continue with an empty pool on error
utils.GetLogREST(ctx, 2).Println("loading RootCAs")
httpTransport.TLSClientConfig.RootCAs, _ = x509.SystemCertPool()
if httpTransport.TLSClientConfig.RootCAs == nil {
httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
}
certs, err := ioutil.ReadFile(zfssaCertLocation)
if err != nil {
return errors.New("failed to read ZFSSA certificate")
}
if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(certs); !ok {
return errors.New("failed to append the certificate")
}
tokens.list = make(map[string]*Token)
utils.GetLogREST(ctx, 5).Println("resetHttpTransport done")
return nil
}
// Looks up a token context based on the user name passed in. If one doesn't exist // Looks up a token context based on the user name passed in. If one doesn't exist
// yet, it is created. // yet, it is created.
func LookUpToken(ctx context.Context, user, password string) *Token { func LookUpToken(ctx context.Context, user, password string) *Token {
@ -235,6 +256,9 @@ func createZfssaSession(ctx context.Context, token *Token) (string, string, erro
if err != nil { if err != nil {
utils.GetLogREST(ctx, 2).Println("Token creation failed in Do", utils.GetLogREST(ctx, 2).Println("Token creation failed in Do",
"url", zServicesURL, "error", err.Error()) "url", zServicesURL, "error", err.Error())
if strings.Contains(err.Error(), "failed to verify certificate") {
resetHttpTlsClient(ctx)
}
return "", "", grpcStatus.Error(codes.Internal, "Failure creating token") return "", "", grpcStatus.Error(codes.Internal, "Failure creating token")
} }
@ -294,6 +318,13 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
if err != nil { if err != nil {
utils.GetLogREST(ctx, 2).Println("client.do call failed", utils.GetLogREST(ctx, 2).Println("client.do call failed",
"method", method, "url", url, "error", err.Error()) "method", method, "url", url, "error", err.Error())
if strings.Contains(err.Error(), "failed to verify certificate") {
utils.GetLogREST(ctx, 2).Println("mark token as invalid")
token.state = zfssaTokenInvalid
resetHttpTlsClient(ctx)
return nil, http.StatusUnauthorized, err
}
return nil, 0, grpcStatus.Error(codes.Unknown, "client.do call failed") return nil, 0, grpcStatus.Error(codes.Unknown, "client.do call failed")
} }
@ -390,13 +421,11 @@ func GetServices(ctx context.Context, token *Token) (*[]Service, error) {
// Host: zfs-storage.example.com // Host: zfs-storage.example.com
// X-Auth-User: admin // X-Auth-User: admin
// X-Auth-Key: password // X-Auth-Key: password
//
func (l *services) UnmarshalJSON(b []byte) error { func (l *services) UnmarshalJSON(b []byte) error {
return zfssaUnmarshalList(b, &l.List) return zfssaUnmarshalList(b, &l.List)
} }
// Unmarshalling of a List sent by the ZFSSA // Unmarshalling of a List sent by the ZFSSA
//
func zfssaUnmarshalList(b []byte, l interface{}) error { func zfssaUnmarshalList(b []byte, l interface{}) error {
// 'b' starts and ends like this: // 'b' starts and ends like this: