zfssa-csi-driver/pkg/service/service.go
2024-07-02 09:11:51 -06:00

400 lines
11 KiB
Go

/*
* Copyright (c) 2021, 2023, Oracle and/or its affiliates.
* Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
*/
package service
import (
"errors"
"fmt"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/oracle/zfssa-csi-driver/pkg/utils"
"github.com/oracle/zfssa-csi-driver/pkg/zfssarest"
"golang.org/x/net/context"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
"io/ioutil"
"net"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
const (
// Default Log Level
DefaultLogLevel = "3"
DefaultCertPath = "/mnt/certs/zfssa.crt"
DefaultCredPath = "/mnt/zfssa/zfssa.yaml"
DefaultConfigPath = "/mnt/config/config.yaml"
)
type ZFSSADriver struct {
name string
nodeID string
version string
endpoint string
config config
NodeMounter Mounter
vCache volumeHashTable
sCache snapshotHashTable
ns *csi.NodeServer
cs *csi.ControllerServer
is *csi.IdentityServer
}
type config struct {
Appliance string
User string
endpoint string
HostIp string
NodeName string
PodIp string
Secure bool
logLevel string
Certificate []byte
CertLocation string
CredLocation string
}
// The structured data in the ZFSSA credentials file
type ZfssaCredentials struct {
Username string `yaml:username`
Password string `yaml:password`
}
type accessType int
// NonBlocking server
type nonBlockingGRPCServer struct {
wg sync.WaitGroup
server *grpc.Server
}
const (
// Helpful size constants
Kib int64 = 1024
Mib int64 = Kib * 1024
Gib int64 = Mib * 1024
Gib100 int64 = Gib * 100
Tib int64 = Gib * 1024
Tib100 int64 = Tib * 100
DefaultVolumeSizeBytes int64 = 50 * Gib
mountAccess accessType = iota
blockAccess
)
const (
UsernamePattern string = `^[a-zA-Z][a-zA-Z0-9_\-\.]*$`
UsernameLength int = 255
)
type ZfssaBlockVolume struct {
VolName string `json:"volName"`
VolID string `json:"volID"`
VolSize int64 `json:"volSize"`
VolPath string `json:"volPath"`
VolAccessType accessType `json:"volAccessType"`
}
// Creates and returns a new ZFSSA driver structure.
func NewZFSSADriver(driverName, version string) (*ZFSSADriver, error) {
zd := new(ZFSSADriver)
zd.name = driverName
zd.version = version
err := getConfig(zd)
if err != nil {
return nil, err
}
zd.vCache.vHash = make(map[string]zVolumeInterface)
zd.sCache.sHash = make(map[string]*zSnapshot)
utils.InitLogs(zd.config.logLevel, zd.name, version, zd.config.NodeName)
err = zfssarest.InitREST(zd.config.Appliance, zd.config.CertLocation, zd.config.Secure)
if err != nil {
return nil, err
}
err = InitClusterInterface()
if err != nil {
return nil, err
}
zd.is = newZFSSAIdentityServer(zd)
zd.cs = newZFSSAControllerServer(zd)
zd.ns = NewZFSSANodeServer(zd)
return zd, nil
}
// Gets the configuration and sanity checks it. Several environment variables values
// are retrieved:
//
// ZFSSA_TARGET The name or IP address of the appliance.
// NODE_NAME The name of the node on which the container is running.
// NODE_ID The ID of the node on which the container is running.
// CSI_ENDPOINT Unix socket the CSI driver will be listening on.
// ZFSSA_INSECURE Boolean specifying whether an appliance certificate is not required.
// ZFSSA_CERT Path to the certificate file (defaults to "/mnt/certs/zfssa.crt")
// ZFSSA_CRED Path to the credential file (defaults to "/mnt/zfssa/zfssa.yaml")
// HOST_IP IP address of the node.
// POD_IP IP address of the pod.
// LOG_LEVEL Log level to apply.
//
// Verifies the credentials are in the ZFSSA_CRED yaml file, does not verify their
// correctness.
func getConfig(zd *ZFSSADriver) error {
// Validate the ZFSSA credentials are available
credfile := strings.TrimSpace(getEnvFallback("ZFSSA_CRED", DefaultCredPath))
if len(credfile) == 0 {
return errors.New(fmt.Sprintf("a ZFSSA credentials file location is required, current value: <%s>",
credfile))
}
zd.config.CredLocation = credfile
_, err := os.Stat(credfile)
if os.IsNotExist(err) {
return errors.New(fmt.Sprintf("the ZFSSA credentials file is not present at location: <%s>",
credfile))
}
// Get the user from the credentials file, this can be stored in the config file without a problem
zd.config.User, err = zd.GetUsernameFromCred()
if err != nil {
return errors.New(fmt.Sprintf("Cannot get ZFSSA username: %s", err))
}
// Get ZFSSA_TARGET from the mounted config file if available
zfssaHost, _ := utils.GetValueFromYAML(DefaultConfigPath, "ZFSSA_TARGET")
appliance := getEnvFallback("ZFSSA_TARGET", zfssaHost)
zd.config.Appliance = strings.TrimSpace(appliance)
if zd.config.Appliance == "not-set" {
return errors.New("appliance name required")
}
zd.config.NodeName = getEnvFallback("NODE_NAME", "")
if zd.config.NodeName == "" {
return errors.New("node name required")
}
zd.config.endpoint = getEnvFallback("CSI_ENDPOINT", "")
if zd.config.endpoint == "" {
return errors.New("endpoint is required")
} else {
if !strings.HasPrefix(zd.config.endpoint, "unix://") {
return errors.New("endpoint is invalid")
}
s := strings.SplitN(zd.config.endpoint, "://", 2)
zd.config.endpoint = "/" + s[1]
err := os.RemoveAll(zd.config.endpoint)
if err != nil && !os.IsNotExist(err) {
return errors.New("failed to remove endpoint path")
}
}
switch strings.ToLower(strings.TrimSpace(getEnvFallback("ZFSSA_INSECURE", "False"))) {
case "true":
zd.config.Secure = false
case "false":
zd.config.Secure = true
default:
return errors.New("ZFSSA_INSECURE value is invalid")
}
if zd.config.Secure {
certfile := strings.TrimSpace(getEnvFallback("ZFSSA_CERT", DefaultCertPath))
if len(certfile) == 0 {
return errors.New("a certificate is required")
}
_, err := os.Stat(certfile)
if os.IsNotExist(err) {
return errors.New("certificate does not exits")
}
zd.config.CertLocation = certfile
zd.config.Certificate, err = ioutil.ReadFile(certfile)
if err != nil {
return errors.New("failed to read certificate")
}
}
zd.config.HostIp = getEnvFallback("HOST_IP", "0.0.0.0")
zd.config.PodIp = getEnvFallback("POD_IP", "0.0.0.0")
zd.config.logLevel = getEnvFallback("LOG_LEVEL", DefaultLogLevel)
_, err = strconv.Atoi(zd.config.logLevel)
if err != nil {
return errors.New("invalid debug level")
}
return nil
}
// Starts the CSI driver. This includes registering the different servers (Identity, Controller and Node) with
// the CSI framework and starting listening on the UNIX socket.
var sigList = []os.Signal{
syscall.SIGTERM,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
}
// Retrieves just the username from a credential file (zd.config.CredLocation)
func (zd *ZFSSADriver) GetUsernameFromCred() (string, error) {
yamlData, err := ioutil.ReadFile(zd.config.CredLocation)
if err != nil {
return "", errors.New(fmt.Sprintf("the ZFSSA credentials file <%s> could not be read: <%s>",
zd.config.CredLocation, err))
}
var yamlConfig ZfssaCredentials
err = yaml.Unmarshal(yamlData, &yamlConfig)
if err != nil {
return "", errors.New(fmt.Sprintf("the ZFSSA credentials file <%s> could not be parsed: <%s>",
zd.config.CredLocation, err))
}
if !isUsernameValid(yamlConfig.Username) {
return "", errors.New(fmt.Sprintf("ZFSSA username is invalid: <%s>", yamlConfig.Username))
}
return yamlConfig.Username, nil
}
// Retrieves just the username from a credential file
func (zd *ZFSSADriver) GetPasswordFromCred() (string, error) {
yamlData, err := ioutil.ReadFile(zd.config.CredLocation)
if err != nil {
return "", errors.New(fmt.Sprintf("the ZFSSA credentials file <%s> could not be read: <%s>",
zd.config.CredLocation, err))
}
var yamlConfig ZfssaCredentials
err = yaml.Unmarshal(yamlData, &yamlConfig)
if err != nil {
return "", errors.New(fmt.Sprintf("the ZFSSA credentials file <%s> could not be parsed: <%s>",
zd.config.CredLocation, err))
}
return yamlConfig.Password, nil
}
func (zd *ZFSSADriver) Run() {
// Refresh current information
_ = zd.updateVolumeList(nil)
_ = zd.updateSnapshotList(nil)
// Create GRPC servers
s := new(nonBlockingGRPCServer)
sigChannel := make(chan os.Signal, 1)
signal.Notify(sigChannel, sigList...)
s.Start(zd.config.endpoint, *zd.is, *zd.cs, *zd.ns)
s.Wait(sigChannel)
s.Stop()
_ = os.RemoveAll(zd.config.endpoint)
}
func (s *nonBlockingGRPCServer) Start(endpoint string,
ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) {
s.wg.Add(1)
go s.serve(endpoint, ids, cs, ns)
}
func (s *nonBlockingGRPCServer) Wait(ch chan os.Signal) {
for sig := range ch {
switch sig {
case syscall.SIGTERM,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT:
utils.GetLogCSID(nil, 5).Println("Termination signal received", "signal", sig)
return
default:
utils.GetLogCSID(nil, 5).Println("Signal received", "signal", sig)
continue
}
}
}
func (s *nonBlockingGRPCServer) Stop() {
s.server.GracefulStop()
s.wg.Add(-1)
}
func (s *nonBlockingGRPCServer) ForceStop() {
s.server.Stop()
s.wg.Add(-1)
}
func (s *nonBlockingGRPCServer) serve(endpoint string,
ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) {
listener, err := net.Listen("unix", endpoint)
if err != nil {
utils.GetLogCSID(nil, 2).Println("Failed to listen", "error", err)
return
}
opts := []grpc.ServerOption{grpc.UnaryInterceptor(interceptorGRPC)}
server := grpc.NewServer(opts...)
s.server = server
csi.RegisterIdentityServer(server, ids)
csi.RegisterControllerServer(server, cs)
csi.RegisterNodeServer(server, ns)
utils.GetLogCSID(nil, 5).Println("Listening for connections", "address", endpoint)
err = server.Serve(listener)
if err != nil {
utils.GetLogCSID(nil, 2).Println("Serve returned with error", "error", err)
}
}
// Interceptor measuring the response time of the requests.
func interceptorGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// Get a new context with a list of loggers request specific.
newContext := utils.GetNewContext(ctx)
// Calls the handler
utils.GetLogCSID(newContext, 4).Println("Request submitted", "method:", info.FullMethod)
start := time.Now()
rsp, err := handler(newContext, req)
utils.GetLogCSID(newContext, 4).Println("Request completed", "method:", info.FullMethod,
"duration:", time.Since(start), "error", err)
return rsp, err
}
// A local GetEnv utility function
func getEnvFallback(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
// validate username
func isUsernameValid(username string) bool {
if len(username) == 0 || len(username) > UsernameLength {
return false
}
var validUsername = regexp.MustCompile(UsernamePattern).MatchString
return validUsername(username)
}