Merge pull request #97069 from pohly/embedded-csi-mock-driver

e2e storage: embedded csi mock driver
This commit is contained in:
Kubernetes Prow Robot 2021-03-03 12:11:50 -08:00 committed by GitHub
commit 076bd6c401
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 3542 additions and 205 deletions

1
go.mod
View File

@ -52,6 +52,7 @@ require (
github.com/gogo/protobuf v1.3.2
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
github.com/golang/mock v1.4.4
github.com/golang/protobuf v1.4.3
github.com/google/cadvisor v0.38.8
github.com/google/go-cmp v0.5.2
github.com/google/gofuzz v1.1.0

View File

@ -19,15 +19,16 @@ package storage
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"sync/atomic"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1 "k8s.io/api/core/v1"
storagev1 "k8s.io/api/storage/v1"
storagev1alpha1 "k8s.io/api/storage/v1alpha1"
@ -75,13 +76,6 @@ const (
// How log to wait for kubelet to unstage a volume after a pod is deleted
csiUnstageWaitTimeout = 1 * time.Minute
// Name of CSI driver pod name (it's in a StatefulSet with a stable name)
driverPodName = "csi-mockplugin-0"
// Name of CSI driver container name
driverContainerName = "mock"
// Prefix of the mock driver grpc log
grpcCallPrefix = "gRPCCall:"
)
// csiCall represents an expected call from Kubernetes to CSI mock driver and
@ -113,7 +107,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
// just disable resizing on driver it overrides enableResizing flag for CSI mock driver
disableResizingOnDriver bool
enableSnapshot bool
javascriptHooks map[string]string
hooks *drivers.Hooks
tokenRequests []storagev1.TokenRequest
requiresRepublish *bool
fsGroupPolicy *storagev1.FSGroupPolicy
@ -127,7 +121,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
pvcs []*v1.PersistentVolumeClaim
sc map[string]*storagev1.StorageClass
vsc map[string]*unstructured.Unstructured
driver storageframework.TestDriver
driver drivers.MockCSITestDriver
provisioner string
tp testParameters
}
@ -155,12 +149,29 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
EnableResizing: tp.enableResizing,
EnableNodeExpansion: tp.enableNodeExpansion,
EnableSnapshot: tp.enableSnapshot,
JavascriptHooks: tp.javascriptHooks,
TokenRequests: tp.tokenRequests,
RequiresRepublish: tp.requiresRepublish,
FSGroupPolicy: tp.fsGroupPolicy,
}
// At the moment, only tests which need hooks are
// using the embedded CSI mock driver. The rest run
// the driver inside the cluster although they could
// changed to use embedding merely by setting
// driverOpts.embedded to true.
//
// Not enabling it for all tests minimizes
// the risk that the introduction of embedded breaks
// some existings tests and avoids a dependency
// on port forwarding, which is important if some of
// these tests are supposed to become part of
// conformance testing (port forwarding isn't
// currently required).
if tp.hooks != nil {
driverOpts.Embedded = true
driverOpts.Hooks = *tp.hooks
}
// this just disable resizing on driver, keeping resizing on SC enabled.
if tp.disableResizingOnDriver {
driverOpts.EnableResizing = false
@ -188,10 +199,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
createPod := func(ephemeral bool) (class *storagev1.StorageClass, claim *v1.PersistentVolumeClaim, pod *v1.Pod) {
ginkgo.By("Creating pod")
var sc *storagev1.StorageClass
if dDriver, ok := m.driver.(storageframework.DynamicPVTestDriver); ok {
sc = dDriver.GetDynamicProvisionStorageClass(m.config, "")
}
sc := m.driver.GetDynamicProvisionStorageClass(m.config, "")
scTest := testsuites.StorageClassTest{
Name: m.driver.GetDriverInfo().Name,
Timeouts: f.Timeouts,
@ -237,10 +245,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
createPodWithFSGroup := func(fsGroup *int64) (*storagev1.StorageClass, *v1.PersistentVolumeClaim, *v1.Pod) {
ginkgo.By("Creating pod with fsGroup")
nodeSelection := m.config.ClientNodeSelection
var sc *storagev1.StorageClass
if dDriver, ok := m.driver.(storageframework.DynamicPVTestDriver); ok {
sc = dDriver.GetDynamicProvisionStorageClass(m.config, "")
}
sc := m.driver.GetDynamicProvisionStorageClass(m.config, "")
scTest := testsuites.StorageClassTest{
Name: m.driver.GetDriverInfo().Name,
Provisioner: sc.Provisioner,
@ -514,7 +519,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
framework.ExpectNoError(err, "while deleting")
ginkgo.By("Checking CSI driver logs")
err = checkPodLogs(m.cs, m.config.DriverNamespace.Name, driverPodName, driverContainerName, pod, test.expectPodInfo, test.expectEphemeral, csiInlineVolumesEnabled, false, 1)
err = checkPodLogs(m.driver.GetCalls, pod, test.expectPodInfo, test.expectEphemeral, csiInlineVolumesEnabled, false, 1)
framework.ExpectNoError(err)
})
}
@ -727,19 +732,19 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
})
ginkgo.Context("CSI NodeStage error cases [Slow]", func() {
// Global variable in all scripts (called before each test)
globalScript := `counter=0; console.log("globals loaded", OK, INVALIDARGUMENT)`
trackedCalls := []string{
"NodeStageVolume",
"NodeUnstageVolume",
}
tests := []struct {
name string
expectPodRunning bool
expectedCalls []csiCall
nodeStageScript string
nodeUnstageScript string
name string
expectPodRunning bool
expectedCalls []csiCall
// Called for each NodeStateVolume calls, with counter incremented atomically before
// the invocation (i.e. first value will be 1).
nodeStageHook func(counter int64) error
}{
{
// This is already tested elsewhere, adding simple good case here to test the test framework.
@ -749,7 +754,6 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
{expectedMethod: "NodeStageVolume", expectedError: codes.OK, deletePod: true},
{expectedMethod: "NodeUnstageVolume", expectedError: codes.OK},
},
nodeStageScript: `OK;`,
},
{
// Kubelet should repeat NodeStage as long as the pod exists
@ -762,7 +766,12 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
{expectedMethod: "NodeUnstageVolume", expectedError: codes.OK},
},
// Fail first 3 NodeStage requests, 4th succeeds
nodeStageScript: `console.log("Counter:", ++counter); if (counter < 4) { INVALIDARGUMENT; } else { OK; }`,
nodeStageHook: func(counter int64) error {
if counter < 4 {
return status.Error(codes.InvalidArgument, "fake error")
}
return nil
},
},
{
// Kubelet should repeat NodeStage as long as the pod exists
@ -775,7 +784,12 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
{expectedMethod: "NodeUnstageVolume", expectedError: codes.OK},
},
// Fail first 3 NodeStage requests, 4th succeeds
nodeStageScript: `console.log("Counter:", ++counter); if (counter < 4) { DEADLINEEXCEEDED; } else { OK; }`,
nodeStageHook: func(counter int64) error {
if counter < 4 {
return status.Error(codes.DeadlineExceeded, "fake error")
}
return nil
},
},
{
// After NodeUnstage with ephemeral error, the driver may continue staging the volume.
@ -789,7 +803,9 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
{expectedMethod: "NodeStageVolume", expectedError: codes.DeadlineExceeded, deletePod: true},
{expectedMethod: "NodeUnstageVolume", expectedError: codes.OK},
},
nodeStageScript: `DEADLINEEXCEEDED;`,
nodeStageHook: func(counter int64) error {
return status.Error(codes.DeadlineExceeded, "fake error")
},
},
{
// After NodeUnstage with final error, kubelet can be sure the volume is not staged.
@ -801,21 +817,23 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
// This matches all repeated NodeStage calls with InvalidArgument error (due to exp. backoff).
{expectedMethod: "NodeStageVolume", expectedError: codes.InvalidArgument, deletePod: true},
},
nodeStageScript: `INVALIDARGUMENT;`,
// nodeStageScript: `INVALIDARGUMENT;`,
nodeStageHook: func(counter int64) error {
return status.Error(codes.InvalidArgument, "fake error")
},
},
}
for _, t := range tests {
test := t
ginkgo.It(test.name, func() {
scripts := map[string]string{
"globals": globalScript,
"nodeStageVolumeStart": test.nodeStageScript,
"nodeUnstageVolumeStart": test.nodeUnstageScript,
var hooks *drivers.Hooks
if test.nodeStageHook != nil {
hooks = createPreHook("NodeStageVolume", test.nodeStageHook)
}
init(testParameters{
disableAttach: true,
registerDriver: true,
javascriptHooks: scripts,
disableAttach: true,
registerDriver: true,
hooks: hooks,
})
defer cleanup()
@ -836,7 +854,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
framework.Failf("timed out waiting for the CSI call that indicates that the pod can be deleted: %v", test.expectedCalls)
}
time.Sleep(1 * time.Second)
_, index, err := compareCSICalls(trackedCalls, test.expectedCalls, m.cs, m.config.DriverNamespace.Name, driverPodName, driverContainerName)
_, index, err := compareCSICalls(trackedCalls, test.expectedCalls, m.driver.GetCalls)
framework.ExpectNoError(err, "while waiting for initial CSI calls")
if index == 0 {
// No CSI call received yet
@ -860,7 +878,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
ginkgo.By("Waiting for all remaining expected CSI calls")
err = wait.Poll(time.Second, csiUnstageWaitTimeout, func() (done bool, err error) {
_, index, err := compareCSICalls(trackedCalls, test.expectedCalls, m.cs, m.config.DriverNamespace.Name, driverPodName, driverContainerName)
_, index, err := compareCSICalls(trackedCalls, test.expectedCalls, m.driver.GetCalls)
if err != nil {
return true, err
}
@ -909,9 +927,9 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
createVolume := "CreateVolume"
deleteVolume := "DeleteVolume"
// publishVolume := "NodePublishVolume"
unpublishVolume := "NodeUnpublishVolume"
// unpublishVolume := "NodeUnpublishVolume"
// stageVolume := "NodeStageVolume"
unstageVolume := "NodeUnstageVolume"
// unstageVolume := "NodeUnstageVolume"
// These calls are assumed to occur in this order for
// each test run. NodeStageVolume and
@ -921,12 +939,17 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
// (https://github.com/kubernetes/kubernetes/issues/90250).
// Therefore they are temporarily commented out until
// that issue is resolved.
//
// NodeUnpublishVolume and NodeUnstageVolume are racing
// with DeleteVolume, so we cannot assume a deterministic
// order and have to ignore them
// (https://github.com/kubernetes/kubernetes/issues/94108).
deterministicCalls := []string{
createVolume,
// stageVolume,
// publishVolume,
unpublishVolume,
unstageVolume,
// unpublishVolume,
// unstageVolume,
deleteVolume,
}
@ -946,11 +969,12 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
}
if test.resourceExhausted {
params.javascriptHooks = map[string]string{
"globals": `counter=0; console.log("globals loaded", OK, INVALIDARGUMENT)`,
// Every second call returns RESOURCEEXHAUSTED, starting with the first one.
"createVolumeStart": `console.log("Counter:", ++counter); if (counter % 2) { RESOURCEEXHAUSTED; } else { OK; }`,
}
params.hooks = createPreHook("CreateVolume", func(counter int64) error {
if counter%2 != 0 {
return status.Error(codes.ResourceExhausted, "fake error")
}
return nil
})
}
init(params)
@ -1006,9 +1030,9 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
expected = append(expected, normal...)
}
var calls []mockCSICall
var calls []drivers.MockCSICall
err = wait.PollImmediateUntil(time.Second, func() (done bool, err error) {
c, index, err := compareCSICalls(deterministicCalls, expected, m.cs, m.config.DriverNamespace.Name, driverPodName, driverContainerName)
c, index, err := compareCSICalls(deterministicCalls, expected, m.driver.GetCalls)
if err != nil {
return true, fmt.Errorf("error waiting for expected CSI calls: %s", err)
}
@ -1221,31 +1245,32 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
})
ginkgo.Context("CSI Volume Snapshots [Feature:VolumeSnapshotDataSource]", func() {
// Global variable in all scripts (called before each test)
globalScript := `counter=0; console.log("globals loaded", OK, DEADLINEEXCEEDED)`
tests := []struct {
name string
createVolumeScript string
createSnapshotScript string
name string
createSnapshotHook func(counter int64) error
}{
{
name: "volumesnapshotcontent and pvc in Bound state with deletion timestamp set should not get deleted while snapshot finalizer exists",
createVolumeScript: `OK`,
createSnapshotScript: `console.log("Counter:", ++counter); if (counter < 8) { DEADLINEEXCEEDED; } else { OK; }`,
name: "volumesnapshotcontent and pvc in Bound state with deletion timestamp set should not get deleted while snapshot finalizer exists",
createSnapshotHook: func(counter int64) error {
if counter < 8 {
return status.Error(codes.DeadlineExceeded, "fake error")
}
return nil
},
},
}
for _, test := range tests {
test := test
ginkgo.It(test.name, func() {
scripts := map[string]string{
"globals": globalScript,
"createVolumeStart": test.createVolumeScript,
"createSnapshotStart": test.createSnapshotScript,
var hooks *drivers.Hooks
if test.createSnapshotHook != nil {
hooks = createPreHook("CreateSnapshot", test.createSnapshotHook)
}
init(testParameters{
disableAttach: true,
registerDriver: true,
enableSnapshot: true,
javascriptHooks: scripts,
disableAttach: true,
registerDriver: true,
enableSnapshot: true,
hooks: hooks,
})
sDriver, ok := m.driver.(storageframework.SnapshottableTestDriver)
if !ok {
@ -1256,10 +1281,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
defer cancel()
defer cleanup()
var sc *storagev1.StorageClass
if dDriver, ok := m.driver.(storageframework.DynamicPVTestDriver); ok {
sc = dDriver.GetDynamicProvisionStorageClass(m.config, "")
}
sc := m.driver.GetDynamicProvisionStorageClass(m.config, "")
ginkgo.By("Creating storage class")
class, err := m.cs.StorageV1().StorageClasses().Create(context.TODO(), sc, metav1.CreateOptions{})
framework.ExpectNoError(err, "Failed to create class: %v", err)
@ -1402,7 +1424,7 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
framework.ExpectNoError(err, "while deleting")
ginkgo.By("Checking CSI driver logs")
err = checkPodLogs(m.cs, m.config.DriverNamespace.Name, driverPodName, driverContainerName, pod, false, false, false, test.deployCSIDriverObject && csiServiceAccountTokenEnabled, numNodePublishVolume)
err = checkPodLogs(m.driver.GetCalls, pod, false, false, false, test.deployCSIDriverObject && csiServiceAccountTokenEnabled, numNodePublishVolume)
framework.ExpectNoError(err)
})
}
@ -1507,33 +1529,30 @@ var _ = utils.SIGDescribe("CSI mock volume", func() {
annotations interface{}
)
// Global variable in all scripts (called before each test)
globalScript := `counter=0; console.log("globals loaded", OK, DEADLINEEXCEEDED)`
tests := []struct {
name string
createVolumeScript string
createSnapshotScript string
name string
createSnapshotHook func(counter int64) error
}{
{
// volume snapshot should be created using secrets successfully even if there is a failure in the first few attempts,
name: "volume snapshot create/delete with secrets",
createVolumeScript: `OK`,
name: "volume snapshot create/delete with secrets",
// Fail the first 8 calls to create snapshot and succeed the 9th call.
createSnapshotScript: `console.log("Counter:", ++counter); if (counter < 8) { DEADLINEEXCEEDED; } else { OK; }`,
createSnapshotHook: func(counter int64) error {
if counter < 8 {
return status.Error(codes.DeadlineExceeded, "fake error")
}
return nil
},
},
}
for _, test := range tests {
ginkgo.It(test.name, func() {
scripts := map[string]string{
"globals": globalScript,
"createVolumeStart": test.createVolumeScript,
"createSnapshotStart": test.createSnapshotScript,
}
hooks := createPreHook("CreateSnapshot", test.createSnapshotHook)
init(testParameters{
disableAttach: true,
registerDriver: true,
enableSnapshot: true,
javascriptHooks: scripts,
disableAttach: true,
registerDriver: true,
enableSnapshot: true,
hooks: hooks,
})
sDriver, ok := m.driver.(storageframework.SnapshottableTestDriver)
@ -1895,24 +1914,9 @@ func startBusyBoxPodWithVolumeSource(cs clientset.Interface, volumeSource v1.Vol
return cs.CoreV1().Pods(ns).Create(context.TODO(), pod, metav1.CreateOptions{})
}
// Dummy structure that parses just volume_attributes and error code out of logged CSI call
type mockCSICall struct {
json string // full log entry
Method string
Request struct {
VolumeContext map[string]string `json:"volume_context"`
}
FullError struct {
Code codes.Code `json:"code"`
Message string `json:"message"`
}
Error string
}
// checkPodLogs tests that NodePublish was called with expected volume_context and (for ephemeral inline volumes)
// has the matching NodeUnpublish
func checkPodLogs(cs clientset.Interface, namespace, driverPodName, driverContainerName string, pod *v1.Pod, expectPodInfo, ephemeralVolume, csiInlineVolumesEnabled, csiServiceAccountTokenEnabled bool, expectedNumNodePublish int) error {
func checkPodLogs(getCalls func() ([]drivers.MockCSICall, error), pod *v1.Pod, expectPodInfo, ephemeralVolume, csiInlineVolumesEnabled, csiServiceAccountTokenEnabled bool, expectedNumNodePublish int) error {
expectedAttributes := map[string]string{}
if expectPodInfo {
expectedAttributes["csi.storage.k8s.io/pod.name"] = pod.Name
@ -1934,10 +1938,11 @@ func checkPodLogs(cs clientset.Interface, namespace, driverPodName, driverContai
foundAttributes := sets.NewString()
numNodePublishVolume := 0
numNodeUnpublishVolume := 0
calls, err := parseMockLogs(cs, namespace, driverPodName, driverContainerName)
calls, err := getCalls()
if err != nil {
return err
}
for _, call := range calls {
switch call.Method {
case "NodePublishVolume":
@ -1970,39 +1975,6 @@ func checkPodLogs(cs clientset.Interface, namespace, driverPodName, driverContai
return nil
}
func parseMockLogs(cs clientset.Interface, namespace, driverPodName, driverContainerName string) ([]mockCSICall, error) {
// Load logs of driver pod
log, err := e2epod.GetPodLogs(cs, namespace, driverPodName, driverContainerName)
if err != nil {
return nil, fmt.Errorf("could not load CSI driver logs: %s", err)
}
logLines := strings.Split(log, "\n")
var calls []mockCSICall
for _, line := range logLines {
index := strings.Index(line, grpcCallPrefix)
if index == -1 {
continue
}
line = line[index+len(grpcCallPrefix):]
call := mockCSICall{
json: string(line),
}
err := json.Unmarshal([]byte(line), &call)
if err != nil {
framework.Logf("Could not parse CSI driver log line %q: %s", line, err)
continue
}
// Trim gRPC service name, i.e. "/csi.v1.Identity/Probe" -> "Probe"
methodParts := strings.Split(call.Method, "/")
call.Method = methodParts[len(methodParts)-1]
calls = append(calls, call)
}
return calls, nil
}
// compareCSICalls compares expectedCalls with logs of the mock driver.
// It returns index of the first expectedCall that was *not* received
// yet or error when calls do not match.
@ -2011,8 +1983,8 @@ func parseMockLogs(cs clientset.Interface, namespace, driverPodName, driverConta
//
// Only permanent errors are returned. Other errors are logged and no
// calls are returned. The caller is expected to retry.
func compareCSICalls(trackedCalls []string, expectedCallSequence []csiCall, cs clientset.Interface, namespace, driverPodName, driverContainerName string) ([]mockCSICall, int, error) {
allCalls, err := parseMockLogs(cs, namespace, driverPodName, driverContainerName)
func compareCSICalls(trackedCalls []string, expectedCallSequence []csiCall, getCalls func() ([]drivers.MockCSICall, error)) ([]drivers.MockCSICall, int, error) {
allCalls, err := getCalls()
if err != nil {
framework.Logf("intermittent (?) log retrieval error, proceeding without output: %v", err)
return nil, 0, nil
@ -2020,8 +1992,8 @@ func compareCSICalls(trackedCalls []string, expectedCallSequence []csiCall, cs c
// Remove all repeated and ignored calls
tracked := sets.NewString(trackedCalls...)
var calls []mockCSICall
var last mockCSICall
var calls []drivers.MockCSICall
var last drivers.MockCSICall
for _, c := range allCalls {
if !tracked.Has(c.Method) {
continue
@ -2145,3 +2117,20 @@ func checkDeleteSnapshotSecrets(cs clientset.Interface, annotations interface{})
return err
}
// createPreHook counts invocations of a certain method (identified by a substring in the full gRPC method name).
func createPreHook(method string, callback func(counter int64) error) *drivers.Hooks {
var counter int64
return &drivers.Hooks{
Pre: func() func(ctx context.Context, fullMethod string, request interface{}) (reply interface{}, err error) {
return func(ctx context.Context, fullMethod string, request interface{}) (reply interface{}, err error) {
if strings.Contains(fullMethod, method) {
counter := atomic.AddInt64(&counter, 1)
return nil, callback(counter)
}
return nil, nil
}
}(),
}
}

View File

@ -0,0 +1,318 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
//go:generate mockgen -package=driver -destination=driver.mock.go github.com/container-storage-interface/spec/lib/go/csi IdentityServer,ControllerServer,NodeServer
package driver
import (
"context"
"encoding/json"
"errors"
"net"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/klog/v2"
"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
)
var (
// ErrNoCredentials is the error when a secret is enabled but not passed in the request.
ErrNoCredentials = errors.New("secret must be provided")
// ErrAuthFailed is the error when the secret is incorrect.
ErrAuthFailed = errors.New("authentication failed")
)
// CSIDriverServers is a unified driver component with both Controller and Node
// services.
type CSIDriverServers struct {
Controller csi.ControllerServer
Identity csi.IdentityServer
Node csi.NodeServer
}
// This is the key name in all the CSI secret objects.
const secretField = "secretKey"
// CSICreds is a driver specific secret type. Drivers can have a key-val pair of
// secrets. This mock driver has a single string secret with secretField as the
// key.
type CSICreds struct {
CreateVolumeSecret string
DeleteVolumeSecret string
ControllerPublishVolumeSecret string
ControllerUnpublishVolumeSecret string
NodeStageVolumeSecret string
NodePublishVolumeSecret string
CreateSnapshotSecret string
DeleteSnapshotSecret string
ControllerValidateVolumeCapabilitiesSecret string
}
type CSIDriver struct {
listener net.Listener
server *grpc.Server
servers *CSIDriverServers
wg sync.WaitGroup
running bool
lock sync.Mutex
creds *CSICreds
logGRPC LogGRPC
}
type LogGRPC func(method string, request, reply interface{}, err error)
func NewCSIDriver(servers *CSIDriverServers) *CSIDriver {
return &CSIDriver{
servers: servers,
}
}
func (c *CSIDriver) goServe(started chan<- bool) {
goServe(c.server, &c.wg, c.listener, started)
}
func (c *CSIDriver) Address() string {
return c.listener.Addr().String()
}
// Start runs a gRPC server with all enabled services. If an interceptor
// is give, then it will be used. Otherwise, an interceptor which
// handles simple credential checks and logs gRPC calls in JSON format
// will be used.
func (c *CSIDriver) Start(l net.Listener, interceptor grpc.UnaryServerInterceptor) error {
c.lock.Lock()
defer c.lock.Unlock()
// Set listener
c.listener = l
// Create a new grpc server
if interceptor == nil {
interceptor = c.callInterceptor
}
c.server = grpc.NewServer(grpc.UnaryInterceptor(interceptor))
// Register Mock servers
if c.servers.Controller != nil {
csi.RegisterControllerServer(c.server, c.servers.Controller)
}
if c.servers.Identity != nil {
csi.RegisterIdentityServer(c.server, c.servers.Identity)
}
if c.servers.Node != nil {
csi.RegisterNodeServer(c.server, c.servers.Node)
}
// Start listening for requests
waitForServer := make(chan bool)
c.goServe(waitForServer)
<-waitForServer
c.running = true
return nil
}
func (c *CSIDriver) Stop() {
stop(&c.lock, &c.wg, c.server, c.running)
}
func (c *CSIDriver) Close() {
c.server.Stop()
}
func (c *CSIDriver) IsRunning() bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.running
}
// SetDefaultCreds sets the default secrets for CSI creds.
func (c *CSIDriver) SetDefaultCreds() {
setDefaultCreds(c.creds)
}
// goServe starts a grpc server.
func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) {
wg.Add(1)
go func() {
defer wg.Done()
started <- true
err := server.Serve(listener)
if err != nil {
panic(err.Error())
}
}()
}
// stop stops a grpc server.
func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) {
lock.Lock()
defer lock.Unlock()
if !running {
return
}
server.Stop()
wg.Wait()
}
// setDefaultCreds sets the default credentials, given a CSICreds instance.
func setDefaultCreds(creds *CSICreds) {
*creds = CSICreds{
CreateVolumeSecret: "secretval1",
DeleteVolumeSecret: "secretval2",
ControllerPublishVolumeSecret: "secretval3",
ControllerUnpublishVolumeSecret: "secretval4",
NodeStageVolumeSecret: "secretval5",
NodePublishVolumeSecret: "secretval6",
CreateSnapshotSecret: "secretval7",
DeleteSnapshotSecret: "secretval8",
ControllerValidateVolumeCapabilitiesSecret: "secretval9",
}
}
func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
err := authInterceptor(c.creds, req)
if err != nil {
logGRPC(info.FullMethod, req, nil, err)
return nil, err
}
rsp, err := handler(ctx, req)
logGRPC(info.FullMethod, req, rsp, err)
if c.logGRPC != nil {
c.logGRPC(info.FullMethod, req, rsp, err)
}
return rsp, err
}
func authInterceptor(creds *CSICreds, req interface{}) error {
if creds != nil {
authenticated, authErr := isAuthenticated(req, creds)
if !authenticated {
if authErr == ErrNoCredentials {
return status.Error(codes.InvalidArgument, authErr.Error())
}
if authErr == ErrAuthFailed {
return status.Error(codes.Unauthenticated, authErr.Error())
}
}
}
return nil
}
func logGRPC(method string, request, reply interface{}, err error) {
// Log JSON with the request and response for easier parsing
logMessage := struct {
Method string
Request interface{}
Response interface{}
// Error as string, for backward compatibility.
// "" on no error.
Error string
// Full error dump, to be able to parse out full gRPC error code and message separately in a test.
FullError error
}{
Method: method,
Request: request,
Response: reply,
FullError: err,
}
if err != nil {
logMessage.Error = err.Error()
}
msg, _ := json.Marshal(logMessage)
klog.V(3).Infof("gRPCCall: %s\n", msg)
}
func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) {
switch r := req.(type) {
case *csi.CreateVolumeRequest:
return authenticateCreateVolume(r, creds)
case *csi.DeleteVolumeRequest:
return authenticateDeleteVolume(r, creds)
case *csi.ControllerPublishVolumeRequest:
return authenticateControllerPublishVolume(r, creds)
case *csi.ControllerUnpublishVolumeRequest:
return authenticateControllerUnpublishVolume(r, creds)
case *csi.NodeStageVolumeRequest:
return authenticateNodeStageVolume(r, creds)
case *csi.NodePublishVolumeRequest:
return authenticateNodePublishVolume(r, creds)
case *csi.CreateSnapshotRequest:
return authenticateCreateSnapshot(r, creds)
case *csi.DeleteSnapshotRequest:
return authenticateDeleteSnapshot(r, creds)
case *csi.ValidateVolumeCapabilitiesRequest:
return authenticateControllerValidateVolumeCapabilities(r, creds)
default:
return true, nil
}
}
func authenticateCreateVolume(req *csi.CreateVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.CreateVolumeSecret)
}
func authenticateDeleteVolume(req *csi.DeleteVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.DeleteVolumeSecret)
}
func authenticateControllerPublishVolume(req *csi.ControllerPublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerPublishVolumeSecret)
}
func authenticateControllerUnpublishVolume(req *csi.ControllerUnpublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerUnpublishVolumeSecret)
}
func authenticateNodeStageVolume(req *csi.NodeStageVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.NodeStageVolumeSecret)
}
func authenticateNodePublishVolume(req *csi.NodePublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.NodePublishVolumeSecret)
}
func authenticateCreateSnapshot(req *csi.CreateSnapshotRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.CreateSnapshotSecret)
}
func authenticateDeleteSnapshot(req *csi.DeleteSnapshotRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.DeleteSnapshotSecret)
}
func authenticateControllerValidateVolumeCapabilities(req *csi.ValidateVolumeCapabilitiesRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerValidateVolumeCapabilitiesSecret)
}
func credsCheck(secrets map[string]string, secretVal string) (bool, error) {
if len(secrets) == 0 {
return false, ErrNoCredentials
}
if secrets[secretField] != secretVal {
return false, ErrAuthFailed
}
return true, nil
}

View File

@ -0,0 +1,407 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package driver is a generated GoMock package, with required copyright
// header added manually.
package driver
import (
context "context"
reflect "reflect"
csi "github.com/container-storage-interface/spec/lib/go/csi"
gomock "github.com/golang/mock/gomock"
)
// MockIdentityServer is a mock of IdentityServer interface
type MockIdentityServer struct {
ctrl *gomock.Controller
recorder *MockIdentityServerMockRecorder
}
// MockIdentityServerMockRecorder is the mock recorder for MockIdentityServer
type MockIdentityServerMockRecorder struct {
mock *MockIdentityServer
}
// NewMockIdentityServer creates a new mock instance
func NewMockIdentityServer(ctrl *gomock.Controller) *MockIdentityServer {
mock := &MockIdentityServer{ctrl: ctrl}
mock.recorder = &MockIdentityServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockIdentityServer) EXPECT() *MockIdentityServerMockRecorder {
return m.recorder
}
// GetPluginCapabilities mocks base method
func (m *MockIdentityServer) GetPluginCapabilities(arg0 context.Context, arg1 *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "GetPluginCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.GetPluginCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginCapabilities indicates an expected call of GetPluginCapabilities
func (mr *MockIdentityServerMockRecorder) GetPluginCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginCapabilities", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginCapabilities), arg0, arg1)
}
// GetPluginInfo mocks base method
func (m *MockIdentityServer) GetPluginInfo(arg0 context.Context, arg1 *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
ret := m.ctrl.Call(m, "GetPluginInfo", arg0, arg1)
ret0, _ := ret[0].(*csi.GetPluginInfoResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginInfo indicates an expected call of GetPluginInfo
func (mr *MockIdentityServerMockRecorder) GetPluginInfo(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInfo", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginInfo), arg0, arg1)
}
// Probe mocks base method
func (m *MockIdentityServer) Probe(arg0 context.Context, arg1 *csi.ProbeRequest) (*csi.ProbeResponse, error) {
ret := m.ctrl.Call(m, "Probe", arg0, arg1)
ret0, _ := ret[0].(*csi.ProbeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Probe indicates an expected call of Probe
func (mr *MockIdentityServerMockRecorder) Probe(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockIdentityServer)(nil).Probe), arg0, arg1)
}
// MockControllerServer is a mock of ControllerServer interface
type MockControllerServer struct {
ctrl *gomock.Controller
recorder *MockControllerServerMockRecorder
}
// MockControllerServerMockRecorder is the mock recorder for MockControllerServer
type MockControllerServerMockRecorder struct {
mock *MockControllerServer
}
// NewMockControllerServer creates a new mock instance
func NewMockControllerServer(ctrl *gomock.Controller) *MockControllerServer {
mock := &MockControllerServer{ctrl: ctrl}
mock.recorder = &MockControllerServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockControllerServer) EXPECT() *MockControllerServerMockRecorder {
return m.recorder
}
// ControllerExpandVolume mocks base method
func (m *MockControllerServer) ControllerExpandVolume(arg0 context.Context, arg1 *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerExpandVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerExpandVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerExpandVolume indicates an expected call of ControllerExpandVolume
func (mr *MockControllerServerMockRecorder) ControllerExpandVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerExpandVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerExpandVolume), arg0, arg1)
}
// ControllerGetCapabilities mocks base method
func (m *MockControllerServer) ControllerGetCapabilities(arg0 context.Context, arg1 *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "ControllerGetCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerGetCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerGetCapabilities indicates an expected call of ControllerGetCapabilities
func (mr *MockControllerServerMockRecorder) ControllerGetCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetCapabilities), arg0, arg1)
}
// ControllerPublishVolume mocks base method
func (m *MockControllerServer) ControllerPublishVolume(arg0 context.Context, arg1 *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerPublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerPublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerPublishVolume indicates an expected call of ControllerPublishVolume
func (mr *MockControllerServerMockRecorder) ControllerPublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerPublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerPublishVolume), arg0, arg1)
}
// ControllerUnpublishVolume mocks base method
func (m *MockControllerServer) ControllerUnpublishVolume(arg0 context.Context, arg1 *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerUnpublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerUnpublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerUnpublishVolume indicates an expected call of ControllerUnpublishVolume
func (mr *MockControllerServerMockRecorder) ControllerUnpublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerUnpublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerUnpublishVolume), arg0, arg1)
}
// CreateSnapshot mocks base method
func (m *MockControllerServer) CreateSnapshot(arg0 context.Context, arg1 *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
ret := m.ctrl.Call(m, "CreateSnapshot", arg0, arg1)
ret0, _ := ret[0].(*csi.CreateSnapshotResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSnapshot indicates an expected call of CreateSnapshot
func (mr *MockControllerServerMockRecorder) CreateSnapshot(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSnapshot", reflect.TypeOf((*MockControllerServer)(nil).CreateSnapshot), arg0, arg1)
}
// CreateVolume mocks base method
func (m *MockControllerServer) CreateVolume(arg0 context.Context, arg1 *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
ret := m.ctrl.Call(m, "CreateVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.CreateVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateVolume indicates an expected call of CreateVolume
func (mr *MockControllerServerMockRecorder) CreateVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVolume", reflect.TypeOf((*MockControllerServer)(nil).CreateVolume), arg0, arg1)
}
// DeleteSnapshot mocks base method
func (m *MockControllerServer) DeleteSnapshot(arg0 context.Context, arg1 *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
ret := m.ctrl.Call(m, "DeleteSnapshot", arg0, arg1)
ret0, _ := ret[0].(*csi.DeleteSnapshotResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteSnapshot indicates an expected call of DeleteSnapshot
func (mr *MockControllerServerMockRecorder) DeleteSnapshot(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSnapshot", reflect.TypeOf((*MockControllerServer)(nil).DeleteSnapshot), arg0, arg1)
}
// DeleteVolume mocks base method
func (m *MockControllerServer) DeleteVolume(arg0 context.Context, arg1 *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
ret := m.ctrl.Call(m, "DeleteVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.DeleteVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteVolume indicates an expected call of DeleteVolume
func (mr *MockControllerServerMockRecorder) DeleteVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolume", reflect.TypeOf((*MockControllerServer)(nil).DeleteVolume), arg0, arg1)
}
// GetCapacity mocks base method
func (m *MockControllerServer) GetCapacity(arg0 context.Context, arg1 *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
ret := m.ctrl.Call(m, "GetCapacity", arg0, arg1)
ret0, _ := ret[0].(*csi.GetCapacityResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCapacity indicates an expected call of GetCapacity
func (mr *MockControllerServerMockRecorder) GetCapacity(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCapacity", reflect.TypeOf((*MockControllerServer)(nil).GetCapacity), arg0, arg1)
}
// ListSnapshots mocks base method
func (m *MockControllerServer) ListSnapshots(arg0 context.Context, arg1 *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
ret := m.ctrl.Call(m, "ListSnapshots", arg0, arg1)
ret0, _ := ret[0].(*csi.ListSnapshotsResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListSnapshots indicates an expected call of ListSnapshots
func (mr *MockControllerServerMockRecorder) ListSnapshots(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockControllerServer)(nil).ListSnapshots), arg0, arg1)
}
// ListVolumes mocks base method
func (m *MockControllerServer) ListVolumes(arg0 context.Context, arg1 *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
ret := m.ctrl.Call(m, "ListVolumes", arg0, arg1)
ret0, _ := ret[0].(*csi.ListVolumesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (m *MockControllerServer) ControllerGetVolume(arg0 context.Context, arg1 *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerGetVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerGetVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerGetVolume indicates an expected call of ControllerGetVolume
func (mr *MockControllerServerMockRecorder) ControllerGetVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetVolume), arg0, arg1)
}
// ListVolumes indicates an expected call of ListVolumes
func (mr *MockControllerServerMockRecorder) ListVolumes(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListVolumes", reflect.TypeOf((*MockControllerServer)(nil).ListVolumes), arg0, arg1)
}
// ValidateVolumeCapabilities mocks base method
func (m *MockControllerServer) ValidateVolumeCapabilities(arg0 context.Context, arg1 *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "ValidateVolumeCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.ValidateVolumeCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ValidateVolumeCapabilities indicates an expected call of ValidateVolumeCapabilities
func (mr *MockControllerServerMockRecorder) ValidateVolumeCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateVolumeCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ValidateVolumeCapabilities), arg0, arg1)
}
// MockNodeServer is a mock of NodeServer interface
type MockNodeServer struct {
ctrl *gomock.Controller
recorder *MockNodeServerMockRecorder
}
// MockNodeServerMockRecorder is the mock recorder for MockNodeServer
type MockNodeServerMockRecorder struct {
mock *MockNodeServer
}
// NewMockNodeServer creates a new mock instance
func NewMockNodeServer(ctrl *gomock.Controller) *MockNodeServer {
mock := &MockNodeServer{ctrl: ctrl}
mock.recorder = &MockNodeServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockNodeServer) EXPECT() *MockNodeServerMockRecorder {
return m.recorder
}
// NodeExpandVolume mocks base method
func (m *MockNodeServer) NodeExpandVolume(arg0 context.Context, arg1 *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeExpandVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeExpandVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeExpandVolume indicates an expected call of NodeExpandVolume
func (mr *MockNodeServerMockRecorder) NodeExpandVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeExpandVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeExpandVolume), arg0, arg1)
}
// NodeGetCapabilities mocks base method
func (m *MockNodeServer) NodeGetCapabilities(arg0 context.Context, arg1 *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "NodeGetCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetCapabilities indicates an expected call of NodeGetCapabilities
func (mr *MockNodeServerMockRecorder) NodeGetCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetCapabilities", reflect.TypeOf((*MockNodeServer)(nil).NodeGetCapabilities), arg0, arg1)
}
// NodeGetInfo mocks base method
func (m *MockNodeServer) NodeGetInfo(arg0 context.Context, arg1 *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) {
ret := m.ctrl.Call(m, "NodeGetInfo", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetInfoResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetInfo indicates an expected call of NodeGetInfo
func (mr *MockNodeServerMockRecorder) NodeGetInfo(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetInfo", reflect.TypeOf((*MockNodeServer)(nil).NodeGetInfo), arg0, arg1)
}
// NodeGetVolumeStats mocks base method
func (m *MockNodeServer) NodeGetVolumeStats(arg0 context.Context, arg1 *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) {
ret := m.ctrl.Call(m, "NodeGetVolumeStats", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetVolumeStatsResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetVolumeStats indicates an expected call of NodeGetVolumeStats
func (mr *MockNodeServerMockRecorder) NodeGetVolumeStats(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetVolumeStats", reflect.TypeOf((*MockNodeServer)(nil).NodeGetVolumeStats), arg0, arg1)
}
// NodePublishVolume mocks base method
func (m *MockNodeServer) NodePublishVolume(arg0 context.Context, arg1 *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodePublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodePublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodePublishVolume indicates an expected call of NodePublishVolume
func (mr *MockNodeServerMockRecorder) NodePublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodePublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodePublishVolume), arg0, arg1)
}
// NodeStageVolume mocks base method
func (m *MockNodeServer) NodeStageVolume(arg0 context.Context, arg1 *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeStageVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeStageVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeStageVolume indicates an expected call of NodeStageVolume
func (mr *MockNodeServerMockRecorder) NodeStageVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeStageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeStageVolume), arg0, arg1)
}
// NodeUnpublishVolume mocks base method
func (m *MockNodeServer) NodeUnpublishVolume(arg0 context.Context, arg1 *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeUnpublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeUnpublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeUnpublishVolume indicates an expected call of NodeUnpublishVolume
func (mr *MockNodeServerMockRecorder) NodeUnpublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnpublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnpublishVolume), arg0, arg1)
}
// NodeUnstageVolume mocks base method
func (m *MockNodeServer) NodeUnstageVolume(arg0 context.Context, arg1 *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeUnstageVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeUnstageVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeUnstageVolume indicates an expected call of NodeUnstageVolume
func (mr *MockNodeServerMockRecorder) NodeUnstageVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnstageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnstageVolume), arg0, arg1)
}

View File

@ -0,0 +1,74 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"net"
"google.golang.org/grpc"
)
type MockCSIDriverServers struct {
Controller *MockControllerServer
Identity *MockIdentityServer
Node *MockNodeServer
}
type MockCSIDriver struct {
CSIDriver
conn *grpc.ClientConn
interceptor grpc.UnaryServerInterceptor
}
func NewMockCSIDriver(servers *MockCSIDriverServers, interceptor grpc.UnaryServerInterceptor) *MockCSIDriver {
return &MockCSIDriver{
CSIDriver: CSIDriver{
servers: &CSIDriverServers{
Controller: servers.Controller,
Node: servers.Node,
Identity: servers.Identity,
},
},
interceptor: interceptor,
}
}
// StartOnAddress starts a new gRPC server listening on given address.
func (m *MockCSIDriver) StartOnAddress(network, address string) error {
l, err := net.Listen(network, address)
if err != nil {
return err
}
if err := m.CSIDriver.Start(l, m.interceptor); err != nil {
l.Close()
return err
}
return nil
}
// Start starts a new gRPC server listening on a random TCP loopback port.
func (m *MockCSIDriver) Start() error {
// Listen on a port assigned by the net package
return m.StartOnAddress("tcp", "127.0.0.1:0")
}
func (m *MockCSIDriver) Close() {
m.conn.Close()
m.server.Stop()
}

View File

@ -0,0 +1,105 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"strings"
"sync"
"github.com/container-storage-interface/spec/lib/go/csi"
)
type SnapshotCache interface {
Add(snapshot Snapshot)
Delete(i int)
List(ready bool) []csi.Snapshot
FindSnapshot(k, v string) (int, Snapshot)
}
type Snapshot struct {
Name string
Parameters map[string]string
SnapshotCSI csi.Snapshot
}
type snapshotCache struct {
snapshotsRWL sync.RWMutex
snapshots []Snapshot
}
func NewSnapshotCache() SnapshotCache {
return &snapshotCache{
snapshots: make([]Snapshot, 0),
}
}
func (snap *snapshotCache) Add(snapshot Snapshot) {
snap.snapshotsRWL.Lock()
defer snap.snapshotsRWL.Unlock()
snap.snapshots = append(snap.snapshots, snapshot)
}
func (snap *snapshotCache) Delete(i int) {
snap.snapshotsRWL.Lock()
defer snap.snapshotsRWL.Unlock()
copy(snap.snapshots[i:], snap.snapshots[i+1:])
snap.snapshots = snap.snapshots[:len(snap.snapshots)-1]
}
func (snap *snapshotCache) List(ready bool) []csi.Snapshot {
snap.snapshotsRWL.RLock()
defer snap.snapshotsRWL.RUnlock()
snapshots := make([]csi.Snapshot, 0)
for _, v := range snap.snapshots {
if v.SnapshotCSI.GetReadyToUse() {
snapshots = append(snapshots, v.SnapshotCSI)
}
}
return snapshots
}
func (snap *snapshotCache) FindSnapshot(k, v string) (int, Snapshot) {
snap.snapshotsRWL.RLock()
defer snap.snapshotsRWL.RUnlock()
snapshotIdx := -1
for i, vi := range snap.snapshots {
switch k {
case "id":
if strings.EqualFold(v, vi.SnapshotCSI.GetSnapshotId()) {
return i, vi
}
case "sourceVolumeId":
if strings.EqualFold(v, vi.SnapshotCSI.SourceVolumeId) {
return i, vi
}
case "name":
if vi.Name == v {
return i, vi
}
}
}
return snapshotIdx, Snapshot{}
}

