Reload ZFSSA certificate and reset https client config when failed to verify certificate

This commit is contained in:
Helen Zhang 2024-05-20 14:20:42 -07:00
parent 0c8642ed58
commit d4631059de
2 changed files with 154 additions and 122 deletions

View File

@ -6,11 +6,11 @@
package service
import (
"github.com/oracle/zfssa-csi-driver/pkg/utils"
"github.com/oracle/zfssa-csi-driver/pkg/zfssarest"
"errors"
"fmt"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/oracle/zfssa-csi-driver/pkg/utils"
"github.com/oracle/zfssa-csi-driver/pkg/zfssarest"
"golang.org/x/net/context"
"google.golang.org/grpc"
"gopkg.in/yaml.v2"
@ -28,9 +28,9 @@ import (
const (
// Default Log Level
DefaultLogLevel = "3"
DefaultCertPath = "/mnt/certs/zfssa.crt"
DefaultCredPath = "/mnt/zfssa/zfssa.yaml"
DefaultLogLevel = "3"
DefaultCertPath = "/mnt/certs/zfssa.crt"
DefaultCredPath = "/mnt/zfssa/zfssa.yaml"
DefaultConfigPath = "/mnt/config/config.yaml"
)
@ -49,15 +49,15 @@ type ZFSSADriver struct {
}
type config struct {
Appliance string
User string
endpoint string
HostIp string
NodeName string
PodIp string
Secure bool
logLevel string
Certificate []byte
Appliance string
User string
endpoint string
HostIp string
NodeName string
PodIp string
Secure bool
logLevel string
Certificate []byte
CertLocation string
CredLocation string
}
@ -73,8 +73,8 @@ type accessType int
// NonBlocking server
type nonBlockingGRPCServer struct {
wg sync.WaitGroup
server *grpc.Server
wg sync.WaitGroup
server *grpc.Server
}
const (
@ -86,7 +86,7 @@ const (
Tib int64 = Gib * 1024
Tib100 int64 = Tib * 100
DefaultVolumeSizeBytes int64 = 50 * Gib
DefaultVolumeSizeBytes int64 = 50 * Gib
mountAccess accessType = iota
blockAccess
@ -94,7 +94,7 @@ const (
const (
UsernamePattern string = `^[a-zA-Z][a-zA-Z0-9_\-\.]*$`
UsernameLength int = 255
UsernameLength int = 255
)
type ZfssaBlockVolume struct {
@ -122,7 +122,7 @@ func NewZFSSADriver(driverName, version string) (*ZFSSADriver, error) {
utils.InitLogs(zd.config.logLevel, zd.name, version, zd.config.NodeName)
err = zfssarest.InitREST(zd.config.Appliance, zd.config.Certificate, zd.config.Secure)
err = zfssarest.InitREST(zd.config.Appliance, zd.config.CertLocation, zd.config.Secure)
if err != nil {
return nil, err
}
@ -175,8 +175,8 @@ func getConfig(zd *ZFSSADriver) error {
return errors.New(fmt.Sprintf("Cannot get ZFSSA username: %s", err))
}
// Get ZFSSA_TARGET from the mounted config file if available
zfssaHost, _ := utils.GetValueFromYAML(DefaultConfigPath,"ZFSSA_TARGET")
// Get ZFSSA_TARGET from the mounted config file if available
zfssaHost, _ := utils.GetValueFromYAML(DefaultConfigPath, "ZFSSA_TARGET")
appliance := getEnvFallback("ZFSSA_TARGET", zfssaHost)
zd.config.Appliance = strings.TrimSpace(appliance)
if zd.config.Appliance == "not-set" {
@ -189,7 +189,7 @@ func getConfig(zd *ZFSSADriver) error {
}
zd.config.endpoint = getEnvFallback("CSI_ENDPOINT", "")
if zd.config.endpoint == "" {
if zd.config.endpoint == "" {
return errors.New("endpoint is required")
} else {
if !strings.HasPrefix(zd.config.endpoint, "unix://") {
@ -204,8 +204,10 @@ func getConfig(zd *ZFSSADriver) error {
}
switch strings.ToLower(strings.TrimSpace(getEnvFallback("ZFSSA_INSECURE", "False"))) {
case "true": zd.config.Secure = false
case "false": zd.config.Secure = true
case "true":
zd.config.Secure = false
case "false":
zd.config.Secure = true
default:
return errors.New("ZFSSA_INSECURE value is invalid")
}
@ -219,6 +221,7 @@ func getConfig(zd *ZFSSADriver) error {
if os.IsNotExist(err) {
return errors.New("certificate does not exits")
}
zd.config.CertLocation = certfile
zd.config.Certificate, err = ioutil.ReadFile(certfile)
if err != nil {
return errors.New("failed to read certificate")
@ -238,7 +241,7 @@ func getConfig(zd *ZFSSADriver) error {
// Starts the CSI driver. This includes registering the different servers (Identity, Controller and Node) with
// the CSI framework and starting listening on the UNIX socket.
var sigList = []os.Signal {
var sigList = []os.Signal{
syscall.SIGTERM,
syscall.SIGHUP,
syscall.SIGINT,
@ -344,7 +347,7 @@ func (s *nonBlockingGRPCServer) serve(endpoint string,
return
}
opts := []grpc.ServerOption{ grpc.UnaryInterceptor(interceptorGRPC) }
opts := []grpc.ServerOption{grpc.UnaryInterceptor(interceptorGRPC)}
server := grpc.NewServer(opts...)
s.server = server

View File

@ -15,6 +15,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"sync"
"time"
@ -25,35 +26,35 @@ import (
// Use RESTapi v2 as it returns scriptable and consistent values
const (
zAppliance string = "https://%s:215"
zServices = zAppliance + "/api/access/v2"
zStorage = zAppliance + "/api/storage/v2"
zSan = zAppliance + "/api/san/v2"
zPools = zStorage + "/pools"
zPool = zPools + "/%s"
zAllProjects = zStorage + "/projects"
zProjects = zPool + "/projects"
zProject = zProjects + "/%s"
zAllFilesystems = zStorage + "/filesystems"
zFilesystems = zProject + "/filesystems"
zFilesystem = zFilesystems + "/%s"
zAllLUNs = zStorage + "/luns"
zLUNs = zProject + "/luns"
zLUN = zLUNs + "/%s"
zAllSnapshots = zStorage + "/snapshots"
zSnapshots = zProject + "/snapshots"
zSnapshot = zSnapshots + "/%s"
zFilesystemSnapshots = zFilesystem + "/snapshots"
zFilesystemSnapshot = zFilesystemSnapshots + "/%s"
zCloneFilesystemSnapshot = zFilesystemSnapshot + "/clone"
zLUNSnapshots = zLUN + "/snapshots"
zLUNSnapshot = zLUNSnapshots + "/%s"
zFilesystemDependents = zFilesystemSnapshot + "/dependents"
zLUNDependents = zLUNSnapshot + "/dependents"
zTargetGroups = zSan + "/%s/target-groups"
zTargetGroup = zTargetGroups + "/%s"
zProperties = zAppliance + "/api/storage/v2/schema"
zProperty = zProperties + "/%s"
zAppliance string = "https://%s:215"
zServices = zAppliance + "/api/access/v2"
zStorage = zAppliance + "/api/storage/v2"
zSan = zAppliance + "/api/san/v2"
zPools = zStorage + "/pools"
zPool = zPools + "/%s"
zAllProjects = zStorage + "/projects"
zProjects = zPool + "/projects"
zProject = zProjects + "/%s"
zAllFilesystems = zStorage + "/filesystems"
zFilesystems = zProject + "/filesystems"
zFilesystem = zFilesystems + "/%s"
zAllLUNs = zStorage + "/luns"
zLUNs = zProject + "/luns"
zLUN = zLUNs + "/%s"
zAllSnapshots = zStorage + "/snapshots"
zSnapshots = zProject + "/snapshots"
zSnapshot = zSnapshots + "/%s"
zFilesystemSnapshots = zFilesystem + "/snapshots"
zFilesystemSnapshot = zFilesystemSnapshots + "/%s"
zCloneFilesystemSnapshot = zFilesystemSnapshot + "/clone"
zLUNSnapshots = zLUN + "/snapshots"
zLUNSnapshot = zLUNSnapshots + "/%s"
zFilesystemDependents = zFilesystemSnapshot + "/dependents"
zLUNDependents = zLUNSnapshot + "/dependents"
zTargetGroups = zSan + "/%s/target-groups"
zTargetGroup = zTargetGroups + "/%s"
zProperties = zAppliance + "/api/storage/v2/schema"
zProperty = zProperties + "/%s"
)
const (
@ -63,68 +64,88 @@ const (
)
type Token struct {
Name string
cv *sync.Cond
mtx sync.Mutex
user string
password string
state int
xAuthSession string
xAuthName string
Name string
cv *sync.Cond
mtx sync.Mutex
user string
password string
state int
xAuthSession string
xAuthName string
}
type tokenList struct {
mtx sync.Mutex
list map[string]*Token
mtx sync.Mutex
list map[string]*Token
}
type faultInfo struct {
Message string `json:"message"`
Code int `json:"code"`
Name string `json:"Name"`
Message string `json:"message"`
Code int `json:"code"`
Name string `json:"Name"`
}
type faultResponse struct {
Fault faultInfo `json:"fault"`
}
var httpTransport = http.Transport{TLSClientConfig: &tls.Config{}}
var httpClient = &http.Client{Transport: &httpTransport}
var zServicesURL string
var zName string
var tokens tokenList
var httpTransport = http.Transport{TLSClientConfig: &tls.Config{}}
var httpClient = &http.Client{Transport: &httpTransport}
var zServicesURL string
var zName string
var tokens tokenList
var zfssaCertLocation string
// Initializes the ZFSSA REST API interface
//
func InitREST(name string, certs []byte, secure bool) error {
if secure {
// set TLSv1.2 for the minimum version of supporting TLS
httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
// Get the SystemCertPool, continue with an empty pool on error
httpTransport.TLSClientConfig.RootCAs, _ = x509.SystemCertPool()
if httpTransport.TLSClientConfig.RootCAs == nil {
httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
}
if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(certs); !ok {
return errors.New("failed to append the certificate")
}
}
func InitREST(name string, certLocation string, secure bool) error {
httpTransport.TLSClientConfig.InsecureSkipVerify = !secure
httpTransport.MaxConnsPerHost = 16
httpTransport.MaxIdleConnsPerHost = 16
httpTransport.IdleConnTimeout = 30 * time.Second
tokens.list = make(map[string]*Token)
zfssaCertLocation = certLocation
err := resetHttpTlsClient(nil)
if err != nil {
return err
}
zServicesURL = fmt.Sprintf(zServices, name)
zName = name
return nil
}
func resetHttpTlsClient(ctx context.Context) error {
if httpTransport.TLSClientConfig.InsecureSkipVerify {
utils.GetLogREST(ctx, 2).Println("resetHttpTransport skipped")
return nil
}
// set TLSv1.2 for the minimum version of supporting TLS
httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
// Get the SystemCertPool, continue with an empty pool on error
utils.GetLogREST(ctx, 2).Println("loading RootCAs")
httpTransport.TLSClientConfig.RootCAs, _ = x509.SystemCertPool()
if httpTransport.TLSClientConfig.RootCAs == nil {
httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
}
certs, err := ioutil.ReadFile(zfssaCertLocation)
if err != nil {
return errors.New("failed to read ZFSSA certificate")
}
if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(certs); !ok {
return errors.New("failed to append the certificate")
}
tokens.list = make(map[string]*Token)
utils.GetLogREST(ctx, 5).Println("resetHttpTransport done")
return nil
}
// Looks up a token context based on the user name passed in. If one doesn't exist
// yet, it is created.
func LookUpToken(ctx context.Context, user, password string) *Token {
@ -160,14 +181,14 @@ func LookUpToken(ctx context.Context, user, password string) *Token {
//
// The possible return values are:
//
// Code Message X-Auth-Session
// Code Message X-Auth-Session
//
// nil Valid
// codes.Internal "Failure getting token" ""
// codes.Internal "Failure creating token" ""
// nil Valid
// codes.Internal "Failure getting token" ""
// codes.Internal "Failure creating token" ""
//
// In case of failure, the message logged will provide more information
// as to where the problem occurred.
// In case of failure, the message logged will provide more information
// as to where the problem occurred.
func getToken(ctx context.Context, token *Token, previous *string) (string, error) {
token.mtx.Lock()
@ -223,7 +244,7 @@ func createZfssaSession(ctx context.Context, token *Token) (string, string, erro
httpReq, err := http.NewRequest("POST", zServicesURL, bytes.NewBuffer(nil))
if err != nil {
utils.GetLogREST(ctx,2).Println("Could not build a request to create a token",
utils.GetLogREST(ctx, 2).Println("Could not build a request to create a token",
"method", "POST", "url", zServicesURL, "error", err.Error())
return "", "", grpcStatus.Error(codes.Internal, "Failure creating token")
}
@ -233,15 +254,18 @@ func createZfssaSession(ctx context.Context, token *Token) (string, string, erro
httpRsp, err := httpClient.Do(httpReq)
if err != nil {
utils.GetLogREST(ctx,2).Println("Token creation failed in Do",
utils.GetLogREST(ctx, 2).Println("Token creation failed in Do",
"url", zServicesURL, "error", err.Error())
if strings.Contains(err.Error(), "failed to verify certificate") {
resetHttpTlsClient(ctx)
}
return "", "", grpcStatus.Error(codes.Internal, "Failure creating token")
}
defer httpRsp.Body.Close()
if httpRsp.StatusCode != http.StatusCreated {
utils.GetLogREST(ctx,2).Println("Token creation failed in ZFSSA",
utils.GetLogREST(ctx, 2).Println("Token creation failed in ZFSSA",
"url", zServicesURL, "StatusCode", httpRsp.StatusCode)
return "", "", grpcStatus.Error(codes.Internal, "Failure creating token")
}
@ -264,7 +288,7 @@ func MakeRequest(ctx context.Context, token *Token, method, url string, reqbody
func makeRequest(ctx context.Context, token *Token, method, url string, reqbody interface{}, status int,
rspbody interface{}) (interface{}, int, error) {
utils.GetLogREST(ctx,5).Println("MakeRequest to ZFSSA",
utils.GetLogREST(ctx, 5).Println("MakeRequest to ZFSSA",
"method", method, "url", url, "body", reqbody)
xAuthSession, err := getToken(ctx, token, nil)
@ -274,14 +298,14 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
reqjson, err := json.Marshal(reqbody)
if err != nil {
utils.GetLogREST(ctx,2).Println("json.Marshal call failed",
utils.GetLogREST(ctx, 2).Println("json.Marshal call failed",
"method", method, "url", url, "body", reqbody, "error", err.Error())
return nil, 0, grpcStatus.Error(codes.Unknown, "json.Marshal call failed")
}
reqhttp, err := http.NewRequest(method, url, bytes.NewBuffer(reqjson))
if err != nil {
utils.GetLogREST(ctx,2).Println("http.NewRequest call failed",
utils.GetLogREST(ctx, 2).Println("http.NewRequest call failed",
"method", method, "url", url, "body", reqbody, "error", err.Error())
return nil, 0, grpcStatus.Error(codes.Unknown, "http.NewRequest call failed")
}
@ -292,8 +316,15 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
rsphttp, err := httpClient.Do(reqhttp)
if err != nil {
utils.GetLogREST(ctx,2).Println("client.do call failed",
utils.GetLogREST(ctx, 2).Println("client.do call failed",
"method", method, "url", url, "error", err.Error())
if strings.Contains(err.Error(), "failed to verify certificate") {
utils.GetLogREST(ctx, 2).Println("mark token as invalid")
token.state = zfssaTokenInvalid
resetHttpTlsClient(ctx)
return nil, http.StatusUnauthorized, err
}
return nil, 0, grpcStatus.Error(codes.Unknown, "client.do call failed")
}
@ -306,23 +337,23 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
// read json http response
rspjson, err := ioutil.ReadAll(rsphttp.Body)
if err != nil {
utils.GetLogREST(ctx,2).Println("ioutil.ReadAll call failed",
utils.GetLogREST(ctx, 2).Println("ioutil.ReadAll call failed",
"method", method, "url", url, "code", rsphttp.StatusCode,
"status", rsphttp.Status, "error", err.Error())
return nil, rsphttp.StatusCode, grpcStatus.Error(codes.Unknown,"ioutil.ReadAll call failed")
return nil, rsphttp.StatusCode, grpcStatus.Error(codes.Unknown, "ioutil.ReadAll call failed")
}
if rsphttp.StatusCode == status {
if rspbody != nil {
err = json.Unmarshal(rspjson, rspbody)
if err != nil {
utils.GetLogREST(ctx,2).Println("json.Unmarshal call failed",
utils.GetLogREST(ctx, 2).Println("json.Unmarshal call failed",
"\nmethod", method, "\nurl", url, "\ncode", rsphttp.StatusCode,
"\nstatus", rsphttp.Status, "\nbody", rspjson, "\nerror", err)
return nil, rsphttp.StatusCode, grpcStatus.Error(codes.Unknown, "json.Unmarshal call failed")
}
}
utils.GetLogREST(ctx,5).Println("Successful response from ZFSSA",
utils.GetLogREST(ctx, 5).Println("Successful response from ZFSSA",
"method", method, "url", url, "result", rsphttp.StatusCode)
return rspbody, rsphttp.StatusCode, nil
}
@ -335,7 +366,7 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
}
// status code was not what the user expected, attempt to unpack
utils.GetLogREST(ctx,2).Println("MakeRequest to ZFSSA resulted in an unexpected status",
utils.GetLogREST(ctx, 2).Println("MakeRequest to ZFSSA resulted in an unexpected status",
"method", method, "url", url, "expected", status, "code", rsphttp.StatusCode,
"status", rsphttp.Status)
@ -343,7 +374,7 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
err = json.Unmarshal(rspjson, failure)
var responseString string
if err != nil {
utils.GetLogREST(ctx,2).Println("Failure from ZFSSA could not be un-marshalled",
utils.GetLogREST(ctx, 2).Println("Failure from ZFSSA could not be un-marshalled",
"method", method, "url", url, "code", rsphttp.StatusCode,
"status", rsphttp.Status, "body", rspjson)
responseString = string(rspjson)
@ -353,7 +384,7 @@ func makeRequest(ctx context.Context, token *Token, method, url string, reqbody
switch rsphttp.StatusCode {
case http.StatusNotFound:
err = grpcStatus.Errorf(codes.NotFound, "Resource not found on target appliance: %s", responseString)
err = grpcStatus.Errorf(codes.NotFound, "Resource not found on target appliance: %s", responseString)
default:
err = grpcStatus.Errorf(codes.Unknown, "Unknown Error Occurred on target appliance: %s", responseString)
}
@ -366,9 +397,9 @@ type services struct {
}
type Service struct {
Version string `json:"version"`
Name string `json:"name"`
URI string `json:"uri"`
Version string `json:"version"`
Name string `json:"name"`
URI string `json:"uri"`
}
func GetServices(ctx context.Context, token *Token) (*[]Service, error) {
@ -386,22 +417,20 @@ func GetServices(ctx context.Context, token *Token) (*[]Service, error) {
// Unmarshalling of a "List" structure. This structure is the ZFSSA response to
// the http request:
//
// GET /api/access/v1 HTTP/1.1
// Host: zfs-storage.example.com
// X-Auth-User: admin
// X-Auth-Key: password
//
// GET /api/access/v1 HTTP/1.1
// Host: zfs-storage.example.com
// X-Auth-User: admin
// X-Auth-Key: password
func (l *services) UnmarshalJSON(b []byte) error {
return zfssaUnmarshalList(b, &l.List)
}
// Unmarshalling of a List sent by the ZFSSA
//
func zfssaUnmarshalList(b []byte, l interface{}) error {
// 'b' starts and ends like this:
// {List:[{...},...,{...}]}
b = b[0:len(b) - 1]
b = b[0 : len(b)-1]
for i := 1; i < len(b); i++ {
if b[i] == '[' {
b = b[i:]