View File

@ -0,0 +1,844 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package service
import (
"fmt"
"path"
"reflect"
"strconv"
"github.com/container-storage-interface/spec/lib/go/csi"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/klog/v2"
)
const (
MaxStorageCapacity = tib
ReadOnlyKey = "readonly"
)
func (s *service) CreateVolume(
ctx context.Context,
req *csi.CreateVolumeRequest) (
*csi.CreateVolumeResponse, error) {
if len(req.Name) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Name cannot be empty")
}
if req.VolumeCapabilities == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty")
}
// Check to see if the volume already exists.
if i, v := s.findVolByName(ctx, req.Name); i >= 0 {
// Requested volume name already exists, need to check if the existing volume's
// capacity is more or equal to new request's capacity.
if v.GetCapacityBytes() < req.GetCapacityRange().GetRequiredBytes() {
return nil, status.Error(codes.AlreadyExists,
fmt.Sprintf("Volume with name %s already exists", req.GetName()))
}
return &csi.CreateVolumeResponse{Volume: &v}, nil
}
// If no capacity is specified then use 100GiB
capacity := gib100
if cr := req.CapacityRange; cr != nil {
if rb := cr.RequiredBytes; rb > 0 {
capacity = rb
}
if lb := cr.LimitBytes; lb > 0 {
capacity = lb
}
}
// Check for maximum available capacity
if capacity >= MaxStorageCapacity {
return nil, status.Errorf(codes.OutOfRange, "Requested capacity %d exceeds maximum allowed %d", capacity, MaxStorageCapacity)
}
var v csi.Volume
// Create volume from content source if provided.
if req.GetVolumeContentSource() != nil {
switch req.GetVolumeContentSource().GetType().(type) {
case *csi.VolumeContentSource_Snapshot:
sid := req.GetVolumeContentSource().GetSnapshot().GetSnapshotId()
// Check if the source snapshot exists.
if snapID, _ := s.snapshots.FindSnapshot("id", sid); snapID >= 0 {
v = s.newVolumeFromSnapshot(req.Name, capacity, snapID)
} else {
return nil, status.Errorf(codes.NotFound, "Requested source snapshot %s not found", sid)
}
case *csi.VolumeContentSource_Volume:
vid := req.GetVolumeContentSource().GetVolume().GetVolumeId()
// Check if the source volume exists.
if volID, _ := s.findVolNoLock("id", vid); volID >= 0 {
v = s.newVolumeFromVolume(req.Name, capacity, volID)
} else {
return nil, status.Errorf(codes.NotFound, "Requested source volume %s not found", vid)
}
}
} else {
v = s.newVolume(req.Name, capacity)
}
// Add the created volume to the service's in-mem volume slice.
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
s.vols = append(s.vols, v)
MockVolumes[v.GetVolumeId()] = Volume{
VolumeCSI: v,
NodeID: "",
ISStaged: false,
ISPublished: false,
StageTargetPath: "",
TargetPath: "",
}
if hookVal, hookMsg := s.execHook("CreateVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.CreateVolumeResponse{Volume: &v}, nil
}
func (s *service) DeleteVolume(
ctx context.Context,
req *csi.DeleteVolumeRequest) (
*csi.DeleteVolumeResponse, error) {
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
// If the volume is not specified, return error
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if hookVal, hookMsg := s.execHook("DeleteVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// If the volume does not exist then return an idempotent response.
i, _ := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return &csi.DeleteVolumeResponse{}, nil
}
// This delete logic preserves order and prevents potential memory
// leaks. The slice's elements may not be pointers, but the structs
// themselves have fields that are.
copy(s.vols[i:], s.vols[i+1:])
s.vols[len(s.vols)-1] = csi.Volume{}
s.vols = s.vols[:len(s.vols)-1]
klog.V(5).InfoS("mock delete volume", "volumeID", req.VolumeId)
if hookVal, hookMsg := s.execHook("DeleteVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.DeleteVolumeResponse{}, nil
}
func (s *service) ControllerPublishVolume(
ctx context.Context,
req *csi.ControllerPublishVolumeRequest) (
*csi.ControllerPublishVolumeResponse, error) {
if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.NodeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Node ID cannot be empty")
}
if req.VolumeCapability == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty")
}
if req.NodeId != s.nodeID {
return nil, status.Errorf(codes.NotFound, "Not matching Node ID %s to Mock Node ID %s", req.NodeId, s.nodeID)
}
if hookVal, hookMsg := s.execHook("ControllerPublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// devPathKey is the key in the volume's attributes that is set to a
// mock device path if the volume has been published by the controller
// to the specified node.
devPathKey := path.Join(req.NodeId, "dev")
// Check to see if the volume is already published.
if device := v.VolumeContext[devPathKey]; device != "" {
var volRo bool
var roVal string
if ro, ok := v.VolumeContext[ReadOnlyKey]; ok {
roVal = ro
}
if roVal == "true" {
volRo = true
} else {
volRo = false
}
// Check if readonly flag is compatible with the publish request.
if req.GetReadonly() != volRo {
return nil, status.Error(codes.AlreadyExists, "Volume published but has incompatible readonly flag")
}
return &csi.ControllerPublishVolumeResponse{
PublishContext: map[string]string{
"device": device,
"readonly": roVal,
},
}, nil
}
// Check attach limit before publishing only if attach limit is set.
if s.config.AttachLimit > 0 && s.getAttachCount(devPathKey) >= s.config.AttachLimit {
return nil, status.Errorf(codes.ResourceExhausted, "Cannot attach any more volumes to this node")
}
var roVal string
if req.GetReadonly() {
roVal = "true"
} else {
roVal = "false"
}
// Publish the volume.
device := "/dev/mock"
v.VolumeContext[devPathKey] = device
v.VolumeContext[ReadOnlyKey] = roVal
s.vols[i] = v
if volInfo, ok := MockVolumes[req.VolumeId]; ok {
volInfo.ISControllerPublished = true
MockVolumes[req.VolumeId] = volInfo
}
if hookVal, hookMsg := s.execHook("ControllerPublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerPublishVolumeResponse{
PublishContext: map[string]string{
"device": device,
"readonly": roVal,
},
}, nil
}
func (s *service) ControllerUnpublishVolume(
ctx context.Context,
req *csi.ControllerUnpublishVolumeRequest) (
*csi.ControllerUnpublishVolumeResponse, error) {
if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
nodeID := req.NodeId
if len(nodeID) == 0 {
// If node id is empty, no failure as per Spec
nodeID = s.nodeID
}
if req.NodeId != s.nodeID {
return nil, status.Errorf(codes.NotFound, "Node ID %s does not match to expected Node ID %s", req.NodeId, s.nodeID)
}
if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
// Not an error: a non-existent volume is not published.
// See also https://github.com/kubernetes-csi/external-attacher/pull/165
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
// devPathKey is the key in the volume's attributes that is set to a
// mock device path if the volume has been published by the controller
// to the specified node.
devPathKey := path.Join(nodeID, "dev")
// Check to see if the volume is already unpublished.
if v.VolumeContext[devPathKey] == "" {
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
// Unpublish the volume.
delete(v.VolumeContext, devPathKey)
delete(v.VolumeContext, ReadOnlyKey)
s.vols[i] = v
if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
func (s *service) ValidateVolumeCapabilities(
ctx context.Context,
req *csi.ValidateVolumeCapabilitiesRequest) (
*csi.ValidateVolumeCapabilitiesResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.VolumeCapabilities) == 0 {
return nil, status.Error(codes.InvalidArgument, req.VolumeId)
}
i, _ := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if hookVal, hookMsg := s.execHook("ValidateVolumeCapabilities"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeContext: req.GetVolumeContext(),
VolumeCapabilities: req.GetVolumeCapabilities(),
Parameters: req.GetParameters(),
},
}, nil
}
func (s *service) ControllerGetVolume(
ctx context.Context,
req *csi.ControllerGetVolumeRequest) (
*csi.ControllerGetVolumeResponse, error) {
if hookVal, hookMsg := s.execHook("GetVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
resp := &csi.ControllerGetVolumeResponse{
Status: &csi.ControllerGetVolumeResponse_VolumeStatus{
VolumeCondition: &csi.VolumeCondition{},
},
}
i, v := s.findVolByID(ctx, req.VolumeId)
if i < 0 {
resp.Status.VolumeCondition.Abnormal = true
resp.Status.VolumeCondition.Message = "volume not found"
return resp, status.Error(codes.NotFound, req.VolumeId)
}
resp.Volume = &v
if !s.config.DisableAttach {
resp.Status.PublishedNodeIds = []string{
s.nodeID,
}
}
if hookVal, hookMsg := s.execHook("GetVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func (s *service) ListVolumes(
ctx context.Context,
req *csi.ListVolumesRequest) (
*csi.ListVolumesResponse, error) {
if hookVal, hookMsg := s.execHook("ListVolumesStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// Copy the mock volumes into a new slice in order to avoid
// locking the service's volume slice for the duration of the
// ListVolumes RPC.
var vols []csi.Volume
func() {
s.volsRWL.RLock()
defer s.volsRWL.RUnlock()
vols = make([]csi.Volume, len(s.vols))
copy(vols, s.vols)
}()
var (
ulenVols = int32(len(vols))
maxEntries = req.MaxEntries
startingToken int32
)
if v := req.StartingToken; v != "" {
i, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%s: %v",
v, err)
}
startingToken = int32(i)
}
if startingToken > ulenVols {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d > len(vols)=%d",
startingToken, ulenVols)
}
// Discern the number of remaining entries.
rem := ulenVols - startingToken
// If maxEntries is 0 or greater than the number of remaining entries then
// set maxEntries to the number of remaining entries.
if maxEntries == 0 || maxEntries > rem {
maxEntries = rem
}
var (
i int
j = startingToken
entries = make(
[]*csi.ListVolumesResponse_Entry,
maxEntries)
)
for i = 0; i < len(entries); i++ {
volumeStatus := &csi.ListVolumesResponse_VolumeStatus{
VolumeCondition: &csi.VolumeCondition{},
}
if !s.config.DisableAttach {
volumeStatus.PublishedNodeIds = []string{
s.nodeID,
}
}
entries[i] = &csi.ListVolumesResponse_Entry{
Volume: &vols[j],
Status: volumeStatus,
}
j++
}
var nextToken string
if n := startingToken + int32(i); n < ulenVols {
nextToken = fmt.Sprintf("%d", n)
}
if hookVal, hookMsg := s.execHook("ListVolumesEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ListVolumesResponse{
Entries: entries,
NextToken: nextToken,
}, nil
}
func (s *service) GetCapacity(
ctx context.Context,
req *csi.GetCapacityRequest) (
*csi.GetCapacityResponse, error) {
if hookVal, hookMsg := s.execHook("GetCapacity"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.GetCapacityResponse{
AvailableCapacity: MaxStorageCapacity,
}, nil
}
func (s *service) ControllerGetCapabilities(
ctx context.Context,
req *csi.ControllerGetCapabilitiesRequest) (
*csi.ControllerGetCapabilitiesResponse, error) {
if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
caps := []*csi.ControllerServiceCapability{
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES_PUBLISHED_NODES,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_CAPACITY,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CLONE_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_VOLUME_CONDITION,
},
},
},
}
if !s.config.DisableAttach {
caps = append(caps, &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
},
})
}
if !s.config.DisableControllerExpansion {
caps = append(caps, &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_EXPAND_VOLUME,
},
},
})
}
if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerGetCapabilitiesResponse{
Capabilities: caps,
}, nil
}
func (s *service) CreateSnapshot(ctx context.Context,
req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
// Check arguments
if len(req.GetName()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot Name cannot be empty")
}
if len(req.GetSourceVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot SourceVolumeId cannot be empty")
}
// Check to see if the snapshot already exists.
if i, v := s.snapshots.FindSnapshot("name", req.GetName()); i >= 0 {
// Requested snapshot name already exists
if v.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() || !reflect.DeepEqual(v.Parameters, req.GetParameters()) {
return nil, status.Error(codes.AlreadyExists,
fmt.Sprintf("Snapshot with name %s already exists", req.GetName()))
}
return &csi.CreateSnapshotResponse{Snapshot: &v.SnapshotCSI}, nil
}
// Create the snapshot and add it to the service's in-mem snapshot slice.
snapshot := s.newSnapshot(req.GetName(), req.GetSourceVolumeId(), req.GetParameters())
s.snapshots.Add(snapshot)
if hookVal, hookMsg := s.execHook("CreateSnapshotEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.CreateSnapshotResponse{Snapshot: &snapshot.SnapshotCSI}, nil
}
func (s *service) DeleteSnapshot(ctx context.Context,
req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
// If the snapshot is not specified, return error
if len(req.SnapshotId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot ID cannot be empty")
}
if hookVal, hookMsg := s.execHook("DeleteSnapshotStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// If the snapshot does not exist then return an idempotent response.
i, _ := s.snapshots.FindSnapshot("id", req.SnapshotId)
if i < 0 {
return &csi.DeleteSnapshotResponse{}, nil
}
// This delete logic preserves order and prevents potential memory
// leaks. The slice's elements may not be pointers, but the structs
// themselves have fields that are.
s.snapshots.Delete(i)
klog.V(5).InfoS("mock delete snapshot", "SnapshotId", req.SnapshotId)
if hookVal, hookMsg := s.execHook("DeleteSnapshotEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.DeleteSnapshotResponse{}, nil
}
func (s *service) ListSnapshots(ctx context.Context,
req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if hookVal, hookMsg := s.execHook("ListSnapshots"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// case 1: SnapshotId is not empty, return snapshots that match the snapshot id.
if len(req.GetSnapshotId()) != 0 {
return getSnapshotById(s, req)
}
// case 2: SourceVolumeId is not empty, return snapshots that match the source volume id.
if len(req.GetSourceVolumeId()) != 0 {
return getSnapshotByVolumeId(s, req)
}
// case 3: no parameter is set, so we return all the snapshots.
return getAllSnapshots(s, req)
}
func (s *service) ControllerExpandVolume(
ctx context.Context,
req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if req.CapacityRange == nil {
return nil, status.Error(codes.InvalidArgument, "Request capacity cannot be empty")
}
if hookVal, hookMsg := s.execHook("ControllerExpandVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if s.config.DisableOnlineExpansion && MockVolumes[v.GetVolumeId()].ISControllerPublished {
return nil, status.Error(codes.FailedPrecondition, "volume is published and online volume expansion is not supported")
}
requestBytes := req.CapacityRange.RequiredBytes
if v.CapacityBytes > requestBytes {
return nil, status.Error(codes.InvalidArgument, "cannot change volume capacity to a smaller size")
}
resp := &csi.ControllerExpandVolumeResponse{
CapacityBytes: requestBytes,
NodeExpansionRequired: s.config.NodeExpansionRequired,
}
// Check to see if the volume already satisfied request size.
if v.CapacityBytes == requestBytes {
klog.V(5).InfoS("volume capacity sufficient, no need to expand", "requested", requestBytes, "current", v.CapacityBytes, "volumeID", v.VolumeId)
return resp, nil
}
// Update volume's capacity to the requested size.
v.CapacityBytes = requestBytes
s.vols[i] = v
if hookVal, hookMsg := s.execHook("ControllerExpandVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func getSnapshotById(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if len(req.GetSnapshotId()) != 0 {
i, snapshot := s.snapshots.FindSnapshot("id", req.GetSnapshotId())
if i < 0 {
return &csi.ListSnapshotsResponse{}, nil
}
if len(req.GetSourceVolumeId()) != 0 {
if snapshot.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() {
return &csi.ListSnapshotsResponse{}, nil
}
}
return &csi.ListSnapshotsResponse{
Entries: []*csi.ListSnapshotsResponse_Entry{
{
Snapshot: &snapshot.SnapshotCSI,
},
},
}, nil
}
return nil, nil
}
func getSnapshotByVolumeId(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if len(req.GetSourceVolumeId()) != 0 {
i, snapshot := s.snapshots.FindSnapshot("sourceVolumeId", req.SourceVolumeId)
if i < 0 {
return &csi.ListSnapshotsResponse{}, nil
}
return &csi.ListSnapshotsResponse{
Entries: []*csi.ListSnapshotsResponse_Entry{
{
Snapshot: &snapshot.SnapshotCSI,
},
},
}, nil
}
return nil, nil
}
func getAllSnapshots(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
// Copy the mock snapshots into a new slice in order to avoid
// locking the service's snapshot slice for the duration of the
// ListSnapshots RPC.
readyToUse := true
snapshots := s.snapshots.List(readyToUse)
var (
ulenSnapshots = int32(len(snapshots))
maxEntries = req.MaxEntries
startingToken int32
)
if v := req.StartingToken; v != "" {
i, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%s: %v",
v, err)
}
startingToken = int32(i)
}
if startingToken > ulenSnapshots {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d > len(snapshots)=%d",
startingToken, ulenSnapshots)
}
// Discern the number of remaining entries.
rem := ulenSnapshots - startingToken
// If maxEntries is 0 or greater than the number of remaining entries then
// set maxEntries to the number of remaining entries.
if maxEntries == 0 || maxEntries > rem {
maxEntries = rem
}
var (
i int
j = startingToken
entries = make(
[]*csi.ListSnapshotsResponse_Entry,
maxEntries)
)
for i = 0; i < len(entries); i++ {
entries[i] = &csi.ListSnapshotsResponse_Entry{
Snapshot: &snapshots[j],
}
j++
}
var nextToken string
if n := startingToken + int32(i); n < ulenSnapshots {
nextToken = fmt.Sprintf("%d", n)
}
return &csi.ListSnapshotsResponse{
Entries: entries,
NextToken: nextToken,
}, nil
}

View File

@ -0,0 +1,90 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package service
import (
"golang.org/x/net/context"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/protobuf/ptypes/wrappers"
)
func (s *service) GetPluginInfo(
ctx context.Context,
req *csi.GetPluginInfoRequest) (
*csi.GetPluginInfoResponse, error) {
return &csi.GetPluginInfoResponse{
Name: s.config.DriverName,
VendorVersion: VendorVersion,
Manifest: Manifest,
}, nil
}
func (s *service) Probe(
ctx context.Context,
req *csi.ProbeRequest) (
*csi.ProbeResponse, error) {
return &csi.ProbeResponse{
Ready: &wrappers.BoolValue{Value: true},
}, nil
}
func (s *service) GetPluginCapabilities(
ctx context.Context,
req *csi.GetPluginCapabilitiesRequest) (
*csi.GetPluginCapabilitiesResponse, error) {
volExpType := csi.PluginCapability_VolumeExpansion_ONLINE
if s.config.DisableOnlineExpansion {
volExpType = csi.PluginCapability_VolumeExpansion_OFFLINE
}
capabilities := []*csi.PluginCapability{
{
Type: &csi.PluginCapability_Service_{
Service: &csi.PluginCapability_Service{
Type: csi.PluginCapability_Service_CONTROLLER_SERVICE,
},
},
},
{
Type: &csi.PluginCapability_VolumeExpansion_{
VolumeExpansion: &csi.PluginCapability_VolumeExpansion{
Type: volExpType,
},
},
},
}
if s.config.EnableTopology {
capabilities = append(capabilities,
&csi.PluginCapability{
Type: &csi.PluginCapability_Service_{
Service: &csi.PluginCapability_Service{
Type: csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS,
},
},
})
}
return &csi.GetPluginCapabilitiesResponse{
Capabilities: capabilities,
}, nil
}

View File

@ -0,0 +1,450 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package service
import (
"fmt"
"path"
"strconv"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"golang.org/x/net/context"
"github.com/container-storage-interface/spec/lib/go/csi"
)
func (s *service) NodeStageVolume(
ctx context.Context,
req *csi.NodeStageVolumeRequest) (
*csi.NodeStageVolumeResponse, error) {
device, ok := req.PublishContext["device"]
if !ok {
if s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetStagingTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty")
}
if req.GetVolumeCapability() == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty")
}
exists, err := s.config.IO.DirExists(req.StagingTargetPath)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
status.Errorf(codes.Internal, "staging target path %s does not exist", req.StagingTargetPath)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// nodeStgPathKey is the key in the volume's attributes that is set to a
// mock stage path if the volume has been published by the node
nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath)
// Check to see if the volume has already been staged.
if v.VolumeContext[nodeStgPathKey] != "" {
// TODO: Check for the capabilities to be equal. Return "ALREADY_EXISTS"
// if the capabilities don't match.
return &csi.NodeStageVolumeResponse{}, nil
}
// Stage the volume.
v.VolumeContext[nodeStgPathKey] = device
s.vols[i] = v
if hookVal, hookMsg := s.execHook("NodeStageVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeStageVolumeResponse{}, nil
}
func (s *service) NodeUnstageVolume(
ctx context.Context,
req *csi.NodeUnstageVolumeRequest) (
*csi.NodeUnstageVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetStagingTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty")
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// nodeStgPathKey is the key in the volume's attributes that is set to a
// mock stage path if the volume has been published by the node
nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath)
// Check to see if the volume has already been unstaged.
if v.VolumeContext[nodeStgPathKey] == "" {
return &csi.NodeUnstageVolumeResponse{}, nil
}
// Unpublish the volume.
delete(v.VolumeContext, nodeStgPathKey)
s.vols[i] = v
if hookVal, hookMsg := s.execHook("NodeUnstageVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeUnstageVolumeResponse{}, nil
}
func (s *service) NodePublishVolume(
ctx context.Context,
req *csi.NodePublishVolumeRequest) (
*csi.NodePublishVolumeResponse, error) {
if hookVal, hookMsg := s.execHook("NodePublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
ephemeralVolume := req.GetVolumeContext()["csi.storage.k8s.io/ephemeral"] == "true"
device, ok := req.PublishContext["device"]
if !ok {
if ephemeralVolume || s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty")
}
if req.GetVolumeCapability() == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty")
}
// May happen with old (or, at this time, even the current) Kubernetes
// although it shouldn't (https://github.com/kubernetes/kubernetes/issues/75535).
exists, err := s.config.IO.DirExists(req.TargetPath)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !s.config.PermissiveTargetPath && exists {
status.Errorf(codes.Internal, "target path %s does exist", req.TargetPath)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 && !ephemeralVolume {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if i >= 0 && ephemeralVolume {
return nil, status.Error(codes.AlreadyExists, req.VolumeId)
}
// nodeMntPathKey is the key in the volume's attributes that is set to a
// mock mount path if the volume has been published by the node
nodeMntPathKey := path.Join(s.nodeID, req.TargetPath)
// Check to see if the volume has already been published.
if v.VolumeContext[nodeMntPathKey] != "" {
// Requests marked Readonly fail due to volumes published by
// the Mock driver supporting only RW mode.
if req.Readonly {
return nil, status.Error(codes.AlreadyExists, req.VolumeId)
}
return &csi.NodePublishVolumeResponse{}, nil
}
// Publish the volume.
if ephemeralVolume {
MockVolumes[req.VolumeId] = Volume{
ISEphemeral: true,
}
} else {
if req.GetTargetPath() != "" {
exists, err := s.config.IO.DirExists(req.GetTargetPath())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
// If target path does not exist we need to create the directory where volume will be staged
if err = s.config.IO.Mkdir(req.TargetPath); err != nil {
msg := fmt.Sprintf("NodePublishVolume: could not create target dir %q: %v", req.TargetPath, err)
return nil, status.Error(codes.Internal, msg)
}
}
v.VolumeContext[nodeMntPathKey] = req.GetTargetPath()
} else {
v.VolumeContext[nodeMntPathKey] = device
}
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodePublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodePublishVolumeResponse{}, nil
}
func (s *service) NodeUnpublishVolume(
ctx context.Context,
req *csi.NodeUnpublishVolumeRequest) (
*csi.NodeUnpublishVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty")
}
if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
ephemeralVolume := MockVolumes[req.VolumeId].ISEphemeral
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 && !ephemeralVolume {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if ephemeralVolume {
delete(MockVolumes, req.VolumeId)
} else {
// nodeMntPathKey is the key in the volume's attributes that is set to a
// mock mount path if the volume has been published by the node
nodeMntPathKey := path.Join(s.nodeID, req.TargetPath)
// Check to see if the volume has already been unpublished.
if v.VolumeContext[nodeMntPathKey] == "" {
return &csi.NodeUnpublishVolumeResponse{}, nil
}
// Delete any created paths
err := s.config.IO.RemoveAll(v.VolumeContext[nodeMntPathKey])
if err != nil {
return nil, status.Errorf(codes.Internal, "Unable to delete previously created target directory")
}
// Unpublish the volume.
delete(v.VolumeContext, nodeMntPathKey)
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeUnpublishVolumeResponse{}, nil
}
func (s *service) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetVolumePath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty")
}
if hookVal, hookMsg := s.execHook("NodeExpandVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// TODO: NodeExpandVolume MUST be called after successful NodeStageVolume as we has STAGE_UNSTAGE_VOLUME node capacity.
resp := &csi.NodeExpandVolumeResponse{}
var requestCapacity int64 = 0
if req.GetCapacityRange() != nil {
requestCapacity = req.CapacityRange.GetRequiredBytes()
resp.CapacityBytes = requestCapacity
}
// fsCapacityKey is the key in the volume's attributes that is set to the file system's size.
fsCapacityKey := path.Join(s.nodeID, req.GetVolumePath(), "size")
// Update volume's fs capacity to requested size.
if requestCapacity > 0 {
v.VolumeContext[fsCapacityKey] = strconv.FormatInt(requestCapacity, 10)
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodeExpandVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func (s *service) NodeGetCapabilities(
ctx context.Context,
req *csi.NodeGetCapabilitiesRequest) (
*csi.NodeGetCapabilitiesResponse, error) {
if hookVal, hookMsg := s.execHook("NodeGetCapabilities"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
capabilities := []*csi.NodeServiceCapability{
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_UNKNOWN,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_VOLUME_CONDITION,
},
},
},
}
if s.config.NodeExpansionRequired {
capabilities = append(capabilities, &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_EXPAND_VOLUME,
},
},
})
}
return &csi.NodeGetCapabilitiesResponse{
Capabilities: capabilities,
}, nil
}
func (s *service) NodeGetInfo(ctx context.Context,
req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) {
if hookVal, hookMsg := s.execHook("NodeGetInfo"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
csiNodeResponse := &csi.NodeGetInfoResponse{
NodeId: s.nodeID,
}
if s.config.AttachLimit > 0 {
csiNodeResponse.MaxVolumesPerNode = s.config.AttachLimit
}
if s.config.EnableTopology {
csiNodeResponse.AccessibleTopology = &csi.Topology{
Segments: map[string]string{
TopologyKey: TopologyValue,
},
}
}
return csiNodeResponse, nil
}
func (s *service) NodeGetVolumeStats(ctx context.Context,
req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) {
resp := &csi.NodeGetVolumeStatsResponse{
VolumeCondition: &csi.VolumeCondition{},
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetVolumePath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty")
}
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
resp.VolumeCondition.Abnormal = true
resp.VolumeCondition.Message = "Volume not found"
return resp, status.Error(codes.NotFound, req.VolumeId)
}
nodeMntPathKey := path.Join(s.nodeID, req.VolumePath)
_, exists := v.VolumeContext[nodeMntPathKey]
if !exists {
msg := fmt.Sprintf("volume %q doest not exist on the specified path %q", req.VolumeId, req.VolumePath)
resp.VolumeCondition.Abnormal = true
resp.VolumeCondition.Message = msg
return resp, status.Errorf(codes.NotFound, msg)
}
if hookVal, hookMsg := s.execHook("NodeGetVolumeStatsEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
resp.Usage = []*csi.VolumeUsage{
{
Total: v.GetCapacityBytes(),
Unit: csi.VolumeUsage_BYTES,
},
}
return resp, nil
}

View File

@ -0,0 +1,276 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package service
import (
"fmt"
"os"
"strings"
"sync"
"sync/atomic"
"github.com/container-storage-interface/spec/lib/go/csi"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"k8s.io/kubernetes/test/e2e/storage/drivers/csi-test/mock/cache"
"github.com/golang/protobuf/ptypes"
)
const (
// Name is the name of the CSI plug-in.
Name = "io.kubernetes.storage.mock"
// VendorVersion is the version returned by GetPluginInfo.
VendorVersion = "0.3.0"
// TopologyKey simulates a per-node topology.
TopologyKey = Name + "/node"
// TopologyValue is the one, fixed node on which the driver runs.
TopologyValue = "some-mock-node"
)
// Manifest is the SP's manifest.
var Manifest = map[string]string{
"url": "https://k8s.io/kubernetes/test/e2e/storage/drivers/csi-test/mock",
}
type Config struct {
DisableAttach bool
DriverName string
AttachLimit int64
NodeExpansionRequired bool
DisableControllerExpansion bool
DisableOnlineExpansion bool
PermissiveTargetPath bool
EnableTopology bool
IO DirIO
}
// DirIO is an abstraction over direct os calls.
type DirIO interface {
// DirExists returns false if the path doesn't exist, true if it exists and is a directory, an error otherwise.
DirExists(path string) (bool, error)
// Mkdir creates the directory, but not its parents, with 0755 permissions.
Mkdir(path string) error
// RemoveAll removes the path and everything contained inside it. It's not an error if the path does not exist.
RemoveAll(path string) error
}
type OSDirIO struct{}
func (o OSDirIO) DirExists(path string) (bool, error) {
info, err := os.Stat(path)
switch {
case err == nil && !info.IsDir():
return false, fmt.Errorf("%s: not a directory", path)
case err == nil:
return true, nil
case os.IsNotExist(err):
return false, nil
default:
return false, err
}
}
func (o OSDirIO) Mkdir(path string) error {
return os.Mkdir(path, os.FileMode(0755))
}
func (o OSDirIO) RemoveAll(path string) error {
return os.RemoveAll(path)
}
// Service is the CSI Mock service provider.
type Service interface {
csi.ControllerServer
csi.IdentityServer
csi.NodeServer
}
type service struct {
sync.Mutex
nodeID string
vols []csi.Volume
volsRWL sync.RWMutex
volsNID uint64
snapshots cache.SnapshotCache
snapshotsNID uint64
config Config
}
type Volume struct {
VolumeCSI csi.Volume
NodeID string
ISStaged bool
ISPublished bool
ISEphemeral bool
ISControllerPublished bool
StageTargetPath string
TargetPath string
}
var MockVolumes map[string]Volume
// New returns a new Service.
func New(config Config) Service {
s := &service{
nodeID: config.DriverName,
config: config,
}
if s.config.IO == nil {
s.config.IO = OSDirIO{}
}
s.snapshots = cache.NewSnapshotCache()
s.vols = []csi.Volume{
s.newVolume("Mock Volume 1", gib100),
s.newVolume("Mock Volume 2", gib100),
s.newVolume("Mock Volume 3", gib100),
}
MockVolumes = map[string]Volume{}
s.snapshots.Add(s.newSnapshot("Mock Snapshot 1", "1", map[string]string{"Description": "snapshot 1"}))
s.snapshots.Add(s.newSnapshot("Mock Snapshot 2", "2", map[string]string{"Description": "snapshot 2"}))
s.snapshots.Add(s.newSnapshot("Mock Snapshot 3", "3", map[string]string{"Description": "snapshot 3"}))
return s
}
const (
kib int64 = 1024
mib int64 = kib * 1024
gib int64 = mib * 1024
gib100 int64 = gib * 100
tib int64 = gib * 1024
tib100 int64 = tib * 100
)
func (s *service) newVolume(name string, capcity int64) csi.Volume {
vol := csi.Volume{
VolumeId: fmt.Sprintf("%d", atomic.AddUint64(&s.volsNID, 1)),
VolumeContext: map[string]string{"name": name},
CapacityBytes: capcity,
}
s.setTopology(&vol)
return vol
}
func (s *service) newVolumeFromSnapshot(name string, capacity int64, snapshotID int) csi.Volume {
vol := s.newVolume(name, capacity)
vol.ContentSource = &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: fmt.Sprintf("%d", snapshotID),
},
},
}
s.setTopology(&vol)
return vol
}
func (s *service) newVolumeFromVolume(name string, capacity int64, volumeID int) csi.Volume {
vol := s.newVolume(name, capacity)
vol.ContentSource = &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Volume{
Volume: &csi.VolumeContentSource_VolumeSource{
VolumeId: fmt.Sprintf("%d", volumeID),
},
},
}
s.setTopology(&vol)
return vol
}
func (s *service) setTopology(vol *csi.Volume) {
if s.config.EnableTopology {
vol.AccessibleTopology = []*csi.Topology{
{
Segments: map[string]string{
TopologyKey: TopologyValue,
},
},
}
}
}
func (s *service) findVol(k, v string) (volIdx int, volInfo csi.Volume) {
s.volsRWL.RLock()
defer s.volsRWL.RUnlock()
return s.findVolNoLock(k, v)
}
func (s *service) findVolNoLock(k, v string) (volIdx int, volInfo csi.Volume) {
volIdx = -1
for i, vi := range s.vols {
switch k {
case "id":
if strings.EqualFold(v, vi.GetVolumeId()) {
return i, vi
}
case "name":
if n, ok := vi.VolumeContext["name"]; ok && strings.EqualFold(v, n) {
return i, vi
}
}
}
return
}
func (s *service) findVolByName(
ctx context.Context, name string) (int, csi.Volume) {
return s.findVol("name", name)
}
func (s *service) findVolByID(
ctx context.Context, id string) (int, csi.Volume) {
return s.findVol("id", id)
}
func (s *service) newSnapshot(name, sourceVolumeId string, parameters map[string]string) cache.Snapshot {
ptime := ptypes.TimestampNow()
return cache.Snapshot{
Name: name,
Parameters: parameters,
SnapshotCSI: csi.Snapshot{
SnapshotId: fmt.Sprintf("%d", atomic.AddUint64(&s.snapshotsNID, 1)),
CreationTime: ptime,
SourceVolumeId: sourceVolumeId,
ReadyToUse: true,
},
}
}
// getAttachCount returns the number of attached volumes on the node.
func (s *service) getAttachCount(devPathKey string) int64 {
var count int64
for _, v := range s.vols {
if device := v.VolumeContext[devPathKey]; device != "" {
count++
}
}
return count
}
func (s *service) execHook(hookName string) (codes.Code, string) {
return codes.OK, ""
}

View File

@ -37,13 +37,16 @@ package drivers
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
"github.com/onsi/ginkgo"
"google.golang.org/grpc/codes"
v1 "k8s.io/api/core/v1"
storagev1 "k8s.io/api/storage/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
@ -51,14 +54,21 @@ import (
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/klog/v2"
"k8s.io/kubernetes/test/e2e/framework"
e2enode "k8s.io/kubernetes/test/e2e/framework/node"
e2epod "k8s.io/kubernetes/test/e2e/framework/pod"
e2eskipper "k8s.io/kubernetes/test/e2e/framework/skipper"
e2evolume "k8s.io/kubernetes/test/e2e/framework/volume"
mockdriver "k8s.io/kubernetes/test/e2e/storage/drivers/csi-test/driver"
mockservice "k8s.io/kubernetes/test/e2e/storage/drivers/csi-test/mock/service"
"k8s.io/kubernetes/test/e2e/storage/drivers/proxy"
storageframework "k8s.io/kubernetes/test/e2e/storage/framework"
"k8s.io/kubernetes/test/e2e/storage/utils"
"google.golang.org/grpc"
)
const (
@ -66,6 +76,9 @@ const (
GCEPDCSIDriverName = "pd.csi.storage.gke.io"
// GCEPDCSIZoneTopologyKey is the key of GCE Persistent Disk CSI zone topology
GCEPDCSIZoneTopologyKey = "topology.gke.io/zone"
// Prefix of the mock driver grpc log
grpcCallPrefix = "gRPCCall:"
)
// hostpathCSI
@ -232,10 +245,42 @@ type mockCSIDriver struct {
attachLimit int
enableTopology bool
enableNodeExpansion bool
javascriptHooks map[string]string
hooks Hooks
tokenRequests []storagev1.TokenRequest
requiresRepublish *bool
fsGroupPolicy *storagev1.FSGroupPolicy
embedded bool
calls MockCSICalls
embeddedCSIDriver *mockdriver.CSIDriver
// Additional values set during PrepareTest
clientSet kubernetes.Interface
driverNamespace *v1.Namespace
}
// Hooks to be run to execute while handling gRPC calls.
//
// At the moment, only generic pre- and post-function call
// hooks are implemented. Those hooks can cast the request and
// response values if needed. More hooks inside specific
// functions could be added if needed.
type Hooks struct {
// Pre is called before invoking the mock driver's implementation of a method.
// If either a non-nil reply or error are returned, then those are returned to the caller.
Pre func(ctx context.Context, method string, request interface{}) (reply interface{}, err error)
// Post is called after invoking the mock driver's implementation of a method.
// What it returns is used as actual result.
Post func(ctx context.Context, method string, request, reply interface{}, err error) (finalReply interface{}, finalErr error)
}
// MockCSITestDriver provides additional functions specific to the CSI mock driver.
type MockCSITestDriver interface {
storageframework.DynamicPVTestDriver
// GetCalls returns all currently observed gRPC calls. Only valid
// after PrepareTest.
GetCalls() ([]MockCSICall, error)
}
// CSIMockDriverOpts defines options used for csi driver
@ -249,10 +294,96 @@ type CSIMockDriverOpts struct {
EnableResizing bool
EnableNodeExpansion bool
EnableSnapshot bool
JavascriptHooks map[string]string
TokenRequests []storagev1.TokenRequest
RequiresRepublish *bool
FSGroupPolicy *storagev1.FSGroupPolicy
// Embedded defines whether the CSI mock driver runs
// inside the cluster (false, the default) or just a proxy
// runs inside the cluster and all gRPC calls are handled
// inside the e2e.test binary.
Embedded bool
// Hooks that will be called if (and only if!) the embedded
// mock driver is used. Beware that hooks are invoked
// asynchronously in different goroutines.
Hooks Hooks
}
// Dummy structure that parses just volume_attributes and error code out of logged CSI call
type MockCSICall struct {
json string // full log entry
Method string
Request struct {
VolumeContext map[string]string `json:"volume_context"`
}
FullError struct {
Code codes.Code `json:"code"`
Message string `json:"message"`
}
Error string
}
// MockCSICalls is a Thread-safe storage for MockCSICall instances.
type MockCSICalls struct {
calls []MockCSICall
mutex sync.Mutex
}
// Get returns all currently recorded calls.
func (c *MockCSICalls) Get() []MockCSICall {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.calls[:]
}
// Add appens one new call at the end.
func (c *MockCSICalls) Add(call MockCSICall) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.calls = append(c.calls, call)
}
// LogGRPC takes individual parameters from the mock CSI driver and adds them.
func (c *MockCSICalls) LogGRPC(method string, request, reply interface{}, err error) {
// Encoding to JSON and decoding mirrors the traditional way of capturing calls.
// Probably could be simplified now...
logMessage := struct {
Method string
Request interface{}
Response interface{}
// Error as string, for backward compatibility.
// "" on no error.
Error string
// Full error dump, to be able to parse out full gRPC error code and message separately in a test.
FullError error
}{
Method: method,
Request: request,
Response: reply,
FullError: err,
}
if err != nil {
logMessage.Error = err.Error()
}
msg, _ := json.Marshal(logMessage)
call := MockCSICall{
json: string(msg),
}
json.Unmarshal(msg, &call)
klog.Infof("%s %s", grpcCallPrefix, string(msg))
// Trim gRPC service name, i.e. "/csi.v1.Identity/Probe" -> "Probe"
methodParts := strings.Split(call.Method, "/")
call.Method = methodParts[len(methodParts)-1]
c.Add(call)
}
var _ storageframework.TestDriver = &mockCSIDriver{}
@ -260,7 +391,7 @@ var _ storageframework.DynamicPVTestDriver = &mockCSIDriver{}
var _ storageframework.SnapshottableTestDriver = &mockCSIDriver{}
// InitMockCSIDriver returns a mockCSIDriver that implements TestDriver interface
func InitMockCSIDriver(driverOpts CSIMockDriverOpts) storageframework.TestDriver {
func InitMockCSIDriver(driverOpts CSIMockDriverOpts) MockCSITestDriver {
driverManifests := []string{
"test/e2e/testing-manifests/storage-csi/external-attacher/rbac.yaml",
"test/e2e/testing-manifests/storage-csi/external-provisioner/rbac.yaml",
@ -268,7 +399,11 @@ func InitMockCSIDriver(driverOpts CSIMockDriverOpts) storageframework.TestDriver
"test/e2e/testing-manifests/storage-csi/external-snapshotter/rbac.yaml",
"test/e2e/testing-manifests/storage-csi/mock/csi-mock-rbac.yaml",
"test/e2e/testing-manifests/storage-csi/mock/csi-storageclass.yaml",
"test/e2e/testing-manifests/storage-csi/mock/csi-mock-driver.yaml",
}
if driverOpts.Embedded {
driverManifests = append(driverManifests, "test/e2e/testing-manifests/storage-csi/mock/csi-mock-proxy.yaml")
} else {
driverManifests = append(driverManifests, "test/e2e/testing-manifests/storage-csi/mock/csi-mock-driver.yaml")
}
if driverOpts.RegisterDriver {
@ -309,10 +444,11 @@ func InitMockCSIDriver(driverOpts CSIMockDriverOpts) storageframework.TestDriver
attachable: !driverOpts.DisableAttach,
attachLimit: driverOpts.AttachLimit,
enableNodeExpansion: driverOpts.EnableNodeExpansion,
javascriptHooks: driverOpts.JavascriptHooks,
tokenRequests: driverOpts.TokenRequests,
requiresRepublish: driverOpts.RequiresRepublish,
fsGroupPolicy: driverOpts.FSGroupPolicy,
embedded: driverOpts.Embedded,
hooks: driverOpts.Hooks,
}
}
@ -340,62 +476,108 @@ func (m *mockCSIDriver) GetSnapshotClass(config *storageframework.PerTestConfig,
}
func (m *mockCSIDriver) PrepareTest(f *framework.Framework) (*storageframework.PerTestConfig, func()) {
m.clientSet = f.ClientSet
// Create secondary namespace which will be used for creating driver
driverNamespace := utils.CreateDriverNamespace(f)
driverns := driverNamespace.Name
m.driverNamespace = utils.CreateDriverNamespace(f)
driverns := m.driverNamespace.Name
testns := f.Namespace.Name
ginkgo.By("deploying csi mock driver")
cancelLogging := utils.StartPodLogs(f, driverNamespace)
if m.embedded {
ginkgo.By("deploying csi mock proxy")
} else {
ginkgo.By("deploying csi mock driver")
}
cancelLogging := utils.StartPodLogs(f, m.driverNamespace)
cs := f.ClientSet
// pods should be scheduled on the node
node, err := e2enode.GetRandomReadySchedulableNode(cs)
framework.ExpectNoError(err)
embeddedCleanup := func() {}
containerArgs := []string{}
if m.embedded {
// Run embedded CSI driver.
//
// For now we start exactly one instance which implements controller,
// node and identity services. It matches with the one pod that we run
// inside the cluster. The name and namespace of that one is deterministic,
// so we know what to connect to.
//
// Long-term we could also deploy one central controller and multiple
// node instances, with knowledge about provisioned volumes shared in
// this process.
podname := "csi-mockplugin-0"
containername := "mock"
ctx, cancel := context.WithCancel(context.Background())
serviceConfig := mockservice.Config{
DisableAttach: !m.attachable,
DriverName: "csi-mock-" + f.UniqueName,
AttachLimit: int64(m.attachLimit),
NodeExpansionRequired: m.enableNodeExpansion,
EnableTopology: m.enableTopology,
IO: proxy.PodDirIO{
F: f,
Namespace: m.driverNamespace.Name,
PodName: podname,
ContainerName: "busybox",
},
}
s := mockservice.New(serviceConfig)
servers := &mockdriver.CSIDriverServers{
Controller: s,
Identity: s,
Node: s,
}
m.embeddedCSIDriver = mockdriver.NewCSIDriver(servers)
l, err := proxy.Listen(ctx, f.ClientSet, f.ClientConfig(),
proxy.Addr{
Namespace: m.driverNamespace.Name,
PodName: podname,
ContainerName: containername,
Port: 9000,
},
)
framework.ExpectNoError(err, "start connecting to proxy pod")
err = m.embeddedCSIDriver.Start(l, m.interceptGRPC)
framework.ExpectNoError(err, "start mock driver")
embeddedCleanup = func() {
// Kill all goroutines and delete resources of the mock driver.
m.embeddedCSIDriver.Stop()
l.Close()
cancel()
}
} else {
// When using the mock driver inside the cluster it has to be reconfigured
// via command line parameters.
containerArgs = append(containerArgs, "--name=csi-mock-"+f.UniqueName)
if !m.attachable {
containerArgs = append(containerArgs, "--disable-attach")
}
if m.enableTopology {
containerArgs = append(containerArgs, "--enable-topology")
}
if m.attachLimit > 0 {
containerArgs = append(containerArgs, "--attach-limit", strconv.Itoa(m.attachLimit))
}
if m.enableNodeExpansion {
containerArgs = append(containerArgs, "--node-expand-required=true")
}
}
config := &storageframework.PerTestConfig{
Driver: m,
Prefix: "mock",
Framework: f,
ClientNodeSelection: e2epod.NodeSelection{Name: node.Name},
DriverNamespace: driverNamespace,
}
containerArgs := []string{"--name=csi-mock-" + f.UniqueName}
if !m.attachable {
containerArgs = append(containerArgs, "--disable-attach")
}
if m.enableTopology {
containerArgs = append(containerArgs, "--enable-topology")
}
if m.attachLimit > 0 {
containerArgs = append(containerArgs, "--attach-limit", strconv.Itoa(m.attachLimit))
}
if m.enableNodeExpansion {
containerArgs = append(containerArgs, "--node-expand-required=true")
}
// Create a config map with javascript hooks. Create it even when javascriptHooks
// are empty, so we can unconditionally add it to the mock pod.
const hooksConfigMapName = "mock-driver-hooks"
hooksYaml, err := yaml.Marshal(m.javascriptHooks)
framework.ExpectNoError(err)
hooks := &v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: hooksConfigMapName,
},
Data: map[string]string{
"hooks.yaml": string(hooksYaml),
},
}
_, err = f.ClientSet.CoreV1().ConfigMaps(driverns).Create(context.TODO(), hooks, metav1.CreateOptions{})
framework.ExpectNoError(err)
if len(m.javascriptHooks) > 0 {
containerArgs = append(containerArgs, "--hooks-file=/etc/hooks/hooks.yaml")
DriverNamespace: m.driverNamespace,
}
o := utils.PatchCSIOptions{
@ -416,7 +598,7 @@ func (m *mockCSIDriver) PrepareTest(f *framework.Framework) (*storageframework.P
RequiresRepublish: m.requiresRepublish,
FSGroupPolicy: m.fsGroupPolicy,
}
cleanup, err := utils.CreateFromManifests(f, driverNamespace, func(item interface{}) error {
cleanup, err := utils.CreateFromManifests(f, m.driverNamespace, func(item interface{}) error {
return utils.PatchCSIDeployment(f, o, item)
}, m.manifests...)
@ -424,7 +606,7 @@ func (m *mockCSIDriver) PrepareTest(f *framework.Framework) (*storageframework.P
framework.Failf("deploying csi mock driver: %v", err)
}
cleanupFunc := generateDriverCleanupFunc(
driverCleanupFunc := generateDriverCleanupFunc(
f,
"mock",
testns,
@ -432,9 +614,81 @@ func (m *mockCSIDriver) PrepareTest(f *framework.Framework) (*storageframework.P
cleanup,
cancelLogging)
cleanupFunc := func() {
embeddedCleanup()
driverCleanupFunc()
}
return config, cleanupFunc
}
func (m *mockCSIDriver) interceptGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
// Always log the call and its final result,
// regardless whether the result was from the real
// implementation or a hook.
m.calls.LogGRPC(info.FullMethod, req, resp, err)
}()
if m.hooks.Pre != nil {
resp, err = m.hooks.Pre(ctx, info.FullMethod, req)
if resp != nil || err != nil {
return
}
}
resp, err = handler(ctx, req)
if m.hooks.Post != nil {
resp, err = m.hooks.Post(ctx, info.FullMethod, req, resp, err)
}
return
}
func (m *mockCSIDriver) GetCalls() ([]MockCSICall, error) {
if m.embedded {
return m.calls.Get(), nil
}
if m.driverNamespace == nil {
return nil, errors.New("PrepareTest not called yet")
}
// Name of CSI driver pod name (it's in a StatefulSet with a stable name)
driverPodName := "csi-mockplugin-0"
// Name of CSI driver container name
driverContainerName := "mock"
// Load logs of driver pod
log, err := e2epod.GetPodLogs(m.clientSet, m.driverNamespace.Name, driverPodName, driverContainerName)
if err != nil {
return nil, fmt.Errorf("could not load CSI driver logs: %s", err)
}
logLines := strings.Split(log, "\n")
var calls []MockCSICall
for _, line := range logLines {
index := strings.Index(line, grpcCallPrefix)
if index == -1 {
continue
}
line = line[index+len(grpcCallPrefix):]
call := MockCSICall{
json: string(line),
}
err := json.Unmarshal([]byte(line), &call)
if err != nil {
framework.Logf("Could not parse CSI driver log line %q: %s", line, err)
continue
}
// Trim gRPC service name, i.e. "/csi.v1.Identity/Probe" -> "Probe"
methodParts := strings.Split(call.Method, "/")
call.Method = methodParts[len(methodParts)-1]
calls = append(calls, call)
}
return calls, nil
}
// gce-pd
type gcePDCSIDriver struct {
driverInfo storageframework.DriverInfo

View File

@ -0,0 +1,82 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"fmt"
"k8s.io/kubernetes/test/e2e/framework"
"k8s.io/kubernetes/test/e2e/storage/drivers/csi-test/mock/service"
)
type PodDirIO struct {
F *framework.Framework
Namespace string
PodName string
ContainerName string
}
var _ service.DirIO = PodDirIO{}
func (p PodDirIO) DirExists(path string) (bool, error) {
stdout, stderr, err := p.execute([]string{
"sh",
"-c",
fmt.Sprintf("if ! [ -e '%s' ]; then echo notexist; elif [ -d '%s' ]; then echo dir; else echo nodir; fi", path, path),
})
if err != nil {
return false, fmt.Errorf("error executing dir test commands: stderr=%q, %v", stderr, err)
}
switch stdout {
case "notexist":
return false, nil
case "nodir":
return false, fmt.Errorf("%s: not a directory", path)
case "dir":
return true, nil
default:
return false, fmt.Errorf("unexpected output from dir test commands: %q", stdout)
}
}
func (p PodDirIO) Mkdir(path string) error {
_, stderr, err := p.execute([]string{"mkdir", path})
if err != nil {
return fmt.Errorf("mkdir %q: stderr=%q, %v", path, stderr, err)
}
return nil
}
func (p PodDirIO) RemoveAll(path string) error {
_, stderr, err := p.execute([]string{"rm", "-rf", path})
if err != nil {
return fmt.Errorf("rm -rf %q: stderr=%q, %v", path, stderr, err)
}
return nil
}
func (p PodDirIO) execute(command []string) (string, string, error) {
return p.F.ExecWithOptions(framework.ExecOptions{
Command: command,
Namespace: p.Namespace,
PodName: p.PodName,
ContainerName: p.ContainerName,
CaptureStdout: true,
CaptureStderr: true,
Quiet: true,
})
}

View File

@ -0,0 +1,336 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
"k8s.io/klog/v2"
)
// Maximum number of forwarded connections. In practice we don't
// need more than one per sidecar and kubelet. Keeping this reasonably
// small ensures that we don't establish connections through the apiserver
// and the remote kernel which then arent' needed.
const maxConcurrentConnections = 10
// Listen creates a listener which returns new connections whenever someone connects
// to a socat or mock driver proxy instance running inside the given pod.
//
// socat must by started with "<listen>,fork TCP-LISTEN:<port>,reuseport"
// for this to work. "<listen>" can be anything that accepts connections,
// for example "UNIX-LISTEN:/csi/csi.sock". In this mode, socat will
// accept exactly one connection on the given port for each connection
// that socat itself accepted.
//
// Listening stops when the context is done or Close() is called.
func Listen(ctx context.Context, clientset kubernetes.Interface, restConfig *rest.Config, addr Addr) (net.Listener, error) {
// We connect through port forwarding. Strictly
// speaking this is overkill because we don't need a local
// port. But this way we can reuse client-go/tools/portforward
// instead of having to replicate handleConnection
// in our own code.
restClient := clientset.CoreV1().RESTClient()
if restConfig.GroupVersion == nil {
restConfig.GroupVersion = &schema.GroupVersion{}
}
if restConfig.NegotiatedSerializer == nil {
restConfig.NegotiatedSerializer = scheme.Codecs
}
// The setup code around the actual portforward is from
// https://github.com/kubernetes/kubernetes/blob/c652ffbe4a29143623a1aaec39f745575f7e43ad/staging/src/k8s.io/kubectl/pkg/cmd/portforward/portforward.go
req := restClient.Post().
Resource("pods").
Namespace(addr.Namespace).
Name(addr.PodName).
SubResource("portforward")
transport, upgrader, err := spdy.RoundTripperFor(restConfig)
if err != nil {
return nil, fmt.Errorf("create round tripper: %v", err)
}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
prefix := fmt.Sprintf("port forwarding for %s", addr)
ctx, cancel := context.WithCancel(ctx)
l := &listener{
connections: make(chan *connection),
ctx: ctx,
cancel: cancel,
addr: addr,
}
var connectionsCreated, connectionsClosed int32
runForwarding := func() {
klog.V(2).Infof("%s: starting connection polling", prefix)
defer klog.V(2).Infof("%s: connection polling ended", prefix)
// This delay determines how quickly we notice when someone has
// connected inside the cluster. With socat, we cannot make this too small
// because otherwise we get many rejected connections. With the mock
// driver as proxy that doesn't happen as long as we don't
// ask for too many concurrent connections because the mock driver
// keeps the listening port open at all times and the Linux
// kernel automatically accepts our connection requests.
tryConnect := time.NewTicker(100 * time.Millisecond)
defer tryConnect.Stop()
for {
select {
case <-ctx.Done():
return
case <-tryConnect.C:
currentClosed := atomic.LoadInt32(&connectionsClosed)
openConnections := connectionsCreated - currentClosed
if openConnections >= maxConcurrentConnections {
break
}
klog.V(5).Infof("%s: trying to create a new connection #%d, %d open", prefix, connectionsCreated, openConnections)
stream, err := dial(ctx, fmt.Sprintf("%s #%d", prefix, connectionsCreated), dialer, addr.Port)
if err != nil {
klog.Errorf("%s: no connection: %v", prefix, err)
break
}
// Make the connection available to Accept below.
klog.V(5).Infof("%s: created a new connection #%d", prefix, connectionsCreated)
l.connections <- &connection{
stream: stream,
addr: addr,
counter: connectionsCreated,
closed: &connectionsClosed,
}
connectionsCreated++
}
}
}
// Portforwarding and polling for connections run in the background.
go func() {
for {
running := false
pod, err := clientset.CoreV1().Pods(addr.Namespace).Get(ctx, addr.PodName, metav1.GetOptions{})
if err != nil {
klog.V(5).Infof("checking for container %q in pod %s/%s: %v", addr.ContainerName, addr.Namespace, addr.PodName, err)
}
for i, status := range pod.Status.ContainerStatuses {
if pod.Spec.Containers[i].Name == addr.ContainerName &&
status.State.Running != nil {
running = true
break
}
}
if running {
klog.V(2).Infof("container %q in pod %s/%s is running", addr.ContainerName, addr.Namespace, addr.PodName)
runForwarding()
}
select {
case <-ctx.Done():
return
// Sleep a bit before restarting. This is
// where we potentially wait for the pod to
// start.
case <-time.After(1 * time.Second):
}
}
}()
return l, nil
}
// Addr contains all relevant parameters for a certain port in a pod.
// The container must be running before connections are attempted.
type Addr struct {
Namespace, PodName, ContainerName string
Port int
}
var _ net.Addr = Addr{}
func (a Addr) Network() string {
return "port-forwarding"
}
func (a Addr) String() string {
return fmt.Sprintf("%s/%s:%d", a.Namespace, a.PodName, a.Port)
}
type stream struct {
httpstream.Stream
streamConn httpstream.Connection
}
func dial(ctx context.Context, prefix string, dialer httpstream.Dialer, port int) (s *stream, finalErr error) {
streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return nil, fmt.Errorf("dialer failed: %v", err)
}
requestID := "1"
defer func() {
if finalErr != nil {
streamConn.Close()
}
}()
// create error stream
headers := http.Header{}
headers.Set(v1.StreamType, v1.StreamTypeError)
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port))
headers.Set(v1.PortForwardRequestIDHeader, requestID)
// We're not writing to this stream, just reading an error message from it.
// This happens asynchronously.
errorStream, err := streamConn.CreateStream(headers)
if err != nil {
return nil, fmt.Errorf("error creating error stream: %v", err)
}
errorStream.Close()
go func() {
message, err := ioutil.ReadAll(errorStream)
switch {
case err != nil:
klog.Errorf("%s: error reading from error stream: %v", prefix, err)
case len(message) > 0:
klog.Errorf("%s: an error occurred connecting to the remote port: %v", prefix, string(message))
}
}()
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := streamConn.CreateStream(headers)
if err != nil {
return nil, fmt.Errorf("error creating data stream: %v", err)
}
return &stream{
Stream: dataStream,
streamConn: streamConn,
}, nil
}
func (s *stream) Close() {
s.Stream.Close()
s.streamConn.Close()
}
type listener struct {
addr Addr
connections chan *connection
ctx context.Context
cancel func()
}
var _ net.Listener = &listener{}
func (l *listener) Close() error {
klog.V(5).Infof("forward listener for %s: closing", l.addr)
l.cancel()
return nil
}
func (l *listener) Accept() (net.Conn, error) {
select {
case <-l.ctx.Done():
return nil, errors.New("listening was stopped")
case c := <-l.connections:
klog.V(5).Infof("forward listener for %s: got a new connection #%d", l.addr, c.counter)
return c, nil
}
}
type connection struct {
stream *stream
addr Addr
counter int32
closed *int32
mutex sync.Mutex
}
var _ net.Conn = &connection{}
func (c *connection) LocalAddr() net.Addr {
return c.addr
}
func (c *connection) RemoteAddr() net.Addr {
return c.addr
}
func (c *connection) SetDeadline(t time.Time) error {
return nil
}
func (c *connection) SetReadDeadline(t time.Time) error {
return nil
}
func (c *connection) SetWriteDeadline(t time.Time) error {
return nil
}
func (c *connection) Read(b []byte) (int, error) {
n, err := c.stream.Read(b)
if errors.Is(err, io.EOF) {
klog.V(5).Infof("forward connection #%d for %s: remote side closed the stream", c.counter, c.addr)
}
return n, err
}
func (c *connection) Write(b []byte) (int, error) {
n, err := c.stream.Write(b)
if errors.Is(err, io.EOF) {
klog.V(5).Infof("forward connection #%d for %s: remote side closed the stream", c.counter, c.addr)
}
return n, err
}
func (c *connection) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed != nil {
// Do the logging and book-keeping only once. The function itself may be called more than once.
klog.V(5).Infof("forward connection #%d for %s: closing our side", c.counter, c.addr)
atomic.AddInt32(c.closed, 1)
c.closed = nil
}
c.stream.Close()
return nil
}
func (l *listener) Addr() net.Addr {
return l.addr
}

View File

@ -24,6 +24,9 @@ spec:
- "-v=5"
# Needed for fsGroup support.
- "--default-fstype=ext4"
# We don't need much concurrency and having many gouroutines
# makes klog.Fatal during shutdown very long.
- "--worker-threads=5"
env:
- name: ADDRESS
value: /csi/csi.sock
@ -50,7 +53,7 @@ spec:
- mountPath: /registration
name: registration-dir
- name: mock
image: k8s.gcr.io/sig-storage/mock-driver:v4.0.2
image: k8s.gcr.io/sig-storage/mock-driver:v4.1.0
args:
- "--name=mock.storage.k8s.io"
- "-v=3" # enabled the gRPC call logging
@ -68,10 +71,9 @@ spec:
- mountPath: /csi
name: socket-dir
- mountPath: /var/lib/kubelet/pods
mountPropagation: Bidirectional
name: mountpoint-dir
- name: hooks
mountPath: /etc/hooks
name: kubelet-pods-dir
- mountPath: /var/lib/kubelet/plugins/kubernetes.io/csi
name: kubelet-csi-dir
volumes:
- hostPath:
path: /var/lib/kubelet/plugins/csi-mock
@ -79,12 +81,15 @@ spec:
name: socket-dir
- hostPath:
path: /var/lib/kubelet/pods
type: Directory
# mock driver doesn't make mounts and therefore doesn't need mount propagation.
# mountPropagation: Bidirectional
name: kubelet-pods-dir
- hostPath:
path: /var/lib/kubelet/plugins/kubernetes.io/csi
type: DirectoryOrCreate
name: mountpoint-dir
name: kubelet-csi-dir
- hostPath:
path: /var/lib/kubelet/plugins_registry
type: Directory
name: registration-dir
- name: hooks
configMap:
name: mock-driver-hooks

View File

@ -0,0 +1,105 @@
kind: StatefulSet
apiVersion: apps/v1
metadata:
name: csi-mockplugin
spec:
selector:
matchLabels:
app: csi-mockplugin
replicas: 1
template:
metadata:
labels:
app: csi-mockplugin
spec:
serviceAccountName: csi-mock
containers:
- name: csi-provisioner
image: k8s.gcr.io/sig-storage/csi-provisioner:v2.1.0
args:
- "--csi-address=$(ADDRESS)"
# Topology support is needed for the pod rescheduling test
# ("storage capacity" in csi_mock_volume.go).
- "--feature-gates=Topology=true"
- "-v=5"
- "--timeout=1m"
# Needed for fsGroup support.
- "--default-fstype=ext4"
# We don't need much concurrency and having many gouroutines
# makes klog.Fatal during shutdown very long.
- "--worker-threads=5"
env:
- name: ADDRESS
value: /csi/csi.sock
volumeMounts:
- mountPath: /csi
name: socket-dir
- name: driver-registrar
image: k8s.gcr.io/sig-storage/csi-node-driver-registrar:v2.1.0
args:
- --v=5
- --csi-address=/csi/csi.sock
- --kubelet-registration-path=/var/lib/kubelet/plugins/csi-mock/csi.sock
- --timeout=1m
env:
- name: KUBE_NODE_NAME
valueFrom:
fieldRef:
apiVersion: v1
fieldPath: spec.nodeName
volumeMounts:
- mountPath: /csi
name: socket-dir
- mountPath: /registration
name: registration-dir
- name: mock
image: k8s.gcr.io/sig-storage/mock-driver:v4.1.0
args:
# -v3 shows when connections get established. Higher log levels print information about
# transferred bytes, but cannot print message content (no gRPC parsing), so this is usually
# not interesting.
- -v=3
- -proxy-endpoint=tcp://:9000
env:
- name: CSI_ENDPOINT
value: /csi/csi.sock
ports:
- containerPort: 9000
name: socat
volumeMounts:
- mountPath: /csi
name: socket-dir
# The busybox container is needed for running shell commands which
# test for directories or create them. It needs additional privileges
# for that.
- name: busybox
image: k8s.gcr.io/busybox
securityContext:
privileged: true
command:
- sleep
- "100000"
volumeMounts:
- mountPath: /var/lib/kubelet/pods
name: kubelet-pods-dir
- mountPath: /var/lib/kubelet/plugins/kubernetes.io/csi
name: kubelet-csi-dir
volumes:
- hostPath:
path: /var/lib/kubelet/plugins/csi-mock
type: DirectoryOrCreate
name: socket-dir
- hostPath:
path: /var/lib/kubelet/pods
type: Directory
# mock driver doesn't make mounts and therefore doesn't need mount propagation.
# mountPropagation: Bidirectional
name: kubelet-pods-dir
- hostPath:
path: /var/lib/kubelet/plugins/kubernetes.io/csi
type: DirectoryOrCreate
name: kubelet-csi-dir
- hostPath:
path: /var/lib/kubelet/plugins_registry
type: Directory
name: registration-dir

1
vendor/modules.txt vendored
View File

@ -496,6 +496,7 @@ github.com/golang/groupcache/lru
# github.com/golang/mock => github.com/golang/mock v1.4.4
github.com/golang/mock/gomock
# github.com/golang/protobuf v1.4.3 => github.com/golang/protobuf v1.4.3
## explicit
# github.com/golang/protobuf => github.com/golang/protobuf v1.4.3
github.com/golang/protobuf/jsonpb
github.com/golang/protobuf/proto